feat: LTX2: Support reference audio (ID-LoRA) (#13111)

This commit is contained in:
Jukka Seppänen
2026-03-24 00:22:24 +02:00
committed by GitHub
parent da6edb5a4e
commit e87858e974
3 changed files with 126 additions and 0 deletions
+80
View File
@@ -3,6 +3,7 @@ import node_helpers
import torch
import comfy.model_management
import comfy.model_sampling
import comfy.samplers
import comfy.utils
import math
import numpy as np
@@ -682,6 +683,84 @@ class LTXVSeparateAVLatent(io.ComfyNode):
return io.NodeOutput(video_latent, audio_latent)
class LTXVReferenceAudio(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVReferenceAudio",
display_name="LTXV Reference Audio (ID-LoRA)",
category="conditioning/audio",
description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."),
io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."),
io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."),
],
outputs=[
io.Model.Output(),
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
@classmethod
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
# Encode reference audio to latents and patchify
audio_latents = audio_vae.encode(reference_audio)
b, c, t, f = audio_latents.shape
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
ref_audio = {"tokens": ref_tokens}
positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio})
negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio})
# Patch model with identity guidance
m = model.clone()
scale = identity_guidance_scale
model_sampling = m.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
def post_cfg_function(args):
if scale == 0:
return args["denoised"]
sigma = args["sigma"]
sigma_ = sigma[0].item()
if sigma_ > sigma_start or sigma_ < sigma_end:
return args["denoised"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
model_options = args["model_options"].copy()
x = args["input"]
# Strip ref_audio from conditioning for the no-reference pass
noref_cond = []
for entry in cond:
new_entry = entry.copy()
mc = new_entry.get("model_conds", {}).copy()
mc.pop("ref_audio", None)
new_entry["model_conds"] = mc
noref_cond.append(new_entry)
(pred_noref,) = comfy.samplers.calc_cond_batch(
args["model"], [noref_cond], x, sigma, model_options
)
return cfg_result + (cond_pred - pred_noref) * scale
m.set_model_sampler_post_cfg_function(post_cfg_function)
return io.NodeOutput(m, positive, negative)
class LtxvExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@@ -697,6 +776,7 @@ class LtxvExtension(ComfyExtension):
LTXVCropGuides,
LTXVConcatAVLatent,
LTXVSeparateAVLatent,
LTXVReferenceAudio,
]