Support the LTXAV 2.3 model. (#12773)
This commit is contained in:
@@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def in_meta_context():
|
||||
return torch.device("meta") == torch.empty(0).device
|
||||
|
||||
def mark_conv3d_ended(module):
|
||||
tid = threading.get_ident()
|
||||
for _, m in module.named_modules():
|
||||
@@ -350,6 +353,10 @@ class Decoder(nn.Module):
|
||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||
if block_name == "compress_all":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
if block_name == "compress_space":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
if block_name == "compress_time":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims,
|
||||
@@ -395,17 +402,21 @@ class Decoder(nn.Module):
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 1, 1),
|
||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(1, 2, 2),
|
||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
@@ -455,6 +466,15 @@ class Decoder(nn.Module):
|
||||
output_channel * 2, 0, operations=ops,
|
||||
)
|
||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||
else:
|
||||
self.register_buffer(
|
||||
"last_scale_shift_table",
|
||||
torch.tensor(
|
||||
[0.0, 0.0],
|
||||
device="cpu" if in_meta_context() else None
|
||||
).unsqueeze(1).expand(2, output_channel),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
@@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module):
|
||||
self.scale_shift_table = nn.Parameter(
|
||||
torch.randn(4, in_channels) / in_channels**0.5
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"scale_shift_table",
|
||||
torch.tensor(
|
||||
[0.0, 0.0, 0.0, 0.0],
|
||||
device="cpu" if in_meta_context() else None
|
||||
).unsqueeze(1).expand(4, in_channels),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.temporal_cache_state={}
|
||||
|
||||
@@ -1012,9 +1041,6 @@ class processor(nn.Module):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
|
||||
self.register_buffer("channel", torch.empty(128))
|
||||
|
||||
def un_normalize(self, x):
|
||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
@@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self.guess_config(version)
|
||||
config = self.get_default_config(version)
|
||||
|
||||
self.config = config
|
||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
|
||||
self.decode_timestep = config.get("decode_timestep", 0.05)
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
@@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module):
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||
base_channels=config.get("encoder_base_channels", 128),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
@@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module):
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||
base_channels=config.get("decoder_base_channels", 128),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
@@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module):
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def guess_config(self, version):
|
||||
def get_default_config(self, version):
|
||||
if version == 0:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
@@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module):
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x, timestep=0.05, noise_scale=0.025):
|
||||
def decode(self, x):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
|
||||
|
||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||
|
||||
Reference in New Issue
Block a user