Basic WIP support for the wan animate model. (#9939)

This commit is contained in:
comfyanonymous
2025-09-19 00:07:17 -07:00
committed by GitHub
parent 711bcf33ee
commit dc95b6acc0
5 changed files with 666 additions and 1 deletions

View File

@@ -1108,6 +1108,89 @@ class WanHuMoImageToVideo(io.ComfyNode):
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanAnimateToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanAnimateToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("reference_image", optional=True),
io.Image.Input("face_video", optional=True),
io.Image.Input("pose_video", optional=True),
io.Int.Input("continue_motion_max_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Image.Input("continue_motion", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
io.Int.Output(display_name="trim_latent"),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None) -> io.NodeOutput:
latent_length = ((length - 1) // 4) + 1
latent_width = width // 8
latent_height = height // 8
trim_latent = 0
if reference_image is None:
reference_image = torch.zeros((1, height, width, 3))
image = comfy.utils.common_upscale(reference_image[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = torch.zeros((1, 1, concat_latent_image.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype)
trim_latent += concat_latent_image.shape[2]
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
if face_video is not None:
face_video = comfy.utils.common_upscale(face_video[:length].movedim(-1, 1), 512, 512, "area", "center") * 2.0 - 1.0
face_video = face_video.movedim(0, 1).unsqueeze(0)
positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video})
negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0})
if pose_video is not None:
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
pose_video_latent = vae.encode(pose_video[:, :, :, :3])
positive = node_helpers.conditioning_set_values(positive, {"pose_video_latent": pose_video_latent})
negative = node_helpers.conditioning_set_values(negative, {"pose_video_latent": pose_video_latent})
if continue_motion is None:
image = torch.ones((length, height, width, 3)) * 0.5
else:
continue_motion = continue_motion[-continue_motion_max_frames:]
continue_motion = comfy.utils.common_upscale(continue_motion[-length:].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
image = torch.ones((length, height, width, continue_motion.shape[-1]), device=continue_motion.device, dtype=continue_motion.dtype) * 0.5
image[:continue_motion.shape[0]] = continue_motion
concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2)
mask_refmotion = torch.ones((1, 1, latent_length, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype)
if continue_motion is not None:
mask_refmotion[:, :, :((continue_motion.shape[0] - 1) // 4) + 1] = 0.0
mask = torch.cat((mask, mask_refmotion), dim=2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
latent = torch.zeros([batch_size, 16, latent_length + trim_latent, latent_height, latent_width], device=comfy.model_management.intermediate_device())
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent, trim_latent)
class Wan22ImageToVideoLatent(io.ComfyNode):
@classmethod
def define_schema(cls):
@@ -1169,6 +1252,7 @@ class WanExtension(ComfyExtension):
WanSoundImageToVideo,
WanSoundImageToVideoExtend,
WanHuMoImageToVideo,
WanAnimateToVideo,
Wan22ImageToVideoLatent,
]