Implement the mmaudio VAE. (#10300)
This commit is contained in:
24
comfy/sd.py
24
comfy/sd.py
@@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
import comfy.ldm.mmaudio.vae.autoencoder
|
||||
import comfy.pixel_space_convert
|
||||
import yaml
|
||||
import math
|
||||
@@ -291,6 +292,7 @@ class VAE:
|
||||
self.downscale_index_formula = None
|
||||
self.upscale_index_formula = None
|
||||
self.extra_1d_channel = None
|
||||
self.crop_input = True
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
@@ -542,6 +544,25 @@ class VAE:
|
||||
self.latent_channels = 3
|
||||
self.latent_dim = 2
|
||||
self.output_channels = 3
|
||||
elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
|
||||
sample_rate = 16000
|
||||
if sample_rate == 16000:
|
||||
mode = '16k'
|
||||
else:
|
||||
mode = '44k'
|
||||
|
||||
self.first_stage_model = comfy.ldm.mmaudio.vae.autoencoder.AudioAutoencoder(mode=mode)
|
||||
self.memory_used_encode = lambda shape, dtype: (30 * shape[2]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (90 * shape[2] * 1411.2) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 20
|
||||
self.output_channels = 2
|
||||
self.upscale_ratio = 512 * (44100 / sample_rate)
|
||||
self.downscale_ratio = 512 * (44100 / sample_rate)
|
||||
self.latent_dim = 1
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float32]
|
||||
self.crop_input = False
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@@ -575,6 +596,9 @@ class VAE:
|
||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
if not self.crop_input:
|
||||
return pixels
|
||||
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
|
||||
dims = pixels.shape[1:-1]
|
||||
|
||||
Reference in New Issue
Block a user