feat: LTX2: Support reference audio (ID-LoRA) (#13111)
This commit is contained in:
@@ -3,6 +3,7 @@ import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.model_sampling
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
import math
|
||||
import numpy as np
|
||||
@@ -682,6 +683,84 @@ class LTXVSeparateAVLatent(io.ComfyNode):
|
||||
return io.NodeOutput(video_latent, audio_latent)
|
||||
|
||||
|
||||
class LTXVReferenceAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVReferenceAudio",
|
||||
display_name="LTXV Reference Audio (ID-LoRA)",
|
||||
category="conditioning/audio",
|
||||
description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."),
|
||||
io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."),
|
||||
io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."),
|
||||
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."),
|
||||
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
|
||||
# Encode reference audio to latents and patchify
|
||||
audio_latents = audio_vae.encode(reference_audio)
|
||||
b, c, t, f = audio_latents.shape
|
||||
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
|
||||
ref_audio = {"tokens": ref_tokens}
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio})
|
||||
|
||||
# Patch model with identity guidance
|
||||
m = model.clone()
|
||||
scale = identity_guidance_scale
|
||||
model_sampling = m.get_model_object("model_sampling")
|
||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
def post_cfg_function(args):
|
||||
if scale == 0:
|
||||
return args["denoised"]
|
||||
|
||||
sigma = args["sigma"]
|
||||
sigma_ = sigma[0].item()
|
||||
if sigma_ > sigma_start or sigma_ < sigma_end:
|
||||
return args["denoised"]
|
||||
|
||||
cond_pred = args["cond_denoised"]
|
||||
cond = args["cond"]
|
||||
cfg_result = args["denoised"]
|
||||
model_options = args["model_options"].copy()
|
||||
x = args["input"]
|
||||
|
||||
# Strip ref_audio from conditioning for the no-reference pass
|
||||
noref_cond = []
|
||||
for entry in cond:
|
||||
new_entry = entry.copy()
|
||||
mc = new_entry.get("model_conds", {}).copy()
|
||||
mc.pop("ref_audio", None)
|
||||
new_entry["model_conds"] = mc
|
||||
noref_cond.append(new_entry)
|
||||
|
||||
(pred_noref,) = comfy.samplers.calc_cond_batch(
|
||||
args["model"], [noref_cond], x, sigma, model_options
|
||||
)
|
||||
|
||||
return cfg_result + (cond_pred - pred_noref) * scale
|
||||
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
return io.NodeOutput(m, positive, negative)
|
||||
|
||||
|
||||
class LtxvExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@@ -697,6 +776,7 @@ class LtxvExtension(ComfyExtension):
|
||||
LTXVCropGuides,
|
||||
LTXVConcatAVLatent,
|
||||
LTXVSeparateAVLatent,
|
||||
LTXVReferenceAudio,
|
||||
]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user