Add support for unCLIP SD2.x models.
See _for_testing/unclip in the UI for the new nodes. unCLIPCheckpointLoader is used to load them. unCLIPConditioning is used to add the image cond and takes as input a CLIPVisionEncode output which has been moved to the conditioning section.
This commit is contained in:
@@ -307,7 +307,16 @@ def model_wrapper(
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t_continuous] * 2)
|
||||
c_in = torch.cat([unconditional_condition, condition])
|
||||
if isinstance(condition, dict):
|
||||
assert isinstance(unconditional_condition, dict)
|
||||
c_in = dict()
|
||||
for k in condition:
|
||||
if isinstance(condition[k], list):
|
||||
c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_condition[k], condition[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_condition, condition])
|
||||
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
||||
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import torch
|
||||
|
||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||
|
||||
|
||||
MODEL_TYPES = {
|
||||
"eps": "noise",
|
||||
"v": "v"
|
||||
@@ -51,12 +50,20 @@ class DPMSolverSampler(object):
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
if isinstance(ctmp, torch.Tensor):
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
if isinstance(conditioning, torch.Tensor):
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
@@ -83,6 +90,7 @@ class DPMSolverSampler(object):
|
||||
)
|
||||
|
||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
|
||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
|
||||
lower_order_final=True)
|
||||
|
||||
return x.to(device), None
|
||||
return x.to(device), None
|
||||
|
||||
Reference in New Issue
Block a user