Support the LTXAV 2.3 model. (#12773)

This commit is contained in:
comfyanonymous
2026-03-04 17:06:20 -08:00
committed by GitHub
parent ac4a943ff3
commit 43c64b6308
10 changed files with 959 additions and 133 deletions
@@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
def in_meta_context():
return torch.device("meta") == torch.empty(0).device
def mark_conv3d_ended(module):
tid = threading.get_ident()
for _, m in module.named_modules():
@@ -350,6 +353,10 @@ class Decoder(nn.Module):
output_channel = output_channel * block_params.get("multiplier", 2)
if block_name == "compress_all":
output_channel = output_channel * block_params.get("multiplier", 1)
if block_name == "compress_space":
output_channel = output_channel * block_params.get("multiplier", 1)
if block_name == "compress_time":
output_channel = output_channel * block_params.get("multiplier", 1)
self.conv_in = make_conv_nd(
dims,
@@ -395,17 +402,21 @@ class Decoder(nn.Module):
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 1, 1),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(1, 2, 2),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
@@ -455,6 +466,15 @@ class Decoder(nn.Module):
output_channel * 2, 0, operations=ops,
)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
else:
self.register_buffer(
"last_scale_shift_table",
torch.tensor(
[0.0, 0.0],
device="cpu" if in_meta_context() else None
).unsqueeze(1).expand(2, output_channel),
persistent=False,
)
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
@@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module):
self.scale_shift_table = nn.Parameter(
torch.randn(4, in_channels) / in_channels**0.5
)
else:
self.register_buffer(
"scale_shift_table",
torch.tensor(
[0.0, 0.0, 0.0, 0.0],
device="cpu" if in_meta_context() else None
).unsqueeze(1).expand(4, in_channels),
persistent=False,
)
self.temporal_cache_state={}
@@ -1012,9 +1041,6 @@ class processor(nn.Module):
super().__init__()
self.register_buffer("std-of-means", torch.empty(128))
self.register_buffer("mean-of-means", torch.empty(128))
self.register_buffer("mean-of-stds", torch.empty(128))
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
self.register_buffer("channel", torch.empty(128))
def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
@@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module):
super().__init__()
if config is None:
config = self.guess_config(version)
config = self.get_default_config(version)
self.config = config
self.timestep_conditioning = config.get("timestep_conditioning", False)
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
self.decode_timestep = config.get("decode_timestep", 0.05)
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
@@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module):
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
base_channels=config.get("encoder_base_channels", 128),
)
self.decoder = Decoder(
@@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module):
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
base_channels=config.get("decoder_base_channels", 128),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
@@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module):
self.per_channel_statistics = processor()
def guess_config(self, version):
def get_default_config(self, version):
if version == 0:
config = {
"_class_name": "CausalVideoAutoencoder",
@@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode(self, x, timestep=0.05, noise_scale=0.025):
def decode(self, x):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)