feat: SUPIR model support (CORE-17) (#13250)
This commit is contained in:
@@ -7,7 +7,10 @@ import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.latent_formats
|
||||
import comfy.ldm.lumina.controlnet
|
||||
import comfy.ldm.supir.supir_modules
|
||||
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
||||
from comfy_api.latest import io
|
||||
from comfy.ldm.supir.supir_patch import SUPIRPatch
|
||||
|
||||
|
||||
class BlockWiseControlBlock(torch.nn.Module):
|
||||
@@ -266,6 +269,27 @@ class ModelPatchLoader:
|
||||
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
operations=comfy.ops.manual_cast)
|
||||
elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd:
|
||||
prefix_replace = {}
|
||||
if 'model.control_model.input_hint_block.0.weight' in sd:
|
||||
prefix_replace["model.control_model."] = "control_model."
|
||||
prefix_replace["model.diffusion_model.project_modules."] = "project_modules."
|
||||
else:
|
||||
prefix_replace["control_model."] = "control_model."
|
||||
prefix_replace["project_modules."] = "project_modules."
|
||||
|
||||
# Extract denoise_encoder weights before filter_keys discards them
|
||||
de_prefix = "first_stage_model.denoise_encoder."
|
||||
denoise_encoder_sd = {}
|
||||
for k in list(sd.keys()):
|
||||
if k.startswith(de_prefix):
|
||||
denoise_encoder_sd[k[len(de_prefix):]] = sd.pop(k)
|
||||
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True)
|
||||
sd.pop("control_model.mask_LQ", None)
|
||||
model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||
if denoise_encoder_sd:
|
||||
model.denoise_encoder_sd = denoise_encoder_sd
|
||||
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
|
||||
@@ -565,9 +589,89 @@ class MultiTalkModelPatch(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class SUPIRApply(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SUPIRApply",
|
||||
category="model_patches/supir",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.ModelPatch.Input("model_patch"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("image"),
|
||||
io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||
tooltip="Control strength at the start of sampling (high sigma)."),
|
||||
io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||
tooltip="Control strength at the end of sampling (low sigma). Linearly interpolated from start."),
|
||||
io.Float.Input("restore_cfg", default=4.0, min=0.0, max=20.0, step=0.1, advanced=True,
|
||||
tooltip="Pulls denoised output toward the input latent. Higher = stronger fidelity to input. 0 to disable."),
|
||||
io.Float.Input("restore_cfg_s_tmin", default=0.05, min=0.0, max=1.0, step=0.01, advanced=True,
|
||||
tooltip="Sigma threshold below which restore_cfg is disabled."),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _encode_with_denoise_encoder(cls, vae, model_patch, image):
|
||||
"""Encode using denoise_encoder weights from SUPIR checkpoint if available."""
|
||||
denoise_sd = getattr(model_patch.model, 'denoise_encoder_sd', None)
|
||||
if not denoise_sd:
|
||||
return vae.encode(image)
|
||||
|
||||
# Clone VAE patcher, apply denoise_encoder weights to clone, encode
|
||||
orig_patcher = vae.patcher
|
||||
vae.patcher = orig_patcher.clone()
|
||||
patches = {f"encoder.{k}": (v,) for k, v in denoise_sd.items()}
|
||||
vae.patcher.add_patches(patches, strength_patch=1.0, strength_model=0.0)
|
||||
try:
|
||||
return vae.encode(image)
|
||||
finally:
|
||||
vae.patcher = orig_patcher
|
||||
|
||||
@classmethod
|
||||
def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, vae: io.Vae.Type, image: io.Image.Type,
|
||||
strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput:
|
||||
model_patched = model.clone()
|
||||
hint_latent = model.get_model_object("latent_format").process_in(
|
||||
cls._encode_with_denoise_encoder(vae, model_patch, image[:, :, :, :3]))
|
||||
patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end)
|
||||
patch.register(model_patched)
|
||||
|
||||
if restore_cfg > 0.0:
|
||||
# Round-trip to match original pipeline: decode hint, re-encode with regular VAE
|
||||
latent_format = model.get_model_object("latent_format")
|
||||
decoded = vae.decode(latent_format.process_out(hint_latent))
|
||||
x_center = latent_format.process_in(vae.encode(decoded[:, :, :, :3]))
|
||||
sigma_max = 14.6146
|
||||
|
||||
def restore_cfg_function(args):
|
||||
denoised = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
if sigma.dim() > 0:
|
||||
s = sigma[0].item()
|
||||
else:
|
||||
s = sigma.item()
|
||||
if s > restore_cfg_s_tmin:
|
||||
ref = x_center.to(device=denoised.device, dtype=denoised.dtype)
|
||||
b = denoised.shape[0]
|
||||
if ref.shape[0] != b:
|
||||
ref = ref.expand(b, -1, -1, -1) if ref.shape[0] == 1 else ref.repeat((b + ref.shape[0] - 1) // ref.shape[0], 1, 1, 1)[:b]
|
||||
sigma_val = sigma.view(-1, 1, 1, 1) if sigma.dim() > 0 else sigma
|
||||
d_center = denoised - ref
|
||||
denoised = denoised - d_center * ((sigma_val / sigma_max) ** restore_cfg)
|
||||
return denoised
|
||||
|
||||
model_patched.set_model_sampler_post_cfg_function(restore_cfg_function)
|
||||
|
||||
return io.NodeOutput(model_patched)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelPatchLoader": ModelPatchLoader,
|
||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||
"ZImageFunControlnet": ZImageFunControlnet,
|
||||
"USOStyleReference": USOStyleReference,
|
||||
"SUPIRApply": SUPIRApply,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user