Add a way to pass options to the transformers blocks.

This commit is contained in:
comfyanonymous
2023-03-31 13:04:39 -04:00
parent 04b42bad87
commit 61ec3c9d5d
5 changed files with 33 additions and 29 deletions

View File

@@ -78,7 +78,7 @@ class DDIMSampler(object):
dynamic_threshold=None,
ucg_schedule=None,
denoise_function=None,
cond_concat=None,
extra_args=None,
to_zero=True,
end_step=None,
**kwargs
@@ -101,7 +101,7 @@ class DDIMSampler(object):
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule,
denoise_function=denoise_function,
cond_concat=cond_concat,
extra_args=extra_args,
to_zero=to_zero,
end_step=end_step
)
@@ -174,7 +174,7 @@ class DDIMSampler(object):
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule,
denoise_function=None,
cond_concat=None
extra_args=None
)
return samples, intermediates
@@ -185,7 +185,7 @@ class DDIMSampler(object):
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None, denoise_function=None, cond_concat=None, to_zero=True, end_step=None):
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
@@ -225,7 +225,7 @@ class DDIMSampler(object):
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, cond_concat=cond_concat)
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
@@ -249,11 +249,11 @@ class DDIMSampler(object):
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None, denoise_function=None, cond_concat=None):
dynamic_threshold=None, denoise_function=None, extra_args=None):
b, *_, device = *x.shape, x.device
if denoise_function is not None:
model_output = denoise_function(self.model.apply_model, x, t, unconditional_conditioning, c, unconditional_guidance_scale, cond_concat)
model_output = denoise_function(self.model.apply_model, x, t, **extra_args)
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else: