HunyuanVideo 1.5 (#10819)
* init * update * Update model.py * Update model.py * remove print * Fix text encoding * Prevent empty negative prompt Really doesn't work otherwise * fp16 works * I2V * Update model_base.py * Update nodes_hunyuan.py * Better latent rgb factors * Use the correct sigclip output... * Support HunyuanVideo1.5 SR model * whitespaces... * Proper latent channel count * SR model fixes This also still needs timesteps scheduling based on the noise scale, can be used with two samplers too already * vae_refiner: roll the convolution through temporal Work in progress. Roll the convolution through time using 2-latent-frame chunks and a FIFO queue for the convolution seams. * Support HunyuanVideo15 latent resampler * fix * Some cleanup Co-Authored-By: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> * Proper hyvid15 I2V channels Co-Authored-By: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> * Fix TokenRefiner for fp16 Otherwise x.sum has infs, just in case only casting if input is fp16, I don't know if necessary. * Bugfix for the HunyuanVideo15 SR model * vae_refiner: roll the convolution through temporal II Roll the convolution through time using 2-latent-frame chunks and a FIFO queue for the convolution seams. Added support for encoder, lowered to 1 latent frame to save more VRAM, made work for Hunyuan Image 3.0 (as code shared). Fixed names, cleaned up code. * Allow any number of input frames in VAE. * Better VAE encode mem estimation. * Lowvram fix. * Fix hunyuan image 2.1 refiner. * Fix mistake. * Name changes. * Rename. * Whitespace. * Fix. * Fix. --------- Co-authored-by: kijai <40791699+kijai@users.noreply.github.com> Co-authored-by: Rattus <rattus128@gmail.com>
This commit is contained in:
@@ -6,7 +6,6 @@ import comfy.ldm.flux.layers
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from einops import repeat
|
||||
|
||||
@@ -42,6 +41,8 @@ class HunyuanVideoParams:
|
||||
guidance_embed: bool
|
||||
byt5: bool
|
||||
meanflow: bool
|
||||
use_cond_type_embedding: bool
|
||||
vision_in_dim: int
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
@@ -157,7 +158,10 @@ class TokenRefiner(nn.Module):
|
||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||
# m = mask.float().unsqueeze(-1)
|
||||
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
if x.dtype == torch.float16:
|
||||
c = x.float().sum(dim=1) / x.shape[1]
|
||||
else:
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
|
||||
c = t + self.c_embedder(c.to(x.dtype))
|
||||
x = self.input_embedder(x)
|
||||
@@ -196,11 +200,15 @@ class HunyuanVideo(nn.Module):
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
params = HunyuanVideoParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
self.use_cond_type_embedding = params.use_cond_type_embedding
|
||||
self.vision_in_dim = params.vision_in_dim
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@@ -266,6 +274,18 @@ class HunyuanVideo(nn.Module):
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# HunyuanVideo 1.5 specific modules
|
||||
if self.vision_in_dim is not None:
|
||||
from comfy.ldm.wan.model import MLPProj
|
||||
self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
|
||||
else:
|
||||
self.vision_in = None
|
||||
if self.use_cond_type_embedding:
|
||||
# 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
|
||||
self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
|
||||
else:
|
||||
self.cond_type_embedding = None
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
@@ -276,6 +296,7 @@ class HunyuanVideo(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor = None,
|
||||
txt_byt5=None,
|
||||
clip_fea=None,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
@@ -331,12 +352,31 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.cond_type_embedding is not None:
|
||||
self.cond_type_embedding.to(txt.device)
|
||||
cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
|
||||
txt = txt + cond_emb.to(txt.dtype)
|
||||
|
||||
if self.byt5_in is not None and txt_byt5 is not None:
|
||||
txt_byt5 = self.byt5_in(txt_byt5)
|
||||
if self.cond_type_embedding is not None:
|
||||
cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
|
||||
txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
|
||||
txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5
|
||||
else:
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
|
||||
|
||||
if clip_fea is not None:
|
||||
txt_vision_states = self.vision_in(clip_fea)
|
||||
if self.cond_type_embedding is not None:
|
||||
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
|
||||
txt_vision_states = txt_vision_states + cond_emb
|
||||
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
|
||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
@@ -430,14 +470,14 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||
).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
bs = x.shape[0]
|
||||
if len(self.patch_size) == 3:
|
||||
img_ids = self.img_ids(x)
|
||||
@@ -445,5 +485,5 @@ class HunyuanVideo(nn.Module):
|
||||
else:
|
||||
img_ids = self.img_ids_2d(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user