Use torch RMSNorm for flux models and refactor hunyuan video code. (#12432)

This commit is contained in:
comfyanonymous
2026-02-13 12:35:13 -08:00
committed by GitHub
parent 8902907d7a
commit e1add563f9
10 changed files with 74 additions and 69 deletions
+5 -6
View File
@@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
flipped_img_txt=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -378,14 +377,14 @@ class HunyuanVideo(nn.Module):
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
attn_mask[:, 0, img_len:] = txt_mask
attn_mask[:, 0, :txt.shape[1]] = txt_mask
else:
attn_mask = None
@@ -413,7 +412,7 @@ class HunyuanVideo(nn.Module):
if add is not None:
img += add
img = torch.cat((img, txt), 1)
img = torch.cat((txt, img), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
@@ -435,9 +434,9 @@ class HunyuanVideo(nn.Module):
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, : img_len] += add
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
img = img[:, : img_len]
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]