Make the ltx audio vae more native. (#13486)
This commit is contained in:
@@ -104,7 +104,7 @@ def vae_decode_audio(vae, samples, tile=None, overlap=None):
|
||||
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100))
|
||||
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}
|
||||
|
||||
|
||||
|
||||
@@ -3,9 +3,8 @@ import comfy.utils
|
||||
import comfy.model_management
|
||||
import torch
|
||||
|
||||
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
from comfy_extras.nodes_audio import VAEEncodeAudio
|
||||
|
||||
class LTXVAudioVAELoader(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -28,10 +27,14 @@ class LTXVAudioVAELoader(io.ComfyNode):
|
||||
def execute(cls, ckpt_name: str) -> io.NodeOutput:
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
return io.NodeOutput(AudioVAE(sd, metadata))
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder.", "vocoder.": "vocoder."}, filter_keys=True)
|
||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
||||
vae.throw_exception_if_invalid()
|
||||
|
||||
return io.NodeOutput(vae)
|
||||
|
||||
|
||||
class LTXVAudioVAEEncode(io.ComfyNode):
|
||||
class LTXVAudioVAEEncode(VAEEncodeAudio):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
@@ -50,15 +53,8 @@ class LTXVAudioVAEEncode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||
audio_latents = audio_vae.encode(audio)
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"samples": audio_latents,
|
||||
"sample_rate": int(audio_vae.sample_rate),
|
||||
"type": "audio",
|
||||
}
|
||||
)
|
||||
def execute(cls, audio, audio_vae) -> io.NodeOutput:
|
||||
return super().execute(audio_vae, audio)
|
||||
|
||||
|
||||
class LTXVAudioVAEDecode(io.ComfyNode):
|
||||
@@ -80,12 +76,12 @@ class LTXVAudioVAEDecode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput:
|
||||
def execute(cls, samples, audio_vae) -> io.NodeOutput:
|
||||
audio_latent = samples["samples"]
|
||||
if audio_latent.is_nested:
|
||||
audio_latent = audio_latent.unbind()[-1]
|
||||
audio = audio_vae.decode(audio_latent).to(audio_latent.device)
|
||||
output_audio_sample_rate = audio_vae.output_sample_rate
|
||||
audio = audio_vae.decode(audio_latent).movedim(-1, 1).to(audio_latent.device)
|
||||
output_audio_sample_rate = audio_vae.first_stage_model.output_sample_rate
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"waveform": audio,
|
||||
@@ -143,17 +139,17 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
|
||||
frames_number: int,
|
||||
frame_rate: int,
|
||||
batch_size: int,
|
||||
audio_vae: AudioVAE,
|
||||
audio_vae,
|
||||
) -> io.NodeOutput:
|
||||
"""Generate empty audio latents matching the reference pipeline structure."""
|
||||
|
||||
assert audio_vae is not None, "Audio VAE model is required"
|
||||
|
||||
z_channels = audio_vae.latent_channels
|
||||
audio_freq = audio_vae.latent_frequency_bins
|
||||
sampling_rate = int(audio_vae.sample_rate)
|
||||
audio_freq = audio_vae.first_stage_model.latent_frequency_bins
|
||||
sampling_rate = int(audio_vae.first_stage_model.sample_rate)
|
||||
|
||||
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate)
|
||||
num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate)
|
||||
|
||||
audio_latents = torch.zeros(
|
||||
(batch_size, z_channels, num_audio_latents, audio_freq),
|
||||
|
||||
Reference in New Issue
Block a user