Support the LTXAV 2.3 model. (#12773)
This commit is contained in:
@@ -2,11 +2,16 @@ from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.lightricks.model import (
|
||||
ADALN_BASE_PARAMS_COUNT,
|
||||
ADALN_CROSS_ATTN_PARAMS_COUNT,
|
||||
CrossAttention,
|
||||
FeedForward,
|
||||
AdaLayerNormSingle,
|
||||
PixArtAlphaTextProjection,
|
||||
NormSingleLinearTextProjection,
|
||||
LTXVModel,
|
||||
apply_cross_attention_adaln,
|
||||
compute_prompt_timestep,
|
||||
)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
@@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
v_context_dim=None,
|
||||
a_context_dim=None,
|
||||
attn_precision=None,
|
||||
apply_gated_attention=False,
|
||||
cross_attention_adaln=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
@@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
dim_head=vd_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
dim_head=ad_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=v_heads,
|
||||
dim_head=vd_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
||||
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype))
|
||||
self.audio_scale_shift_table = nn.Parameter(
|
||||
torch.empty(6, a_dim, device=device, dtype=dtype)
|
||||
torch.empty(num_ada_params, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
if cross_attention_adaln:
|
||||
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype))
|
||||
self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype))
|
||||
|
||||
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
@@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def _apply_text_cross_attention(
|
||||
self, x, context, attn, scale_shift_table, prompt_scale_shift_table,
|
||||
timestep, prompt_timestep, attention_mask, transformer_options,
|
||||
):
|
||||
"""Apply text cross-attention, with optional ADaLN modulation."""
|
||||
if self.cross_attention_adaln:
|
||||
shift_q, scale_q, gate = self.get_ada_values(
|
||||
scale_shift_table, x.shape[0], timestep, slice(6, 9)
|
||||
)
|
||||
return apply_cross_attention_adaln(
|
||||
x, context, attn, shift_q, scale_q, gate,
|
||||
prompt_scale_shift_table, prompt_timestep,
|
||||
attention_mask, transformer_options,
|
||||
)
|
||||
return attn(
|
||||
comfy.ldm.common_dit.rms_norm(x), context=context,
|
||||
mask=attention_mask, transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
||||
v_prompt_timestep=None, a_prompt_timestep=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
@@ -240,7 +279,11 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||
vx.addcmul_(attn1_out, vgate_msa)
|
||||
del vgate_msa, attn1_out
|
||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
vx.add_(self._apply_text_cross_attention(
|
||||
vx, v_context, self.attn2, self.scale_shift_table,
|
||||
getattr(self, 'prompt_scale_shift_table', None),
|
||||
v_timestep, v_prompt_timestep, attention_mask, transformer_options,)
|
||||
)
|
||||
|
||||
# audio
|
||||
if run_ax:
|
||||
@@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||
ax.addcmul_(attn1_out, agate_msa)
|
||||
del agate_msa, attn1_out
|
||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
ax.add_(self._apply_text_cross_attention(
|
||||
ax, a_context, self.audio_attn2, self.audio_scale_shift_table,
|
||||
getattr(self, 'audio_prompt_scale_shift_table', None),
|
||||
a_timestep, a_prompt_timestep, attention_mask, transformer_options,)
|
||||
)
|
||||
|
||||
# video - audio cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
@@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel):
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier=1000.0,
|
||||
av_ca_timestep_scale_multiplier=1.0,
|
||||
apply_gated_attention=False,
|
||||
caption_proj_before_connector=False,
|
||||
cross_attention_adaln=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel):
|
||||
self.audio_attention_head_dim = audio_attention_head_dim
|
||||
self.audio_num_attention_heads = audio_num_attention_heads
|
||||
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||
self.apply_gated_attention = apply_gated_attention
|
||||
|
||||
# Calculate audio dimensions
|
||||
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||
@@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel):
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
caption_proj_before_connector=caption_proj_before_connector,
|
||||
cross_attention_adaln=cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Audio-specific AdaLN
|
||||
audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.audio_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=audio_embedding_coefficient,
|
||||
use_additional_conditions=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
if self.cross_attention_adaln:
|
||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=2,
|
||||
use_additional_conditions=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
else:
|
||||
self.audio_prompt_adaln_single = None
|
||||
|
||||
num_scale_shift_values = 4
|
||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
@@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Audio caption projection
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
if self.caption_proj_before_connector:
|
||||
if self.caption_projection_first_linear:
|
||||
self.audio_caption_projection = NormSingleLinearTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
else:
|
||||
self.audio_caption_projection = lambda a: a
|
||||
else:
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
connector_split_rope = kwargs.get("rope_type", "split") == "split"
|
||||
connector_gated_attention = kwargs.get("connector_apply_gated_attention", False)
|
||||
attention_head_dim = kwargs.get("connector_attention_head_dim", 128)
|
||||
num_attention_heads = kwargs.get("connector_num_attention_heads", 30)
|
||||
num_layers = kwargs.get("connector_num_layers", 2)
|
||||
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim),
|
||||
num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads),
|
||||
num_layers=num_layers,
|
||||
split_rope=connector_split_rope,
|
||||
double_precision_rope=True,
|
||||
apply_gated_attention=connector_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_layers=num_layers,
|
||||
split_rope=connector_split_rope,
|
||||
double_precision_rope=True,
|
||||
apply_gated_attention=connector_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
def preprocess_text_embeds(self, context):
|
||||
if context.shape[-1] == self.caption_channels * 2:
|
||||
return context
|
||||
out_vid = self.video_embeddings_connector(context)[0]
|
||||
out_audio = self.audio_embeddings_connector(context)[0]
|
||||
def preprocess_text_embeds(self, context, unprocessed=False):
|
||||
# LTXv2 fully processed context has dimension of self.caption_channels * 2
|
||||
# LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim
|
||||
if not unprocessed:
|
||||
if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2):
|
||||
return context
|
||||
if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim:
|
||||
context_vid = context[:, :, :self.cross_attention_dim]
|
||||
context_audio = context[:, :, self.cross_attention_dim:]
|
||||
else:
|
||||
context_vid = context
|
||||
context_audio = context
|
||||
if self.caption_proj_before_connector:
|
||||
context_vid = self.caption_projection(context_vid)
|
||||
context_audio = self.audio_caption_projection(context_audio)
|
||||
out_vid = self.video_embeddings_connector(context_vid)[0]
|
||||
out_audio = self.audio_embeddings_connector(context_audio)[0]
|
||||
return torch.concat((out_vid, out_audio), dim=-1)
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
@@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel):
|
||||
ad_head=self.audio_attention_head_dim,
|
||||
v_context_dim=self.cross_attention_dim,
|
||||
a_context_dim=self.audio_cross_attention_dim,
|
||||
apply_gated_attention=self.apply_gated_attention,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
@@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel):
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||
|
||||
v_prompt_timestep = compute_prompt_timestep(
|
||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||
)
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
if a_timestep is not None:
|
||||
@@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel):
|
||||
|
||||
# Cross-attention timesteps - compress these too
|
||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
a_timestep_flat,
|
||||
timestep.max().expand_as(a_timestep_flat),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
timestep_flat,
|
||||
a_timestep.max().expand_as(timestep_flat),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
timestep_flat * av_ca_factor,
|
||||
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
a_timestep_flat * av_ca_factor,
|
||||
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
@@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel):
|
||||
# Audio timesteps
|
||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||
|
||||
a_prompt_timestep = compute_prompt_timestep(
|
||||
self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype
|
||||
)
|
||||
else:
|
||||
a_timestep = timestep_scaled
|
||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||
cross_av_timestep_ss = []
|
||||
a_prompt_timestep = None
|
||||
|
||||
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
||||
return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [
|
||||
v_embedded_timestep,
|
||||
a_embedded_timestep,
|
||||
]
|
||||
], None
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
video_dim = vx.shape[-1]
|
||||
audio_dim = ax.shape[-1]
|
||||
|
||||
v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim
|
||||
a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim
|
||||
|
||||
v_context, a_context = torch.split(
|
||||
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
||||
context, [v_context_dim, a_context_dim], len(context.shape) - 1
|
||||
)
|
||||
|
||||
v_context, attention_mask = super()._prepare_context(
|
||||
v_context, batch_size, vx, attention_mask
|
||||
)
|
||||
if self.audio_caption_projection is not None:
|
||||
if self.caption_proj_before_connector is False:
|
||||
a_context = self.audio_caption_projection(a_context)
|
||||
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
||||
a_context = a_context.view(batch_size, -1, audio_dim)
|
||||
|
||||
return [v_context, a_context], attention_mask
|
||||
|
||||
@@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel):
|
||||
av_ca_v2a_gate_noise_timestep,
|
||||
) = timestep[2]
|
||||
|
||||
v_prompt_timestep = timestep[3]
|
||||
a_prompt_timestep = timestep[4]
|
||||
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
@@ -771,6 +896,8 @@ class LTXAVModel(LTXVModel):
|
||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||
transformer_options=args["transformer_options"],
|
||||
self_attention_mask=args.get("self_attention_mask"),
|
||||
v_prompt_timestep=args.get("v_prompt_timestep"),
|
||||
a_prompt_timestep=args.get("a_prompt_timestep"),
|
||||
)
|
||||
return out
|
||||
|
||||
@@ -792,6 +919,8 @@ class LTXAVModel(LTXVModel):
|
||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||
"transformer_options": transformer_options,
|
||||
"self_attention_mask": self_attention_mask,
|
||||
"v_prompt_timestep": v_prompt_timestep,
|
||||
"a_prompt_timestep": a_prompt_timestep,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
@@ -814,6 +943,8 @@ class LTXAVModel(LTXVModel):
|
||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
v_prompt_timestep=v_prompt_timestep,
|
||||
a_prompt_timestep=a_prompt_timestep,
|
||||
)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
Reference in New Issue
Block a user