Support the LTXAV 2.3 model. (#12773)
This commit is contained in:
+158
-28
@@ -275,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NormSingleLinearTextProjection(nn.Module):
|
||||
"""Text projection for 20B models - single linear with RMSNorm (no activation)."""
|
||||
|
||||
def __init__(
|
||||
self, in_features, hidden_size, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
if operations is None:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
self.in_norm = operations.RMSNorm(
|
||||
in_features, eps=1e-6, elementwise_affine=False
|
||||
)
|
||||
self.linear_1 = operations.Linear(
|
||||
in_features, hidden_size, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.in_features = in_features
|
||||
|
||||
def forward(self, caption):
|
||||
caption = self.in_norm(caption)
|
||||
caption = caption * (self.hidden_size / self.in_features) ** 0.5
|
||||
return self.linear_1(caption)
|
||||
|
||||
|
||||
class GELU_approx(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@@ -343,6 +367,7 @@ class CrossAttention(nn.Module):
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
attn_precision=None,
|
||||
apply_gated_attention=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -362,6 +387,12 @@ class CrossAttention(nn.Module):
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
# Optional per-head gating
|
||||
if apply_gated_attention:
|
||||
self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_gate_logits = None
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
||||
)
|
||||
@@ -383,16 +414,30 @@ class CrossAttention(nn.Module):
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
|
||||
# Apply per-head gating if enabled
|
||||
if self.to_gate_logits is not None:
|
||||
gate_logits = self.to_gate_logits(x) # (B, T, H)
|
||||
b, t, _ = out.shape
|
||||
out = out.view(b, t, self.heads, self.dim_head)
|
||||
gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity
|
||||
out = out * gates.unsqueeze(-1)
|
||||
out = out.view(b, t, self.heads * self.dim_head)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate)
|
||||
ADALN_BASE_PARAMS_COUNT = 6
|
||||
ADALN_CROSS_ATTN_PARAMS_COUNT = 9
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
@@ -416,18 +461,25 @@ class BasicTransformerBlock(nn.Module):
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
if cross_attention_adaln:
|
||||
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
|
||||
|
||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||
x.addcmul_(attn1_input, gate_msa)
|
||||
del attn1_input
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2)
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa
|
||||
|
||||
if self.cross_attention_adaln:
|
||||
shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2)
|
||||
x += apply_cross_attention_adaln(
|
||||
x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca,
|
||||
self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options,
|
||||
)
|
||||
else:
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x)
|
||||
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||
@@ -435,6 +487,47 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype):
|
||||
"""Compute a single global prompt timestep for cross-attention ADaLN.
|
||||
|
||||
Uses the max across tokens (matching JAX max_per_segment) and broadcasts
|
||||
over text tokens. Returns None when *adaln_module* is None.
|
||||
"""
|
||||
if adaln_module is None:
|
||||
return None
|
||||
ts_input = (
|
||||
timestep_scaled.max(dim=1, keepdim=True).values.flatten()
|
||||
if timestep_scaled.dim() > 1
|
||||
else timestep_scaled.flatten()
|
||||
)
|
||||
prompt_ts, _ = adaln_module(
|
||||
ts_input,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1])
|
||||
|
||||
|
||||
def apply_cross_attention_adaln(
|
||||
x, context, attn, q_shift, q_scale, q_gate,
|
||||
prompt_scale_shift_table, prompt_timestep,
|
||||
attention_mask=None, transformer_options={},
|
||||
):
|
||||
"""Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV).
|
||||
|
||||
Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so
|
||||
that both regular tensors and CompressedTimestep are supported.
|
||||
"""
|
||||
batch_size = x.shape[0]
|
||||
shift_kv, scale_kv = (
|
||||
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
|
||||
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
|
||||
).unbind(dim=2)
|
||||
attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift
|
||||
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
|
||||
return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
n_pos_dims = indices_grid.shape[1]
|
||||
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
||||
@@ -556,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
vae_scale_factors: tuple = (8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
caption_proj_before_connector=False,
|
||||
cross_attention_adaln=False,
|
||||
caption_projection_first_linear=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -582,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.operations = operations
|
||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||
self.caption_proj_before_connector = caption_proj_before_connector
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
self.caption_projection_first_linear = caption_projection_first_linear
|
||||
|
||||
# Common dimensions
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
@@ -609,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
if self.cross_attention_adaln:
|
||||
self.prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
)
|
||||
else:
|
||||
self.prompt_adaln_single = None
|
||||
|
||||
if self.caption_proj_before_connector:
|
||||
if self.caption_projection_first_linear:
|
||||
self.caption_projection = NormSingleLinearTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
else:
|
||||
self.caption_projection = lambda a: a
|
||||
else:
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
@@ -665,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
timestep_scaled.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
@@ -677,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||
|
||||
return timestep, embedded_timestep
|
||||
prompt_timestep = compute_prompt_timestep(
|
||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||
)
|
||||
|
||||
return timestep, embedded_timestep, prompt_timestep
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
"""Prepare context for transformer blocks."""
|
||||
if self.caption_projection is not None:
|
||||
if self.caption_proj_before_connector is False:
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
return context, attention_mask
|
||||
|
||||
def _precompute_freqs_cis(
|
||||
@@ -792,7 +915,8 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
merged_args.update(additional_args)
|
||||
|
||||
# Prepare timestep and context
|
||||
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||
timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||
merged_args["prompt_timestep"] = prompt_timestep
|
||||
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
||||
|
||||
# Prepare attention mask and positional embeddings
|
||||
@@ -833,7 +957,9 @@ class LTXVModel(LTXBaseModel):
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
timestep_scale_multiplier=1000.0,
|
||||
caption_proj_before_connector=False,
|
||||
cross_attention_adaln=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -852,6 +978,8 @@ class LTXVModel(LTXBaseModel):
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
caption_proj_before_connector=caption_proj_before_connector,
|
||||
cross_attention_adaln=cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -860,7 +988,6 @@ class LTXVModel(LTXBaseModel):
|
||||
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize LTXV-specific components."""
|
||||
# No additional components needed for LTXV beyond base class
|
||||
pass
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
@@ -872,6 +999,7 @@ class LTXVModel(LTXBaseModel):
|
||||
self.num_attention_heads,
|
||||
self.attention_head_dim,
|
||||
context_dim=self.cross_attention_dim,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
@@ -1149,16 +1277,17 @@ class LTXVModel(LTXBaseModel):
|
||||
"""Process transformer blocks for LTXV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
prompt_timestep = kwargs.get("prompt_timestep", None)
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
@@ -1169,6 +1298,7 @@ class LTXVModel(LTXBaseModel):
|
||||
pe=pe,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
prompt_timestep=prompt_timestep,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user