Simplify differential diffusion code.
This commit is contained in:
@@ -67,6 +67,9 @@ class ModelPatcher:
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||
|
||||
def set_model_denoise_mask_function(self, denoise_mask_function):
|
||||
self.model_options["denoise_mask_function"] = denoise_mask_function
|
||||
|
||||
def set_model_patch(self, patch, name):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" not in to:
|
||||
|
||||
@@ -272,13 +272,14 @@ class CFGNoisePredictor(torch.nn.Module):
|
||||
return self.apply_model(*args, **kwargs)
|
||||
|
||||
class KSamplerX0Inpaint(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
def __init__(self, model, sigmas):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.sigmas = sigmas
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
|
||||
if denoise_mask is not None:
|
||||
if "denoise_mask_function" in model_options:
|
||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask)
|
||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
|
||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
|
||||
@@ -528,7 +529,7 @@ class KSAMPLER(Sampler):
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
model_k = KSamplerX0Inpaint(model_wrap)
|
||||
model_k = KSamplerX0Inpaint(model_wrap, sigmas)
|
||||
model_k.latent_image = latent_image
|
||||
if self.inpaint_options.get("random", False): #TODO: Should this be the default?
|
||||
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
|
||||
|
||||
Reference in New Issue
Block a user