Support the LTXAV 2.3 model. (#12773)

This commit is contained in:
comfyanonymous
2026-03-04 17:06:20 -08:00
committed by GitHub
parent ac4a943ff3
commit 43c64b6308
10 changed files with 959 additions and 133 deletions
+158 -27
View File
@@ -2,11 +2,16 @@ from typing import Tuple
import torch
import torch.nn as nn
from comfy.ldm.lightricks.model import (
ADALN_BASE_PARAMS_COUNT,
ADALN_CROSS_ATTN_PARAMS_COUNT,
CrossAttention,
FeedForward,
AdaLayerNormSingle,
PixArtAlphaTextProjection,
NormSingleLinearTextProjection,
LTXVModel,
apply_cross_attention_adaln,
compute_prompt_timestep,
)
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
@@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module):
v_context_dim=None,
a_context_dim=None,
attn_precision=None,
apply_gated_attention=False,
cross_attention_adaln=False,
dtype=None,
device=None,
operations=None,
@@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module):
super().__init__()
self.attn_precision = attn_precision
self.cross_attention_adaln = cross_attention_adaln
self.attn1 = CrossAttention(
query_dim=v_dim,
@@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=vd_head,
context_dim=None,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=ad_head,
context_dim=None,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=v_heads,
dim_head=vd_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module):
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype))
self.audio_scale_shift_table = nn.Parameter(
torch.empty(6, a_dim, device=device, dtype=dtype)
torch.empty(num_ada_params, a_dim, device=device, dtype=dtype)
)
if cross_attention_adaln:
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype))
self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype))
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
torch.empty(5, a_dim, device=device, dtype=dtype)
)
@@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module):
return (*scale_shift_ada_values, *gate_ada_values)
def _apply_text_cross_attention(
self, x, context, attn, scale_shift_table, prompt_scale_shift_table,
timestep, prompt_timestep, attention_mask, transformer_options,
):
"""Apply text cross-attention, with optional ADaLN modulation."""
if self.cross_attention_adaln:
shift_q, scale_q, gate = self.get_ada_values(
scale_shift_table, x.shape[0], timestep, slice(6, 9)
)
return apply_cross_attention_adaln(
x, context, attn, shift_q, scale_q, gate,
prompt_scale_shift_table, prompt_timestep,
attention_mask, transformer_options,
)
return attn(
comfy.ldm.common_dit.rms_norm(x), context=context,
mask=attention_mask, transformer_options=transformer_options,
)
def forward(
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
v_prompt_timestep=None, a_prompt_timestep=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True)
@@ -240,7 +279,11 @@ class BasicAVTransformerBlock(nn.Module):
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
vx.addcmul_(attn1_out, vgate_msa)
del vgate_msa, attn1_out
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
vx.add_(self._apply_text_cross_attention(
vx, v_context, self.attn2, self.scale_shift_table,
getattr(self, 'prompt_scale_shift_table', None),
v_timestep, v_prompt_timestep, attention_mask, transformer_options,)
)
# audio
if run_ax:
@@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module):
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
ax.addcmul_(attn1_out, agate_msa)
del agate_msa, attn1_out
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
ax.add_(self._apply_text_cross_attention(
ax, a_context, self.audio_attn2, self.audio_scale_shift_table,
getattr(self, 'audio_prompt_scale_shift_table', None),
a_timestep, a_prompt_timestep, attention_mask, transformer_options,)
)
# video - audio cross attention.
if run_a2v or run_v2a:
@@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel):
use_middle_indices_grid=False,
timestep_scale_multiplier=1000.0,
av_ca_timestep_scale_multiplier=1.0,
apply_gated_attention=False,
caption_proj_before_connector=False,
cross_attention_adaln=False,
dtype=None,
device=None,
operations=None,
@@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel):
self.audio_attention_head_dim = audio_attention_head_dim
self.audio_num_attention_heads = audio_num_attention_heads
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
self.apply_gated_attention = apply_gated_attention
# Calculate audio dimensions
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
@@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel):
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
caption_proj_before_connector=caption_proj_before_connector,
cross_attention_adaln=cross_attention_adaln,
dtype=dtype,
device=device,
operations=operations,
@@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel):
)
# Audio-specific AdaLN
audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.audio_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=audio_embedding_coefficient,
use_additional_conditions=False,
dtype=dtype,
device=device,
operations=self.operations,
)
if self.cross_attention_adaln:
self.audio_prompt_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=2,
use_additional_conditions=False,
dtype=dtype,
device=device,
operations=self.operations,
)
else:
self.audio_prompt_adaln_single = None
num_scale_shift_values = 4
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
self.inner_dim,
@@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel):
)
# Audio caption projection
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
if self.caption_proj_before_connector:
if self.caption_projection_first_linear:
self.audio_caption_projection = NormSingleLinearTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
else:
self.audio_caption_projection = lambda a: a
else:
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
connector_split_rope = kwargs.get("rope_type", "split") == "split"
connector_gated_attention = kwargs.get("connector_apply_gated_attention", False)
attention_head_dim = kwargs.get("connector_attention_head_dim", 128)
num_attention_heads = kwargs.get("connector_num_attention_heads", 30)
num_layers = kwargs.get("connector_num_layers", 2)
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim),
num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads),
num_layers=num_layers,
split_rope=connector_split_rope,
double_precision_rope=True,
apply_gated_attention=connector_gated_attention,
dtype=dtype,
device=device,
operations=self.operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
num_layers=num_layers,
split_rope=connector_split_rope,
double_precision_rope=True,
apply_gated_attention=connector_gated_attention,
dtype=dtype,
device=device,
operations=self.operations,
)
def preprocess_text_embeds(self, context):
if context.shape[-1] == self.caption_channels * 2:
return context
out_vid = self.video_embeddings_connector(context)[0]
out_audio = self.audio_embeddings_connector(context)[0]
def preprocess_text_embeds(self, context, unprocessed=False):
# LTXv2 fully processed context has dimension of self.caption_channels * 2
# LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim
if not unprocessed:
if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2):
return context
if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim:
context_vid = context[:, :, :self.cross_attention_dim]
context_audio = context[:, :, self.cross_attention_dim:]
else:
context_vid = context
context_audio = context
if self.caption_proj_before_connector:
context_vid = self.caption_projection(context_vid)
context_audio = self.audio_caption_projection(context_audio)
out_vid = self.video_embeddings_connector(context_vid)[0]
out_audio = self.audio_embeddings_connector(context_audio)[0]
return torch.concat((out_vid, out_audio), dim=-1)
def _init_transformer_blocks(self, device, dtype, **kwargs):
@@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel):
ad_head=self.audio_attention_head_dim,
v_context_dim=self.cross_attention_dim,
a_context_dim=self.audio_cross_attention_dim,
apply_gated_attention=self.apply_gated_attention,
cross_attention_adaln=self.cross_attention_adaln,
dtype=dtype,
device=device,
operations=self.operations,
@@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel):
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
v_prompt_timestep = compute_prompt_timestep(
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
)
# Prepare audio timestep
a_timestep = kwargs.get("a_timestep")
if a_timestep is not None:
@@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel):
# Cross-attention timesteps - compress these too
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
a_timestep_flat,
timestep.max().expand_as(a_timestep_flat),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
timestep_flat,
a_timestep.max().expand_as(timestep_flat),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
timestep_flat * av_ca_factor,
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
a_timestep_flat * av_ca_factor,
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
@@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel):
# Audio timesteps
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
a_prompt_timestep = compute_prompt_timestep(
self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype
)
else:
a_timestep = timestep_scaled
a_embedded_timestep = kwargs.get("embedded_timestep")
cross_av_timestep_ss = []
a_prompt_timestep = None
return [v_timestep, a_timestep, cross_av_timestep_ss], [
return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [
v_embedded_timestep,
a_embedded_timestep,
]
], None
def _prepare_context(self, context, batch_size, x, attention_mask=None):
vx = x[0]
ax = x[1]
video_dim = vx.shape[-1]
audio_dim = ax.shape[-1]
v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim
a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim
v_context, a_context = torch.split(
context, int(context.shape[-1] / 2), len(context.shape) - 1
context, [v_context_dim, a_context_dim], len(context.shape) - 1
)
v_context, attention_mask = super()._prepare_context(
v_context, batch_size, vx, attention_mask
)
if self.audio_caption_projection is not None:
if self.caption_proj_before_connector is False:
a_context = self.audio_caption_projection(a_context)
a_context = a_context.view(batch_size, -1, ax.shape[-1])
a_context = a_context.view(batch_size, -1, audio_dim)
return [v_context, a_context], attention_mask
@@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel):
av_ca_v2a_gate_noise_timestep,
) = timestep[2]
v_prompt_timestep = timestep[3]
a_prompt_timestep = timestep[4]
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
@@ -771,6 +896,8 @@ class LTXAVModel(LTXVModel):
a_cross_gate_timestep=args["a_cross_gate_timestep"],
transformer_options=args["transformer_options"],
self_attention_mask=args.get("self_attention_mask"),
v_prompt_timestep=args.get("v_prompt_timestep"),
a_prompt_timestep=args.get("a_prompt_timestep"),
)
return out
@@ -792,6 +919,8 @@ class LTXAVModel(LTXVModel):
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
"transformer_options": transformer_options,
"self_attention_mask": self_attention_mask,
"v_prompt_timestep": v_prompt_timestep,
"a_prompt_timestep": a_prompt_timestep,
},
{"original_block": block_wrap},
)
@@ -814,6 +943,8 @@ class LTXAVModel(LTXVModel):
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
v_prompt_timestep=v_prompt_timestep,
a_prompt_timestep=a_prompt_timestep,
)
return [vx, ax]