Support hunyuan image distilled model. (#9807)

This commit is contained in:
comfyanonymous
2025-09-10 20:17:34 -07:00
committed by GitHub
parent 72212fef66
commit e01e99d075
2 changed files with 24 additions and 2 deletions

View File

@@ -41,6 +41,7 @@ class HunyuanVideoParams:
qkv_bias: bool
guidance_embed: bool
byt5: bool
meanflow: bool
class SelfAttentionRef(nn.Module):
@@ -256,6 +257,11 @@ class HunyuanVideo(nn.Module):
else:
self.byt5_in = None
if params.meanflow:
self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
self.time_r_in = None
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
@@ -282,6 +288,14 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
if self.time_r_in is not None:
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
if len(w) > 0:
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
vec = (vec + vec_r) / 2
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)
ref_latent = self.img_in(ref_latent)