Stable Cascade Stage A.
This commit is contained in:
254
comfy/ldm/cascade/stage_a.py
Normal file
254
comfy/ldm/cascade/stage_a.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
class vector_quantize(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, codebook):
|
||||
with torch.no_grad():
|
||||
codebook_sqr = torch.sum(codebook ** 2, dim=1)
|
||||
x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
|
||||
|
||||
dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
|
||||
_, indices = dist.min(dim=1)
|
||||
|
||||
ctx.save_for_backward(indices, codebook)
|
||||
ctx.mark_non_differentiable(indices)
|
||||
|
||||
nn = torch.index_select(codebook, 0, indices)
|
||||
return nn, indices
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output, grad_indices):
|
||||
grad_inputs, grad_codebook = None, None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_inputs = grad_output.clone()
|
||||
if ctx.needs_input_grad[1]:
|
||||
# Gradient wrt. the codebook
|
||||
indices, codebook = ctx.saved_tensors
|
||||
|
||||
grad_codebook = torch.zeros_like(codebook)
|
||||
grad_codebook.index_add_(0, indices, grad_output)
|
||||
|
||||
return (grad_inputs, grad_codebook)
|
||||
|
||||
|
||||
class VectorQuantize(nn.Module):
|
||||
def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
|
||||
"""
|
||||
Takes an input of variable size (as long as the last dimension matches the embedding size).
|
||||
Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
|
||||
with the same size as the input, vq and commitment components for the loss as a touple
|
||||
in the second output and the indices of the quantized vectors in the third:
|
||||
quantized, (vq_loss, commit_loss), indices
|
||||
"""
|
||||
super(VectorQuantize, self).__init__()
|
||||
|
||||
self.codebook = nn.Embedding(k, embedding_size)
|
||||
self.codebook.weight.data.uniform_(-1./k, 1./k)
|
||||
self.vq = vector_quantize.apply
|
||||
|
||||
self.ema_decay = ema_decay
|
||||
self.ema_loss = ema_loss
|
||||
if ema_loss:
|
||||
self.register_buffer('ema_element_count', torch.ones(k))
|
||||
self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
|
||||
|
||||
def _laplace_smoothing(self, x, epsilon):
|
||||
n = torch.sum(x)
|
||||
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
|
||||
|
||||
def _updateEMA(self, z_e_x, indices):
|
||||
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
|
||||
elem_count = mask.sum(dim=0)
|
||||
weight_sum = torch.mm(mask.t(), z_e_x)
|
||||
|
||||
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
|
||||
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
|
||||
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
|
||||
|
||||
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
||||
|
||||
def idx2vq(self, idx, dim=-1):
|
||||
q_idx = self.codebook(idx)
|
||||
if dim != -1:
|
||||
q_idx = q_idx.movedim(-1, dim)
|
||||
return q_idx
|
||||
|
||||
def forward(self, x, get_losses=True, dim=-1):
|
||||
if dim != -1:
|
||||
x = x.movedim(dim, -1)
|
||||
z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
|
||||
z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
|
||||
vq_loss, commit_loss = None, None
|
||||
if self.ema_loss and self.training:
|
||||
self._updateEMA(z_e_x.detach(), indices.detach())
|
||||
# pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
|
||||
z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
|
||||
if get_losses:
|
||||
vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
|
||||
commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
|
||||
|
||||
z_q_x = z_q_x.view(x.shape)
|
||||
if dim != -1:
|
||||
z_q_x = z_q_x.movedim(-1, dim)
|
||||
return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, c, c_hidden):
|
||||
super().__init__()
|
||||
# depthwise/attention
|
||||
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.depthwise = nn.Sequential(
|
||||
nn.ReplicationPad2d(1),
|
||||
nn.Conv2d(c, c, kernel_size=3, groups=c)
|
||||
)
|
||||
|
||||
# channelwise
|
||||
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.channelwise = nn.Sequential(
|
||||
nn.Linear(c, c_hidden),
|
||||
nn.GELU(),
|
||||
nn.Linear(c_hidden, c),
|
||||
)
|
||||
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||
|
||||
# Init weights
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
def _norm(self, x, norm):
|
||||
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
def forward(self, x):
|
||||
mods = self.gammas
|
||||
|
||||
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
|
||||
x = x + self.depthwise(x_temp) * mods[2]
|
||||
|
||||
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
|
||||
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class StageA(nn.Module):
|
||||
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
|
||||
scale_factor=0.43): # 0.3764
|
||||
super().__init__()
|
||||
self.c_latent = c_latent
|
||||
self.scale_factor = scale_factor
|
||||
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
|
||||
|
||||
# Encoder blocks
|
||||
self.in_block = nn.Sequential(
|
||||
nn.PixelUnshuffle(2),
|
||||
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||
)
|
||||
down_blocks = []
|
||||
for i in range(levels):
|
||||
if i > 0:
|
||||
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
||||
down_blocks.append(block)
|
||||
down_blocks.append(nn.Sequential(
|
||||
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
||||
))
|
||||
self.down_blocks = nn.Sequential(*down_blocks)
|
||||
self.down_blocks[0]
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
|
||||
|
||||
# Decoder blocks
|
||||
up_blocks = [nn.Sequential(
|
||||
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||
)]
|
||||
for i in range(levels):
|
||||
for j in range(bottleneck_blocks if i == 0 else 1):
|
||||
block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
|
||||
up_blocks.append(block)
|
||||
if i < levels - 1:
|
||||
up_blocks.append(
|
||||
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||
padding=1))
|
||||
self.up_blocks = nn.Sequential(*up_blocks)
|
||||
self.out_block = nn.Sequential(
|
||||
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||
nn.PixelShuffle(2),
|
||||
)
|
||||
|
||||
def encode(self, x, quantize=False):
|
||||
x = self.in_block(x)
|
||||
x = self.down_blocks(x)
|
||||
if quantize:
|
||||
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
|
||||
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
|
||||
else:
|
||||
return x / self.scale_factor
|
||||
|
||||
def decode(self, x):
|
||||
x = x * self.scale_factor
|
||||
x = self.up_blocks(x)
|
||||
x = self.out_block(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, quantize=False):
|
||||
qe, x, _, vq_loss = self.encode(x, quantize)
|
||||
x = self.decode(qe)
|
||||
return x, vq_loss
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
|
||||
super().__init__()
|
||||
d = max(depth - 3, 3)
|
||||
layers = [
|
||||
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||
nn.LeakyReLU(0.2),
|
||||
]
|
||||
for i in range(depth - 1):
|
||||
c_in = c_hidden // (2 ** max((d - i), 0))
|
||||
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
||||
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||
layers.append(nn.InstanceNorm2d(c_out))
|
||||
layers.append(nn.LeakyReLU(0.2))
|
||||
self.encoder = nn.Sequential(*layers)
|
||||
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||
self.logits = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
x = self.encoder(x)
|
||||
if cond is not None:
|
||||
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
|
||||
x = torch.cat([x, cond], dim=1)
|
||||
x = self.shuffle(x)
|
||||
x = self.logits(x)
|
||||
return x
|
||||
Reference in New Issue
Block a user