Support the LTXAV 2.3 model. (#12773)
This commit is contained in:
@@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self._guess_config()
|
||||
config = self.get_default_config()
|
||||
|
||||
# Extract encoder and decoder configs from the new format
|
||||
model_config = config.get("model", {}).get("params", {})
|
||||
variables_config = config.get("variables", {})
|
||||
|
||||
self.sampling_rate = variables_config.get(
|
||||
"sampling_rate",
|
||||
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
|
||||
self.sampling_rate = model_config.get(
|
||||
"sampling_rate", config.get("sampling_rate", 16000)
|
||||
)
|
||||
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
||||
decoder_config = model_config.get("decoder", encoder_config)
|
||||
|
||||
# Load mel spectrogram parameters
|
||||
self.mel_bins = encoder_config.get("mel_bins", 64)
|
||||
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||
self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||
self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||
|
||||
# Store causality configuration at VAE level (not just in encoder internals)
|
||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
|
||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value)
|
||||
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
||||
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
||||
|
||||
@@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module):
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def _guess_config(self):
|
||||
encoder_config = {
|
||||
# Required parameters - based on ltx-video-av-1679000 model metadata
|
||||
"ch": 128,
|
||||
"out_ch": 8,
|
||||
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": [], # Based on metadata: empty list, no attention
|
||||
"dropout": 0.0,
|
||||
"resamp_with_conv": True,
|
||||
"in_channels": 2, # stereo
|
||||
"resolution": 256,
|
||||
"z_channels": 8,
|
||||
def get_default_config(self):
|
||||
ddconfig = {
|
||||
"double_z": True,
|
||||
"attn_type": "vanilla",
|
||||
"mid_block_add_attention": False, # Based on metadata: false
|
||||
"mel_bins": 64,
|
||||
"z_channels": 8,
|
||||
"resolution": 256,
|
||||
"downsample_time": False,
|
||||
"in_channels": 2,
|
||||
"out_ch": 2,
|
||||
"ch": 128,
|
||||
"ch_mult": [1, 2, 4],
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": [],
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height", # Based on metadata
|
||||
"mel_bins": 64, # Based on metadata: mel_bins = 64
|
||||
}
|
||||
|
||||
decoder_config = {
|
||||
# Inherits encoder config, can override specific params
|
||||
**encoder_config,
|
||||
"out_ch": 2, # Stereo audio output (2 channels)
|
||||
"give_pre_end": False,
|
||||
"tanh_out": False,
|
||||
"causality_axis": "height",
|
||||
}
|
||||
|
||||
config = {
|
||||
"_class_name": "CausalAudioAutoencoder",
|
||||
"sampling_rate": 16000,
|
||||
"model": {
|
||||
"params": {
|
||||
"encoder": encoder_config,
|
||||
"decoder": decoder_config,
|
||||
"ddconfig": ddconfig,
|
||||
"sampling_rate": 16000,
|
||||
}
|
||||
},
|
||||
"preprocessing": {
|
||||
"stft": {
|
||||
"filter_length": 1024,
|
||||
"hop_length": 160,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
Reference in New Issue
Block a user