Support hunyuan image 2.1 regular model. (#9792)

This commit is contained in:
comfyanonymous
2025-09-09 23:05:07 -07:00
committed by GitHub
parent 5c33872e2f
commit 85e34643f8
14 changed files with 1906 additions and 30 deletions

View File

@@ -40,6 +40,7 @@ class HunyuanVideoParams:
patch_size: list
qkv_bias: bool
guidance_embed: bool
byt5: bool
class SelfAttentionRef(nn.Module):
@@ -161,6 +162,30 @@ class TokenRefiner(nn.Module):
x = self.individual_token_refiner(x, c, mask)
return x
class ByT5Mapper(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None):
super().__init__()
self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device)
self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device)
self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device)
self.use_res = use_res
self.act_fn = nn.GELU()
def forward(self, x):
if self.use_res:
res = x
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
x2 = self.act_fn(x)
x2 = self.fc3(x2)
if self.use_res:
x2 = x2 + res
return x2
class HunyuanVideo(nn.Module):
"""
Transformer model for flow matching on sequences.
@@ -185,9 +210,13 @@ class HunyuanVideo(nn.Module):
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
if params.vec_in_dim is not None:
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
self.vector_in = None
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
@@ -215,6 +244,18 @@ class HunyuanVideo(nn.Module):
]
)
if params.byt5:
self.byt5_in = ByT5Mapper(
in_dim=1472,
out_dim=2048,
hidden_dim=2048,
out_dim1=self.hidden_size,
use_res=False,
dtype=dtype, device=device, operations=operations
)
else:
self.byt5_in = None
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
@@ -226,7 +267,8 @@ class HunyuanVideo(nn.Module):
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
y: Tensor,
y: Tensor = None,
txt_byt5=None,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
@@ -250,13 +292,17 @@ class HunyuanVideo(nn.Module):
if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
if self.vector_in is not None:
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
else:
vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1)
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
modulation_dims_txt = [(0, None, 1)]
else:
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.vector_in is not None:
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
modulation_dims = None
modulation_dims_txt = None
@@ -269,6 +315,12 @@ class HunyuanVideo(nn.Module):
txt = self.txt_in(txt, timesteps, txt_mask)
if self.byt5_in is not None and txt_byt5 is not None:
txt_byt5 = self.byt5_in(txt_byt5)
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt = torch.cat((txt, txt_byt5), dim=1)
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
@@ -328,12 +380,16 @@ class HunyuanVideo(nn.Module):
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
shape = initial_shape[-3:]
shape = initial_shape[-len(self.patch_size):]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
if img.ndim == 8:
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
else:
img = img.permute(0, 3, 1, 4, 2, 5)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3])
return img
def img_ids(self, x):
@@ -348,16 +404,30 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
def img_ids_2d(self, x):
bs, c, h, w = x.shape
patch_size = self.patch_size
h_len = ((h + (patch_size[0] // 2)) // patch_size[0])
w_len = ((w + (patch_size[1] // 2)) // patch_size[1])
img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype)
img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs = x.shape[0]
if len(self.patch_size) == 3:
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
else:
img_ids = self.img_ids_2d(x)
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
return out