Support the LTXAV 2.3 model. (#12773)
This commit is contained in:
@@ -3,6 +3,7 @@ import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
@@ -12,6 +13,307 @@ def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
|
||||
# Adopted from https://github.com/NVIDIA/BigVGAN
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sinc(x: torch.Tensor):
|
||||
return torch.where(
|
||||
x == 0,
|
||||
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
||||
torch.sin(math.pi * x) / math.pi / x,
|
||||
)
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
delta_f = 4 * half_width
|
||||
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if A > 50.0:
|
||||
beta = 0.1102 * (A - 8.7)
|
||||
elif A >= 21.0:
|
||||
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
||||
else:
|
||||
beta = 0.0
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
if even:
|
||||
time = torch.arange(-half_size, half_size) + 0.5
|
||||
else:
|
||||
time = torch.arange(kernel_size) - half_size
|
||||
if cutoff == 0:
|
||||
filter_ = torch.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
||||
filter_ /= filter_.sum()
|
||||
filter = filter_.view(1, 1, kernel_size)
|
||||
return filter
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cutoff=0.5,
|
||||
half_width=0.6,
|
||||
stride=1,
|
||||
padding=True,
|
||||
padding_mode="replicate",
|
||||
kernel_size=12,
|
||||
):
|
||||
super().__init__()
|
||||
if cutoff < -0.0:
|
||||
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||
if cutoff > 0.5:
|
||||
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||
self.kernel_size = kernel_size
|
||||
self.even = kernel_size % 2 == 0
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||
return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.stride = ratio
|
||||
|
||||
if window_type == "hann":
|
||||
# Hann-windowed sinc filter — identical to torchaudio.functional.resample
|
||||
# with its default parameters (rolloff=0.99, lowpass_filter_width=6).
|
||||
# Uses replicate boundary padding, matching the reference resampler exactly.
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
self.kernel_size = 2 * width * ratio + 1
|
||||
self.pad = width
|
||||
self.pad_left = 2 * width * ratio
|
||||
self.pad_right = self.kernel_size - ratio
|
||||
t = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
||||
t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width)
|
||||
window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||
filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1)
|
||||
else:
|
||||
# Kaiser-windowed sinc filter (BigVGAN default).
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = (
|
||||
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
)
|
||||
filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
||||
)
|
||||
|
||||
self.register_buffer("filter", filter, persistent=persistent)
|
||||
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
||||
)
|
||||
x = x[..., self.pad_left : -self.pad_right]
|
||||
return x
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.lowpass = LowPassFilter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.lowpass(x)
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
activation,
|
||||
up_ratio=2,
|
||||
down_ratio=2,
|
||||
up_kernel_size=12,
|
||||
down_kernel_size=12,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BigVGAN v2 activations (Snake / SnakeBeta)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Snake(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
||||
):
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(
|
||||
torch.zeros(in_features)
|
||||
if alpha_logscale
|
||||
else torch.ones(in_features) * alpha
|
||||
)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x):
|
||||
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
a = torch.exp(a)
|
||||
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
||||
):
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(
|
||||
torch.zeros(in_features)
|
||||
if alpha_logscale
|
||||
else torch.ones(in_features) * alpha
|
||||
)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta = nn.Parameter(
|
||||
torch.zeros(in_features)
|
||||
if alpha_logscale
|
||||
else torch.ones(in_features) * alpha
|
||||
)
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x):
|
||||
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
b = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
a = torch.exp(a)
|
||||
b = torch.exp(b)
|
||||
return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AMPBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"):
|
||||
super().__init__()
|
||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.acts1 = nn.ModuleList(
|
||||
[Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]
|
||||
)
|
||||
self.acts2 = nn.ModuleList(
|
||||
[Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2):
|
||||
xt = a1(x)
|
||||
xt = c1(xt)
|
||||
xt = a2(xt)
|
||||
xt = c2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HiFi-GAN residual blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock1, self).__init__()
|
||||
@@ -119,6 +421,7 @@ class Vocoder(torch.nn.Module):
|
||||
"""
|
||||
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
||||
|
||||
Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1").
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
@@ -128,19 +431,39 @@ class Vocoder(torch.nn.Module):
|
||||
config = self.get_default_config()
|
||||
|
||||
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
||||
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
|
||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
|
||||
upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2])
|
||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4])
|
||||
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
||||
stereo = config.get("stereo", True)
|
||||
resblock = config.get("resblock", "1")
|
||||
activation = config.get("activation", "snake")
|
||||
use_bias_at_final = config.get("use_bias_at_final", True)
|
||||
|
||||
|
||||
# "output_sample_rate" is not present in recent checkpoint configs.
|
||||
# When absent (None), AudioVAE.output_sample_rate computes it as:
|
||||
# sample_rate * vocoder.upsample_factor / mel_hop_length
|
||||
# where upsample_factor = product of all upsample stride lengths,
|
||||
# and mel_hop_length is loaded from the autoencoder config at
|
||||
# preprocessing.stft.hop_length (see CausalAudioAutoencoder).
|
||||
self.output_sample_rate = config.get("output_sample_rate")
|
||||
self.resblock = config.get("resblock", "1")
|
||||
self.use_tanh_at_final = config.get("use_tanh_at_final", True)
|
||||
self.apply_final_activation = config.get("apply_final_activation", True)
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
|
||||
in_channels = 128 if stereo else 64
|
||||
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||
|
||||
if self.resblock == "1":
|
||||
resblock_cls = ResBlock1
|
||||
elif self.resblock == "2":
|
||||
resblock_cls = ResBlock2
|
||||
elif self.resblock == "AMP1":
|
||||
resblock_cls = AMPBlock1
|
||||
else:
|
||||
raise ValueError(f"Unknown resblock type: {self.resblock}")
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
@@ -157,25 +480,40 @@ class Vocoder(torch.nn.Module):
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock_class(ch, k, d))
|
||||
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||
if self.resblock == "AMP1":
|
||||
self.resblocks.append(resblock_cls(ch, k, d, activation=activation))
|
||||
else:
|
||||
self.resblocks.append(resblock_cls(ch, k, d))
|
||||
|
||||
out_channels = 2 if stereo else 1
|
||||
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
|
||||
if self.resblock == "AMP1":
|
||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||
self.act_post = Activation1d(act_cls(ch))
|
||||
else:
|
||||
self.act_post = nn.LeakyReLU()
|
||||
|
||||
self.conv_post = ops.Conv1d(
|
||||
ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final
|
||||
)
|
||||
|
||||
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
||||
|
||||
|
||||
def get_default_config(self):
|
||||
"""Generate default configuration for the vocoder."""
|
||||
|
||||
config = {
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"upsample_rates": [6, 5, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"upsample_rates": [5, 4, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"upsample_initial_channel": 1024,
|
||||
"stereo": True,
|
||||
"resblock": "1",
|
||||
"activation": "snake",
|
||||
"use_bias_at_final": True,
|
||||
"use_tanh_at_final": True,
|
||||
}
|
||||
|
||||
return config
|
||||
@@ -196,8 +534,10 @@ class Vocoder(torch.nn.Module):
|
||||
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
||||
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if self.resblock != "AMP1":
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
@@ -206,8 +546,167 @@ class Vocoder(torch.nn.Module):
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
|
||||
x = self.act_post(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
if self.apply_final_activation:
|
||||
if self.use_tanh_at_final:
|
||||
x = torch.tanh(x)
|
||||
else:
|
||||
x = torch.clamp(x, -1, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class _STFTFn(nn.Module):
|
||||
"""Implements STFT as a convolution with precomputed DFT × Hann-window bases.
|
||||
|
||||
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
|
||||
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
|
||||
bfloat16 bases from training ensures the mel values fed to the BWE generator are
|
||||
bit-identical to what it was trained on.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_length: int, hop_length: int, win_length: int):
|
||||
super().__init__()
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
|
||||
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute magnitude and phase spectrogram from a batch of waveforms.
|
||||
|
||||
Applies causal (left-only) padding of win_length - hop_length samples so that
|
||||
each output frame depends only on past and present input — no lookahead.
|
||||
The STFT is computed by convolving the padded signal with forward_basis.
|
||||
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
|
||||
Returns:
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
Computed in float32 for numerical stability, then cast back to
|
||||
the input dtype.
|
||||
"""
|
||||
if y.dim() == 2:
|
||||
y = y.unsqueeze(1) # (B, 1, T)
|
||||
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
||||
y = F.pad(y, (left_pad, 0))
|
||||
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
||||
n_freqs = spec.shape[1] // 2
|
||||
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||
magnitude = torch.sqrt(real ** 2 + imag ** 2)
|
||||
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
|
||||
return magnitude, phase
|
||||
|
||||
|
||||
class MelSTFT(nn.Module):
|
||||
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
|
||||
|
||||
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
|
||||
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
|
||||
|
||||
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
|
||||
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_length: int,
|
||||
hop_length: int,
|
||||
win_length: int,
|
||||
n_mel_channels: int,
|
||||
sampling_rate: int,
|
||||
mel_fmin: float,
|
||||
mel_fmax: float,
|
||||
):
|
||||
super().__init__()
|
||||
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
|
||||
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
|
||||
|
||||
def mel_spectrogram(
|
||||
self, y: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute log-mel spectrogram and auxiliary spectral quantities.
|
||||
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
|
||||
Returns:
|
||||
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
|
||||
Computed as log(clamp(mel_basis @ magnitude, min=1e-5)).
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
|
||||
"""
|
||||
magnitude, phase = self.stft_fn(y)
|
||||
energy = torch.norm(magnitude, dim=1)
|
||||
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
||||
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return log_mel, magnitude, phase, energy
|
||||
|
||||
|
||||
class VocoderWithBWE(torch.nn.Module):
|
||||
"""Vocoder with bandwidth extension (BWE) for higher sample rate output.
|
||||
|
||||
Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples
|
||||
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
vocoder_config = config["vocoder"]
|
||||
bwe_config = config["bwe"]
|
||||
|
||||
self.vocoder = Vocoder(config=vocoder_config)
|
||||
self.bwe_generator = Vocoder(
|
||||
config={**bwe_config, "apply_final_activation": False}
|
||||
)
|
||||
|
||||
self.input_sample_rate = bwe_config["input_sampling_rate"]
|
||||
self.output_sample_rate = bwe_config["output_sampling_rate"]
|
||||
self.hop_length = bwe_config["hop_length"]
|
||||
|
||||
self.mel_stft = MelSTFT(
|
||||
filter_length=bwe_config["n_fft"],
|
||||
hop_length=bwe_config["hop_length"],
|
||||
win_length=bwe_config["n_fft"],
|
||||
n_mel_channels=bwe_config["num_mels"],
|
||||
sampling_rate=bwe_config["input_sampling_rate"],
|
||||
mel_fmin=0.0,
|
||||
mel_fmax=bwe_config["input_sampling_rate"] / 2.0,
|
||||
)
|
||||
self.resampler = UpSample1d(
|
||||
ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"],
|
||||
persistent=False,
|
||||
window_type="hann",
|
||||
)
|
||||
|
||||
def _compute_mel(self, audio):
|
||||
"""Compute log-mel spectrogram from waveform using causal STFT bases."""
|
||||
B, C, T = audio.shape
|
||||
flat = audio.reshape(B * C, -1) # (B*C, T)
|
||||
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
||||
return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
||||
|
||||
def forward(self, mel_spec):
|
||||
x = self.vocoder(mel_spec)
|
||||
_, _, T_low = x.shape
|
||||
T_out = T_low * self.output_sample_rate // self.input_sample_rate
|
||||
|
||||
remainder = T_low % self.hop_length
|
||||
if remainder != 0:
|
||||
x = F.pad(x, (0, self.hop_length - remainder))
|
||||
|
||||
mel = self._compute_mel(x)
|
||||
residual = self.bwe_generator(mel)
|
||||
skip = self.resampler(x)
|
||||
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
||||
|
||||
return torch.clamp(residual + skip, -1, 1)[..., :T_out]
|
||||
|
||||
Reference in New Issue
Block a user