Support the LTXAV 2.3 model. (#12773)

This commit is contained in:
comfyanonymous
2026-03-04 17:06:20 -08:00
committed by GitHub
parent ac4a943ff3
commit 43c64b6308
10 changed files with 959 additions and 133 deletions
+5 -2
View File
@@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
CausalityAxis,
CausalAudioAutoencoder,
)
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE
LATENT_DOWNSAMPLE_FACTOR = 4
@@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module):
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
self.vocoder = Vocoder(config=component_config.vocoder)
if "bwe" in component_config.vocoder:
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
else:
self.vocoder = Vocoder(config=component_config.vocoder)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)
@@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module):
super().__init__()
if config is None:
config = self._guess_config()
config = self.get_default_config()
# Extract encoder and decoder configs from the new format
model_config = config.get("model", {}).get("params", {})
variables_config = config.get("variables", {})
self.sampling_rate = variables_config.get(
"sampling_rate",
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
self.sampling_rate = model_config.get(
"sampling_rate", config.get("sampling_rate", 16000)
)
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
decoder_config = model_config.get("decoder", encoder_config)
# Load mel spectrogram parameters
self.mel_bins = encoder_config.get("mel_bins", 64)
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
# Store causality configuration at VAE level (not just in encoder internals)
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value)
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
@@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module):
self.per_channel_statistics = processor()
def _guess_config(self):
encoder_config = {
# Required parameters - based on ltx-video-av-1679000 model metadata
"ch": 128,
"out_ch": 8,
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
"num_res_blocks": 2,
"attn_resolutions": [], # Based on metadata: empty list, no attention
"dropout": 0.0,
"resamp_with_conv": True,
"in_channels": 2, # stereo
"resolution": 256,
"z_channels": 8,
def get_default_config(self):
ddconfig = {
"double_z": True,
"attn_type": "vanilla",
"mid_block_add_attention": False, # Based on metadata: false
"mel_bins": 64,
"z_channels": 8,
"resolution": 256,
"downsample_time": False,
"in_channels": 2,
"out_ch": 2,
"ch": 128,
"ch_mult": [1, 2, 4],
"num_res_blocks": 2,
"attn_resolutions": [],
"dropout": 0.0,
"mid_block_add_attention": False,
"norm_type": "pixel",
"causality_axis": "height", # Based on metadata
"mel_bins": 64, # Based on metadata: mel_bins = 64
}
decoder_config = {
# Inherits encoder config, can override specific params
**encoder_config,
"out_ch": 2, # Stereo audio output (2 channels)
"give_pre_end": False,
"tanh_out": False,
"causality_axis": "height",
}
config = {
"_class_name": "CausalAudioAutoencoder",
"sampling_rate": 16000,
"model": {
"params": {
"encoder": encoder_config,
"decoder": decoder_config,
"ddconfig": ddconfig,
"sampling_rate": 16000,
}
},
"preprocessing": {
"stft": {
"filter_length": 1024,
"hop_length": 160,
},
},
}
return config
@@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
def in_meta_context():
return torch.device("meta") == torch.empty(0).device
def mark_conv3d_ended(module):
tid = threading.get_ident()
for _, m in module.named_modules():
@@ -350,6 +353,10 @@ class Decoder(nn.Module):
output_channel = output_channel * block_params.get("multiplier", 2)
if block_name == "compress_all":
output_channel = output_channel * block_params.get("multiplier", 1)
if block_name == "compress_space":
output_channel = output_channel * block_params.get("multiplier", 1)
if block_name == "compress_time":
output_channel = output_channel * block_params.get("multiplier", 1)
self.conv_in = make_conv_nd(
dims,
@@ -395,17 +402,21 @@ class Decoder(nn.Module):
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 1, 1),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(1, 2, 2),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
@@ -455,6 +466,15 @@ class Decoder(nn.Module):
output_channel * 2, 0, operations=ops,
)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
else:
self.register_buffer(
"last_scale_shift_table",
torch.tensor(
[0.0, 0.0],
device="cpu" if in_meta_context() else None
).unsqueeze(1).expand(2, output_channel),
persistent=False,
)
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
@@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module):
self.scale_shift_table = nn.Parameter(
torch.randn(4, in_channels) / in_channels**0.5
)
else:
self.register_buffer(
"scale_shift_table",
torch.tensor(
[0.0, 0.0, 0.0, 0.0],
device="cpu" if in_meta_context() else None
).unsqueeze(1).expand(4, in_channels),
persistent=False,
)
self.temporal_cache_state={}
@@ -1012,9 +1041,6 @@ class processor(nn.Module):
super().__init__()
self.register_buffer("std-of-means", torch.empty(128))
self.register_buffer("mean-of-means", torch.empty(128))
self.register_buffer("mean-of-stds", torch.empty(128))
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
self.register_buffer("channel", torch.empty(128))
def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
@@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module):
super().__init__()
if config is None:
config = self.guess_config(version)
config = self.get_default_config(version)
self.config = config
self.timestep_conditioning = config.get("timestep_conditioning", False)
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
self.decode_timestep = config.get("decode_timestep", 0.05)
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
@@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module):
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
base_channels=config.get("encoder_base_channels", 128),
)
self.decoder = Decoder(
@@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module):
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
base_channels=config.get("decoder_base_channels", 128),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
@@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module):
self.per_channel_statistics = processor()
def guess_config(self, version):
def get_default_config(self, version):
if version == 0:
config = {
"_class_name": "CausalVideoAutoencoder",
@@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode(self, x, timestep=0.05, noise_scale=0.025):
def decode(self, x):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)