Implement hunyuan image refiner model. (#9817)

This commit is contained in:
comfyanonymous
2025-09-11 21:43:20 -07:00
committed by GitHub
parent 18de0b2830
commit 33bd9ed9cb
9 changed files with 367 additions and 12 deletions

View File

@@ -278,6 +278,7 @@ class HunyuanVideo(nn.Module):
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
disable_time_r=False,
control=None,
transformer_options={},
) -> Tensor:
@@ -288,7 +289,7 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
if self.time_r_in is not None:
if (self.time_r_in is not None) and (not disable_time_r):
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
if len(w) > 0:
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
@@ -428,14 +429,14 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
bs = x.shape[0]
if len(self.patch_size) == 3:
img_ids = self.img_ids(x)
@@ -443,5 +444,5 @@ class HunyuanVideo(nn.Module):
else:
img_ids = self.img_ids_2d(x)
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
return out