Support hunyuan image distilled model. (#9807)

This commit is contained in:
comfyanonymous
2025-09-10 20:17:34 -07:00
committed by GitHub
parent 72212fef66
commit e01e99d075
2 changed files with 24 additions and 2 deletions

View File

@@ -142,12 +142,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = list(in_w.shape[2:])
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
if '{}vector_in.in_layer.weight'.format(key_prefix) in state_dict:
if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
dit_config["vec_in_dim"] = 768
dit_config["axes_dim"] = [16, 56, 56]
else:
dit_config["vec_in_dim"] = None
if len(dit_config["patch_size"]) == 2:
dit_config["axes_dim"] = [64, 64]
else:
dit_config["axes_dim"] = [16, 56, 56]
if any(s.startswith('{}time_r_in.'.format(key_prefix)) for s in state_dict_keys):
dit_config["meanflow"] = True
else:
dit_config["meanflow"] = False
dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
dit_config["hidden_size"] = in_w.shape[0]