Initial commit.
This commit is contained in:
105
comfy/k_diffusion/augmentation.py
Normal file
105
comfy/k_diffusion/augmentation.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from functools import reduce
|
||||
import math
|
||||
import operator
|
||||
|
||||
import numpy as np
|
||||
from skimage import transform
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def translate2d(tx, ty):
|
||||
mat = [[1, 0, tx],
|
||||
[0, 1, ty],
|
||||
[0, 0, 1]]
|
||||
return torch.tensor(mat, dtype=torch.float32)
|
||||
|
||||
|
||||
def scale2d(sx, sy):
|
||||
mat = [[sx, 0, 0],
|
||||
[ 0, sy, 0],
|
||||
[ 0, 0, 1]]
|
||||
return torch.tensor(mat, dtype=torch.float32)
|
||||
|
||||
|
||||
def rotate2d(theta):
|
||||
mat = [[torch.cos(theta), torch.sin(-theta), 0],
|
||||
[torch.sin(theta), torch.cos(theta), 0],
|
||||
[ 0, 0, 1]]
|
||||
return torch.tensor(mat, dtype=torch.float32)
|
||||
|
||||
|
||||
class KarrasAugmentationPipeline:
|
||||
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8):
|
||||
self.a_prob = a_prob
|
||||
self.a_scale = a_scale
|
||||
self.a_aniso = a_aniso
|
||||
self.a_trans = a_trans
|
||||
|
||||
def __call__(self, image):
|
||||
h, w = image.size
|
||||
mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)]
|
||||
|
||||
# x-flip
|
||||
a0 = torch.randint(2, []).float()
|
||||
mats.append(scale2d(1 - 2 * a0, 1))
|
||||
# y-flip
|
||||
do = (torch.rand([]) < self.a_prob).float()
|
||||
a1 = torch.randint(2, []).float() * do
|
||||
mats.append(scale2d(1, 1 - 2 * a1))
|
||||
# scaling
|
||||
do = (torch.rand([]) < self.a_prob).float()
|
||||
a2 = torch.randn([]) * do
|
||||
mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2))
|
||||
# rotation
|
||||
do = (torch.rand([]) < self.a_prob).float()
|
||||
a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do
|
||||
mats.append(rotate2d(-a3))
|
||||
# anisotropy
|
||||
do = (torch.rand([]) < self.a_prob).float()
|
||||
a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do
|
||||
a5 = torch.randn([]) * do
|
||||
mats.append(rotate2d(a4))
|
||||
mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5))
|
||||
mats.append(rotate2d(-a4))
|
||||
# translation
|
||||
do = (torch.rand([]) < self.a_prob).float()
|
||||
a6 = torch.randn([]) * do
|
||||
a7 = torch.randn([]) * do
|
||||
mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7))
|
||||
|
||||
# form the transformation matrix and conditioning vector
|
||||
mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5))
|
||||
mat = reduce(operator.matmul, mats)
|
||||
cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7])
|
||||
|
||||
# apply the transformation
|
||||
image_orig = np.array(image, dtype=np.float32) / 255
|
||||
if image_orig.ndim == 2:
|
||||
image_orig = image_orig[..., None]
|
||||
tf = transform.AffineTransform(mat.numpy())
|
||||
image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
|
||||
image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1
|
||||
image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
|
||||
return image, image_orig, cond
|
||||
|
||||
|
||||
class KarrasAugmentWrapper(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
|
||||
def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs):
|
||||
if aug_cond is None:
|
||||
aug_cond = input.new_zeros([input.shape[0], 9])
|
||||
if mapping_cond is None:
|
||||
mapping_cond = aug_cond
|
||||
else:
|
||||
mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1)
|
||||
return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs)
|
||||
|
||||
def set_skip_stages(self, skip_stages):
|
||||
return self.inner_model.set_skip_stages(skip_stages)
|
||||
|
||||
def set_patch_size(self, patch_size):
|
||||
return self.inner_model.set_patch_size(patch_size)
|
||||
110
comfy/k_diffusion/config.py
Normal file
110
comfy/k_diffusion/config.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from functools import partial
|
||||
import json
|
||||
import math
|
||||
import warnings
|
||||
|
||||
from jsonmerge import merge
|
||||
|
||||
from . import augmentation, layers, models, utils
|
||||
|
||||
|
||||
def load_config(file):
|
||||
defaults = {
|
||||
'model': {
|
||||
'sigma_data': 1.,
|
||||
'patch_size': 1,
|
||||
'dropout_rate': 0.,
|
||||
'augment_wrapper': True,
|
||||
'augment_prob': 0.,
|
||||
'mapping_cond_dim': 0,
|
||||
'unet_cond_dim': 0,
|
||||
'cross_cond_dim': 0,
|
||||
'cross_attn_depths': None,
|
||||
'skip_stages': 0,
|
||||
'has_variance': False,
|
||||
},
|
||||
'dataset': {
|
||||
'type': 'imagefolder',
|
||||
},
|
||||
'optimizer': {
|
||||
'type': 'adamw',
|
||||
'lr': 1e-4,
|
||||
'betas': [0.95, 0.999],
|
||||
'eps': 1e-6,
|
||||
'weight_decay': 1e-3,
|
||||
},
|
||||
'lr_sched': {
|
||||
'type': 'inverse',
|
||||
'inv_gamma': 20000.,
|
||||
'power': 1.,
|
||||
'warmup': 0.99,
|
||||
},
|
||||
'ema_sched': {
|
||||
'type': 'inverse',
|
||||
'power': 0.6667,
|
||||
'max_value': 0.9999
|
||||
},
|
||||
}
|
||||
config = json.load(file)
|
||||
return merge(defaults, config)
|
||||
|
||||
|
||||
def make_model(config):
|
||||
config = config['model']
|
||||
assert config['type'] == 'image_v1'
|
||||
model = models.ImageDenoiserModelV1(
|
||||
config['input_channels'],
|
||||
config['mapping_out'],
|
||||
config['depths'],
|
||||
config['channels'],
|
||||
config['self_attn_depths'],
|
||||
config['cross_attn_depths'],
|
||||
patch_size=config['patch_size'],
|
||||
dropout_rate=config['dropout_rate'],
|
||||
mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0),
|
||||
unet_cond_dim=config['unet_cond_dim'],
|
||||
cross_cond_dim=config['cross_cond_dim'],
|
||||
skip_stages=config['skip_stages'],
|
||||
has_variance=config['has_variance'],
|
||||
)
|
||||
if config['augment_wrapper']:
|
||||
model = augmentation.KarrasAugmentWrapper(model)
|
||||
return model
|
||||
|
||||
|
||||
def make_denoiser_wrapper(config):
|
||||
config = config['model']
|
||||
sigma_data = config.get('sigma_data', 1.)
|
||||
has_variance = config.get('has_variance', False)
|
||||
if not has_variance:
|
||||
return partial(layers.Denoiser, sigma_data=sigma_data)
|
||||
return partial(layers.DenoiserWithVariance, sigma_data=sigma_data)
|
||||
|
||||
|
||||
def make_sample_density(config):
|
||||
sd_config = config['sigma_sample_density']
|
||||
sigma_data = config['sigma_data']
|
||||
if sd_config['type'] == 'lognormal':
|
||||
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
|
||||
scale = sd_config['std'] if 'std' in sd_config else sd_config['scale']
|
||||
return partial(utils.rand_log_normal, loc=loc, scale=scale)
|
||||
if sd_config['type'] == 'loglogistic':
|
||||
loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data)
|
||||
scale = sd_config['scale'] if 'scale' in sd_config else 0.5
|
||||
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
|
||||
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
|
||||
return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value)
|
||||
if sd_config['type'] == 'loguniform':
|
||||
min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min']
|
||||
max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max']
|
||||
return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value)
|
||||
if sd_config['type'] == 'v-diffusion':
|
||||
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
|
||||
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
|
||||
return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value)
|
||||
if sd_config['type'] == 'split-lognormal':
|
||||
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
|
||||
scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1']
|
||||
scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2']
|
||||
return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2)
|
||||
raise ValueError('Unknown sample density type')
|
||||
134
comfy/k_diffusion/evaluation.py
Normal file
134
comfy/k_diffusion/evaluation.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from cleanfid.inception_torchscript import InceptionV3W
|
||||
import clip
|
||||
from resize_right import resize
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import trange
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
class InceptionV3FeatureExtractor(nn.Module):
|
||||
def __init__(self, device='cpu'):
|
||||
super().__init__()
|
||||
path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion'
|
||||
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
||||
digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4'
|
||||
utils.download_file(path / 'inception-2015-12-05.pt', url, digest)
|
||||
self.model = InceptionV3W(str(path), resize_inside=False).to(device)
|
||||
self.size = (299, 299)
|
||||
|
||||
def forward(self, x):
|
||||
if x.shape[2:4] != self.size:
|
||||
x = resize(x, out_shape=self.size, pad_mode='reflect')
|
||||
if x.shape[1] == 1:
|
||||
x = torch.cat([x] * 3, dim=1)
|
||||
x = (x * 127.5 + 127.5).clamp(0, 255)
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class CLIPFeatureExtractor(nn.Module):
|
||||
def __init__(self, name='ViT-L/14@336px', device='cpu'):
|
||||
super().__init__()
|
||||
self.model = clip.load(name, device=device)[0].eval().requires_grad_(False)
|
||||
self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711))
|
||||
self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution)
|
||||
|
||||
def forward(self, x):
|
||||
if x.shape[2:4] != self.size:
|
||||
x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1)
|
||||
x = self.normalize(x)
|
||||
x = self.model.encode_image(x).float()
|
||||
x = F.normalize(x) * x.shape[1] ** 0.5
|
||||
return x
|
||||
|
||||
|
||||
def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size):
|
||||
n_per_proc = math.ceil(n / accelerator.num_processes)
|
||||
feats_all = []
|
||||
try:
|
||||
for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process):
|
||||
cur_batch_size = min(n - i, batch_size)
|
||||
samples = sample_fn(cur_batch_size)[:cur_batch_size]
|
||||
feats_all.append(accelerator.gather(extractor_fn(samples)))
|
||||
except StopIteration:
|
||||
pass
|
||||
return torch.cat(feats_all)[:n]
|
||||
|
||||
|
||||
def polynomial_kernel(x, y):
|
||||
d = x.shape[-1]
|
||||
dot = x @ y.transpose(-2, -1)
|
||||
return (dot / d + 1) ** 3
|
||||
|
||||
|
||||
def squared_mmd(x, y, kernel=polynomial_kernel):
|
||||
m = x.shape[-2]
|
||||
n = y.shape[-2]
|
||||
kxx = kernel(x, x)
|
||||
kyy = kernel(y, y)
|
||||
kxy = kernel(x, y)
|
||||
kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1)
|
||||
kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1)
|
||||
kxy_sum = kxy.sum([-1, -2])
|
||||
term_1 = kxx_sum / m / (m - 1)
|
||||
term_2 = kyy_sum / n / (n - 1)
|
||||
term_3 = kxy_sum * 2 / m / n
|
||||
return term_1 + term_2 - term_3
|
||||
|
||||
|
||||
@utils.tf32_mode(matmul=False)
|
||||
def kid(x, y, max_size=5000):
|
||||
x_size, y_size = x.shape[0], y.shape[0]
|
||||
n_partitions = math.ceil(max(x_size / max_size, y_size / max_size))
|
||||
total_mmd = x.new_zeros([])
|
||||
for i in range(n_partitions):
|
||||
cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)]
|
||||
cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)]
|
||||
total_mmd = total_mmd + squared_mmd(cur_x, cur_y)
|
||||
return total_mmd / n_partitions
|
||||
|
||||
|
||||
class _MatrixSquareRootEig(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a):
|
||||
vals, vecs = torch.linalg.eigh(a)
|
||||
ctx.save_for_backward(vals, vecs)
|
||||
return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
vals, vecs = ctx.saved_tensors
|
||||
d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1)
|
||||
vecs_t = vecs.transpose(-2, -1)
|
||||
return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t
|
||||
|
||||
|
||||
def sqrtm_eig(a):
|
||||
if a.ndim < 2:
|
||||
raise RuntimeError('tensor of matrices must have at least 2 dimensions')
|
||||
if a.shape[-2] != a.shape[-1]:
|
||||
raise RuntimeError('tensor must be batches of square matrices')
|
||||
return _MatrixSquareRootEig.apply(a)
|
||||
|
||||
|
||||
@utils.tf32_mode(matmul=False)
|
||||
def fid(x, y, eps=1e-8):
|
||||
x_mean = x.mean(dim=0)
|
||||
y_mean = y.mean(dim=0)
|
||||
mean_term = (x_mean - y_mean).pow(2).sum()
|
||||
x_cov = torch.cov(x.T)
|
||||
y_cov = torch.cov(y.T)
|
||||
eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps
|
||||
x_cov = x_cov + eps_eye
|
||||
y_cov = y_cov + eps_eye
|
||||
x_cov_sqrt = sqrtm_eig(x_cov)
|
||||
cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt))
|
||||
return mean_term + cov_term
|
||||
179
comfy/k_diffusion/external.py
Normal file
179
comfy/k_diffusion/external.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from . import sampling, utils
|
||||
|
||||
|
||||
class VDenoiser(nn.Module):
|
||||
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
|
||||
|
||||
def __init__(self, inner_model):
|
||||
super().__init__()
|
||||
self.inner_model = inner_model
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
||||
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_skip, c_out, c_in
|
||||
|
||||
def sigma_to_t(self, sigma):
|
||||
return sigma.atan() / math.pi * 2
|
||||
|
||||
def t_to_sigma(self, t):
|
||||
return (t * math.pi / 2).tan()
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
||||
|
||||
|
||||
class DiscreteSchedule(nn.Module):
|
||||
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
|
||||
levels."""
|
||||
|
||||
def __init__(self, sigmas, quantize):
|
||||
super().__init__()
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
self.quantize = quantize
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def get_sigmas(self, n=None):
|
||||
if n is None:
|
||||
return sampling.append_zero(self.sigmas.flip(0))
|
||||
t_max = len(self.sigmas) - 1
|
||||
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
||||
return sampling.append_zero(self.t_to_sigma(t))
|
||||
|
||||
def sigma_to_t(self, sigma, quantize=None):
|
||||
quantize = self.quantize if quantize is None else quantize
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma - self.log_sigmas[:, None]
|
||||
if quantize:
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = w.clamp(0, 1)
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
return t.view(sigma.shape)
|
||||
|
||||
def t_to_sigma(self, t):
|
||||
t = t.float()
|
||||
low_idx = t.floor().long()
|
||||
high_idx = t.ceil().long()
|
||||
w = t-low_idx if t.device.type == 'mps' else t.frac()
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp()
|
||||
|
||||
|
||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
||||
noise)."""
|
||||
|
||||
def __init__(self, model, alphas_cumprod, quantize):
|
||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
||||
self.inner_model = model
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_out, c_in
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model(*args, **kwargs)
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return (eps - noise).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return input + eps * c_out
|
||||
|
||||
|
||||
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
"""A wrapper for OpenAI diffusion models."""
|
||||
|
||||
def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
|
||||
alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
|
||||
super().__init__(model, alphas_cumprod, quantize=quantize)
|
||||
self.has_learned_sigmas = has_learned_sigmas
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
model_output = self.inner_model(*args, **kwargs)
|
||||
if self.has_learned_sigmas:
|
||||
return model_output.chunk(2, dim=1)[0]
|
||||
return model_output
|
||||
|
||||
|
||||
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
"""A wrapper for CompVis diffusion models."""
|
||||
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model.apply_model(*args, **kwargs)
|
||||
|
||||
|
||||
class DiscreteVDDPMDenoiser(DiscreteSchedule):
|
||||
"""A wrapper for discrete schedule DDPM models that output v."""
|
||||
|
||||
def __init__(self, model, alphas_cumprod, quantize):
|
||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
||||
self.inner_model = model
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
||||
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_skip, c_out, c_in
|
||||
|
||||
def get_v(self, *args, **kwargs):
|
||||
return self.inner_model(*args, **kwargs)
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
||||
|
||||
|
||||
class CompVisVDenoiser(DiscreteVDDPMDenoiser):
|
||||
"""A wrapper for CompVis diffusion models that output v."""
|
||||
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_v(self, x, t, cond, **kwargs):
|
||||
return self.inner_model.apply_model(x, t, cond)
|
||||
99
comfy/k_diffusion/gns.py
Normal file
99
comfy/k_diffusion/gns.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class DDPGradientStatsHook:
|
||||
def __init__(self, ddp_module):
|
||||
try:
|
||||
ddp_module.register_comm_hook(self, self._hook_fn)
|
||||
except AttributeError:
|
||||
raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules')
|
||||
self._clear_state()
|
||||
|
||||
def _clear_state(self):
|
||||
self.bucket_sq_norms_small_batch = []
|
||||
self.bucket_sq_norms_large_batch = []
|
||||
|
||||
@staticmethod
|
||||
def _hook_fn(self, bucket):
|
||||
buf = bucket.buffer()
|
||||
self.bucket_sq_norms_small_batch.append(buf.pow(2).sum())
|
||||
fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future()
|
||||
def callback(fut):
|
||||
buf = fut.value()[0]
|
||||
self.bucket_sq_norms_large_batch.append(buf.pow(2).sum())
|
||||
return buf
|
||||
return fut.then(callback)
|
||||
|
||||
def get_stats(self):
|
||||
sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch)
|
||||
sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch)
|
||||
self._clear_state()
|
||||
stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch])
|
||||
torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG)
|
||||
return stats[0].item(), stats[1].item()
|
||||
|
||||
|
||||
class GradientNoiseScale:
|
||||
"""Calculates the gradient noise scale (1 / SNR), or critical batch size,
|
||||
from _An Empirical Model of Large-Batch Training_,
|
||||
https://arxiv.org/abs/1812.06162).
|
||||
|
||||
Args:
|
||||
beta (float): The decay factor for the exponential moving averages used to
|
||||
calculate the gradient noise scale.
|
||||
Default: 0.9998
|
||||
eps (float): Added for numerical stability.
|
||||
Default: 1e-8
|
||||
"""
|
||||
|
||||
def __init__(self, beta=0.9998, eps=1e-8):
|
||||
self.beta = beta
|
||||
self.eps = eps
|
||||
self.ema_sq_norm = 0.
|
||||
self.ema_var = 0.
|
||||
self.beta_cumprod = 1.
|
||||
self.gradient_noise_scale = float('nan')
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the object as a :class:`dict`."""
|
||||
return dict(self.__dict__.items())
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the object's state.
|
||||
Args:
|
||||
state_dict (dict): object state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch):
|
||||
"""Updates the state with a new batch's gradient statistics, and returns the
|
||||
current gradient noise scale.
|
||||
|
||||
Args:
|
||||
sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or
|
||||
per sample gradients.
|
||||
sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or
|
||||
per sample gradients.
|
||||
n_small_batch (int): The batch size of the individual microbatch or per sample
|
||||
gradients (1 if per sample).
|
||||
n_large_batch (int): The total batch size of the mean of the microbatch or
|
||||
per sample gradients.
|
||||
"""
|
||||
est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch)
|
||||
est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch)
|
||||
self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm
|
||||
self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var
|
||||
self.beta_cumprod *= self.beta
|
||||
self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps)
|
||||
return self.gradient_noise_scale
|
||||
|
||||
def get_gns(self):
|
||||
"""Returns the current gradient noise scale."""
|
||||
return self.gradient_noise_scale
|
||||
|
||||
def get_stats(self):
|
||||
"""Returns the current (debiased) estimates of the squared mean gradient
|
||||
and gradient variance."""
|
||||
return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod)
|
||||
246
comfy/k_diffusion/layers.py
Normal file
246
comfy/k_diffusion/layers.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import math
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from . import utils
|
||||
|
||||
# Karras et al. preconditioned denoiser
|
||||
|
||||
class Denoiser(nn.Module):
|
||||
"""A Karras et al. preconditioner for denoising diffusion models."""
|
||||
|
||||
def __init__(self, inner_model, sigma_data=1.):
|
||||
super().__init__()
|
||||
self.inner_model = inner_model
|
||||
self.sigma_data = sigma_data
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
||||
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_skip, c_out, c_in
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output = self.inner_model(noised_input * c_in, sigma, **kwargs)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip
|
||||
|
||||
|
||||
class DenoiserWithVariance(Denoiser):
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs)
|
||||
logvar = utils.append_dims(logvar, model_output.ndim)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2
|
||||
return losses.flatten(1).mean(1)
|
||||
|
||||
|
||||
# Residual blocks
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, *main, skip=None):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(*main)
|
||||
self.skip = skip if skip else nn.Identity()
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input) + self.skip(input)
|
||||
|
||||
|
||||
# Noise level (and other) conditioning
|
||||
|
||||
class ConditionedModule(nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
class UnconditionedModule(ConditionedModule):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self, input, cond=None):
|
||||
return self.module(input)
|
||||
|
||||
|
||||
class ConditionedSequential(nn.Sequential, ConditionedModule):
|
||||
def forward(self, input, cond):
|
||||
for module in self:
|
||||
if isinstance(module, ConditionedModule):
|
||||
input = module(input, cond)
|
||||
else:
|
||||
input = module(input)
|
||||
return input
|
||||
|
||||
|
||||
class ConditionedResidualBlock(ConditionedModule):
|
||||
def __init__(self, *main, skip=None):
|
||||
super().__init__()
|
||||
self.main = ConditionedSequential(*main)
|
||||
self.skip = skip if skip else nn.Identity()
|
||||
|
||||
def forward(self, input, cond):
|
||||
skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input)
|
||||
return self.main(input, cond) + skip
|
||||
|
||||
|
||||
class AdaGN(ConditionedModule):
|
||||
def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'):
|
||||
super().__init__()
|
||||
self.num_groups = num_groups
|
||||
self.eps = eps
|
||||
self.cond_key = cond_key
|
||||
self.mapper = nn.Linear(feats_in, c_out * 2)
|
||||
|
||||
def forward(self, input, cond):
|
||||
weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1)
|
||||
input = F.group_norm(input, self.num_groups, eps=self.eps)
|
||||
return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1)
|
||||
|
||||
|
||||
# Attention
|
||||
|
||||
class SelfAttention2d(ConditionedModule):
|
||||
def __init__(self, c_in, n_head, norm, dropout_rate=0.):
|
||||
super().__init__()
|
||||
assert c_in % n_head == 0
|
||||
self.norm_in = norm(c_in)
|
||||
self.n_head = n_head
|
||||
self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
|
||||
self.out_proj = nn.Conv2d(c_in, c_in, 1)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, input, cond):
|
||||
n, c, h, w = input.shape
|
||||
qkv = self.qkv_proj(self.norm_in(input, cond))
|
||||
qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = k.shape[3] ** -0.25
|
||||
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
|
||||
att = self.dropout(att)
|
||||
y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
|
||||
return input + self.out_proj(y)
|
||||
|
||||
|
||||
class CrossAttention2d(ConditionedModule):
|
||||
def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0.,
|
||||
cond_key='cross', cond_key_padding='cross_padding'):
|
||||
super().__init__()
|
||||
assert c_dec % n_head == 0
|
||||
self.cond_key = cond_key
|
||||
self.cond_key_padding = cond_key_padding
|
||||
self.norm_enc = nn.LayerNorm(c_enc)
|
||||
self.norm_dec = norm_dec(c_dec)
|
||||
self.n_head = n_head
|
||||
self.q_proj = nn.Conv2d(c_dec, c_dec, 1)
|
||||
self.kv_proj = nn.Linear(c_enc, c_dec * 2)
|
||||
self.out_proj = nn.Conv2d(c_dec, c_dec, 1)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, input, cond):
|
||||
n, c, h, w = input.shape
|
||||
q = self.q_proj(self.norm_dec(input, cond))
|
||||
q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3)
|
||||
kv = self.kv_proj(self.norm_enc(cond[self.cond_key]))
|
||||
kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2)
|
||||
k, v = kv.chunk(2, dim=1)
|
||||
scale = k.shape[3] ** -0.25
|
||||
att = ((q * scale) @ (k.transpose(2, 3) * scale))
|
||||
att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000
|
||||
att = att.softmax(3)
|
||||
att = self.dropout(att)
|
||||
y = (att @ v).transpose(2, 3)
|
||||
y = y.contiguous().view([n, c, h, w])
|
||||
return input + self.out_proj(y)
|
||||
|
||||
|
||||
# Downsampling/upsampling
|
||||
|
||||
_kernels = {
|
||||
'linear':
|
||||
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||
'cubic':
|
||||
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
|
||||
0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
||||
'lanczos3':
|
||||
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
|
||||
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
|
||||
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
|
||||
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
|
||||
}
|
||||
_kernels['bilinear'] = _kernels['linear']
|
||||
_kernels['bicubic'] = _kernels['cubic']
|
||||
|
||||
|
||||
class Downsample2d(nn.Module):
|
||||
def __init__(self, kernel='linear', pad_mode='reflect'):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([_kernels[kernel]])
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
|
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
return F.conv2d(x, weight, stride=2)
|
||||
|
||||
|
||||
class Upsample2d(nn.Module):
|
||||
def __init__(self, kernel='linear', pad_mode='reflect'):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([_kernels[kernel]]) * 2
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
|
||||
|
||||
# Embeddings
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, in_features, out_features, std=1.):
|
||||
super().__init__()
|
||||
assert out_features % 2 == 0
|
||||
self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std)
|
||||
|
||||
def forward(self, input):
|
||||
f = 2 * math.pi * input @ self.weight.T
|
||||
return torch.cat([f.cos(), f.sin()], dim=-1)
|
||||
|
||||
|
||||
# U-Nets
|
||||
|
||||
class UNet(ConditionedModule):
|
||||
def __init__(self, d_blocks, u_blocks, skip_stages=0):
|
||||
super().__init__()
|
||||
self.d_blocks = nn.ModuleList(d_blocks)
|
||||
self.u_blocks = nn.ModuleList(u_blocks)
|
||||
self.skip_stages = skip_stages
|
||||
|
||||
def forward(self, input, cond):
|
||||
skips = []
|
||||
for block in self.d_blocks[self.skip_stages:]:
|
||||
input = block(input, cond)
|
||||
skips.append(input)
|
||||
for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))):
|
||||
input = block(input, cond, skip if i > 0 else None)
|
||||
return input
|
||||
1
comfy/k_diffusion/models/__init__.py
Normal file
1
comfy/k_diffusion/models/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .image_v1 import ImageDenoiserModelV1
|
||||
156
comfy/k_diffusion/models/image_v1.py
Normal file
156
comfy/k_diffusion/models/image_v1.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .. import layers, utils
|
||||
|
||||
|
||||
def orthogonal_(module):
|
||||
nn.init.orthogonal_(module.weight)
|
||||
return module
|
||||
|
||||
|
||||
class ResConvBlock(layers.ConditionedResidualBlock):
|
||||
def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.):
|
||||
skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False))
|
||||
super().__init__(
|
||||
layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(c_in, c_mid, 3, padding=1),
|
||||
nn.Dropout2d(dropout_rate, inplace=True),
|
||||
layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(c_mid, c_out, 3, padding=1),
|
||||
nn.Dropout2d(dropout_rate, inplace=True),
|
||||
skip=skip)
|
||||
|
||||
|
||||
class DBlock(layers.ConditionedSequential):
|
||||
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0):
|
||||
modules = [nn.Identity()]
|
||||
for i in range(n_layers):
|
||||
my_c_in = c_in if i == 0 else c_mid
|
||||
my_c_out = c_mid if i < n_layers - 1 else c_out
|
||||
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
|
||||
if self_attn:
|
||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
||||
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
|
||||
if cross_attn:
|
||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
||||
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
|
||||
super().__init__(*modules)
|
||||
self.set_downsample(downsample)
|
||||
|
||||
def set_downsample(self, downsample):
|
||||
self[0] = layers.Downsample2d() if downsample else nn.Identity()
|
||||
return self
|
||||
|
||||
|
||||
class UBlock(layers.ConditionedSequential):
|
||||
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0):
|
||||
modules = []
|
||||
for i in range(n_layers):
|
||||
my_c_in = c_in if i == 0 else c_mid
|
||||
my_c_out = c_mid if i < n_layers - 1 else c_out
|
||||
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
|
||||
if self_attn:
|
||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
||||
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
|
||||
if cross_attn:
|
||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
||||
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
|
||||
modules.append(nn.Identity())
|
||||
super().__init__(*modules)
|
||||
self.set_upsample(upsample)
|
||||
|
||||
def forward(self, input, cond, skip=None):
|
||||
if skip is not None:
|
||||
input = torch.cat([input, skip], dim=1)
|
||||
return super().forward(input, cond)
|
||||
|
||||
def set_upsample(self, upsample):
|
||||
self[-1] = layers.Upsample2d() if upsample else nn.Identity()
|
||||
return self
|
||||
|
||||
|
||||
class MappingNet(nn.Sequential):
|
||||
def __init__(self, feats_in, feats_out, n_layers=2):
|
||||
layers = []
|
||||
for i in range(n_layers):
|
||||
layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out)))
|
||||
layers.append(nn.GELU())
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
class ImageDenoiserModelV1(nn.Module):
|
||||
def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False):
|
||||
super().__init__()
|
||||
self.c_in = c_in
|
||||
self.channels = channels
|
||||
self.unet_cond_dim = unet_cond_dim
|
||||
self.patch_size = patch_size
|
||||
self.has_variance = has_variance
|
||||
self.timestep_embed = layers.FourierFeatures(1, feats_in)
|
||||
if mapping_cond_dim > 0:
|
||||
self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False)
|
||||
self.mapping = MappingNet(feats_in, feats_in)
|
||||
self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1)
|
||||
self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
|
||||
nn.init.zeros_(self.proj_out.weight)
|
||||
nn.init.zeros_(self.proj_out.bias)
|
||||
if cross_cond_dim == 0:
|
||||
cross_attn_depths = [False] * len(self_attn_depths)
|
||||
d_blocks, u_blocks = [], []
|
||||
for i in range(len(depths)):
|
||||
my_c_in = channels[max(0, i - 1)]
|
||||
d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
|
||||
for i in range(len(depths)):
|
||||
my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i]
|
||||
my_c_out = channels[max(0, i - 1)]
|
||||
u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
|
||||
self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages)
|
||||
|
||||
def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False):
|
||||
c_noise = sigma.log() / 4
|
||||
timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2))
|
||||
mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond)
|
||||
mapping_out = self.mapping(timestep_embed + mapping_cond_embed)
|
||||
cond = {'cond': mapping_out}
|
||||
if unet_cond is not None:
|
||||
input = torch.cat([input, unet_cond], dim=1)
|
||||
if cross_cond is not None:
|
||||
cond['cross'] = cross_cond
|
||||
cond['cross_padding'] = cross_cond_padding
|
||||
if self.patch_size > 1:
|
||||
input = F.pixel_unshuffle(input, self.patch_size)
|
||||
input = self.proj_in(input)
|
||||
input = self.u_net(input, cond)
|
||||
input = self.proj_out(input)
|
||||
if self.has_variance:
|
||||
input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1)
|
||||
if self.patch_size > 1:
|
||||
input = F.pixel_shuffle(input, self.patch_size)
|
||||
if self.has_variance and return_variance:
|
||||
return input, logvar
|
||||
return input
|
||||
|
||||
def set_skip_stages(self, skip_stages):
|
||||
self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1)
|
||||
self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1)
|
||||
nn.init.zeros_(self.proj_out.weight)
|
||||
nn.init.zeros_(self.proj_out.bias)
|
||||
self.u_net.skip_stages = skip_stages
|
||||
for i, block in enumerate(self.u_net.d_blocks):
|
||||
block.set_downsample(i > skip_stages)
|
||||
for i, block in enumerate(reversed(self.u_net.u_blocks)):
|
||||
block.set_upsample(i > skip_stages)
|
||||
return self
|
||||
|
||||
def set_patch_size(self, patch_size):
|
||||
self.patch_size = patch_size
|
||||
self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1)
|
||||
self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
|
||||
nn.init.zeros_(self.proj_out.weight)
|
||||
nn.init.zeros_(self.proj_out.bias)
|
||||
607
comfy/k_diffusion/sampling.py
Normal file
607
comfy/k_diffusion/sampling.py
Normal file
@@ -0,0 +1,607 @@
|
||||
import math
|
||||
|
||||
from scipy import integrate
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchdiffeq import odeint
|
||||
import torchsde
|
||||
from tqdm.auto import trange, tqdm
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
|
||||
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
ramp = torch.linspace(0, 1, n, device=device)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return append_zero(sigmas).to(device)
|
||||
|
||||
|
||||
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
|
||||
"""Constructs an exponential noise schedule."""
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
|
||||
"""Constructs an polynomial in log sigma noise schedule."""
|
||||
ramp = torch.linspace(1, 0, n, device=device) ** rho
|
||||
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
||||
"""Constructs a continuous VP noise schedule."""
|
||||
t = torch.linspace(1, eps_s, n, device=device)
|
||||
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
||||
|
||||
|
||||
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
||||
"""Calculates the noise level (sigma_down) to step down to and the amount
|
||||
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
||||
if not eta:
|
||||
return sigma_to, 0.
|
||||
sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
|
||||
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
def default_noise_sampler(x):
|
||||
return lambda sigma, sigma_next: torch.randn_like(x)
|
||||
|
||||
|
||||
class BatchedBrownianTree:
|
||||
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
||||
|
||||
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
||||
t0, t1, self.sign = self.sort(t0, t1)
|
||||
w0 = kwargs.get('w0', torch.zeros_like(x))
|
||||
if seed is None:
|
||||
seed = torch.randint(0, 2 ** 63 - 1, []).item()
|
||||
self.batched = True
|
||||
try:
|
||||
assert len(seed) == x.shape[0]
|
||||
w0 = w0[0]
|
||||
except TypeError:
|
||||
seed = [seed]
|
||||
self.batched = False
|
||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
||||
|
||||
@staticmethod
|
||||
def sort(a, b):
|
||||
return (a, b, 1) if a < b else (b, a, -1)
|
||||
|
||||
def __call__(self, t0, t1):
|
||||
t0, t1, sign = self.sort(t0, t1)
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
||||
return w if self.batched else w[0]
|
||||
|
||||
|
||||
class BrownianTreeNoiseSampler:
|
||||
"""A noise sampler backed by a torchsde.BrownianTree.
|
||||
|
||||
Args:
|
||||
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
||||
random samples.
|
||||
sigma_min (float): The low end of the valid interval.
|
||||
sigma_max (float): The high end of the valid interval.
|
||||
seed (int or List[int]): The random seed. If a list of seeds is
|
||||
supplied instead of a single integer, then the noise sampler will
|
||||
use one BrownianTree per batch item, each with its own seed.
|
||||
transform (callable): A function that maps sigma to the sampler's
|
||||
internal timestep.
|
||||
"""
|
||||
|
||||
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
|
||||
self.transform = transform
|
||||
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
|
||||
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
||||
|
||||
def __call__(self, sigma, sigma_next):
|
||||
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
|
||||
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver-2
|
||||
sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
|
||||
dt_1 = sigma_mid - sigma_hat
|
||||
dt_2 = sigmas[i + 1] - sigma_hat
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver-2
|
||||
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j):
|
||||
if order - 1 > i:
|
||||
raise ValueError(f'Order {order} too high for step {i}')
|
||||
def fn(tau):
|
||||
prod = 1.
|
||||
for k in range(order):
|
||||
if j == k:
|
||||
continue
|
||||
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
||||
return prod
|
||||
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigmas_cpu = sigmas.detach().cpu().numpy()
|
||||
ds = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
ds.append(d)
|
||||
if len(ds) > order:
|
||||
ds.pop(0)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
v = torch.randint_like(x, 2) * 2 - 1
|
||||
fevals = 0
|
||||
def ode_fn(sigma, x):
|
||||
nonlocal fevals
|
||||
with torch.enable_grad():
|
||||
x = x[0].detach().requires_grad_()
|
||||
denoised = model(x, sigma * s_in, **extra_args)
|
||||
d = to_d(x, sigma, denoised)
|
||||
fevals += 1
|
||||
grad = torch.autograd.grad((d * v).sum(), x)[0]
|
||||
d_ll = (v * grad).flatten(1).sum(1)
|
||||
return d.detach(), d_ll
|
||||
x_min = x, x.new_zeros([x.shape[0]])
|
||||
t = x.new_tensor([sigma_min, sigma_max])
|
||||
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
|
||||
latent, delta_ll = sol[0][-1], sol[1][-1]
|
||||
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
|
||||
return ll_prior + delta_ll, {'fevals': fevals}
|
||||
|
||||
|
||||
class PIDStepSizeController:
|
||||
"""A PID controller for ODE adaptive step size control."""
|
||||
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
||||
self.h = h
|
||||
self.b1 = (pcoeff + icoeff + dcoeff) / order
|
||||
self.b2 = -(pcoeff + 2 * dcoeff) / order
|
||||
self.b3 = dcoeff / order
|
||||
self.accept_safety = accept_safety
|
||||
self.eps = eps
|
||||
self.errs = []
|
||||
|
||||
def limiter(self, x):
|
||||
return 1 + math.atan(x - 1)
|
||||
|
||||
def propose_step(self, error):
|
||||
inv_error = 1 / (float(error) + self.eps)
|
||||
if not self.errs:
|
||||
self.errs = [inv_error, inv_error, inv_error]
|
||||
self.errs[0] = inv_error
|
||||
factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
|
||||
factor = self.limiter(factor)
|
||||
accept = factor >= self.accept_safety
|
||||
if accept:
|
||||
self.errs[2] = self.errs[1]
|
||||
self.errs[1] = self.errs[0]
|
||||
self.h *= factor
|
||||
return accept
|
||||
|
||||
|
||||
class DPMSolver(nn.Module):
|
||||
"""DPM-Solver. See https://arxiv.org/abs/2206.00927."""
|
||||
|
||||
def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.extra_args = {} if extra_args is None else extra_args
|
||||
self.eps_callback = eps_callback
|
||||
self.info_callback = info_callback
|
||||
|
||||
def t(self, sigma):
|
||||
return -sigma.log()
|
||||
|
||||
def sigma(self, t):
|
||||
return t.neg().exp()
|
||||
|
||||
def eps(self, eps_cache, key, x, t, *args, **kwargs):
|
||||
if key in eps_cache:
|
||||
return eps_cache[key], eps_cache
|
||||
sigma = self.sigma(t) * x.new_ones([x.shape[0]])
|
||||
eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
|
||||
if self.eps_callback is not None:
|
||||
self.eps_callback()
|
||||
return eps, {key: eps, **eps_cache}
|
||||
|
||||
def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
|
||||
eps_cache = {} if eps_cache is None else eps_cache
|
||||
h = t_next - t
|
||||
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
||||
x_1 = x - self.sigma(t_next) * h.expm1() * eps
|
||||
return x_1, eps_cache
|
||||
|
||||
def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
|
||||
eps_cache = {} if eps_cache is None else eps_cache
|
||||
h = t_next - t
|
||||
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
||||
s1 = t + r1 * h
|
||||
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
|
||||
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
|
||||
x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
|
||||
return x_2, eps_cache
|
||||
|
||||
def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
|
||||
eps_cache = {} if eps_cache is None else eps_cache
|
||||
h = t_next - t
|
||||
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
||||
s1 = t + r1 * h
|
||||
s2 = t + r2 * h
|
||||
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
|
||||
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
|
||||
u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
|
||||
eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
|
||||
x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
|
||||
return x_3, eps_cache
|
||||
|
||||
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
if not t_end > t_start and eta:
|
||||
raise ValueError('eta must be 0 for reverse sampling')
|
||||
|
||||
m = math.floor(nfe / 3) + 1
|
||||
ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
|
||||
|
||||
if nfe % 3 == 0:
|
||||
orders = [3] * (m - 2) + [2, 1]
|
||||
else:
|
||||
orders = [3] * (m - 1) + [nfe % 3]
|
||||
|
||||
for i in range(len(orders)):
|
||||
eps_cache = {}
|
||||
t, t_next = ts[i], ts[i + 1]
|
||||
if eta:
|
||||
sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
|
||||
t_next_ = torch.minimum(t_end, self.t(sd))
|
||||
su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
|
||||
else:
|
||||
t_next_, su = t_next, 0.
|
||||
|
||||
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
||||
denoised = x - self.sigma(t) * eps
|
||||
if self.info_callback is not None:
|
||||
self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
|
||||
|
||||
if orders[i] == 1:
|
||||
x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
|
||||
elif orders[i] == 2:
|
||||
x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
|
||||
else:
|
||||
x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
|
||||
|
||||
x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
|
||||
|
||||
return x
|
||||
|
||||
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
if order not in {2, 3}:
|
||||
raise ValueError('order should be 2 or 3')
|
||||
forward = t_end > t_start
|
||||
if not forward and eta:
|
||||
raise ValueError('eta must be 0 for reverse sampling')
|
||||
h_init = abs(h_init) * (1 if forward else -1)
|
||||
atol = torch.tensor(atol)
|
||||
rtol = torch.tensor(rtol)
|
||||
s = t_start
|
||||
x_prev = x
|
||||
accept = True
|
||||
pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
|
||||
info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
|
||||
|
||||
while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
|
||||
eps_cache = {}
|
||||
t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
|
||||
if eta:
|
||||
sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
|
||||
t_ = torch.minimum(t_end, self.t(sd))
|
||||
su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
|
||||
else:
|
||||
t_, su = t, 0.
|
||||
|
||||
eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
|
||||
denoised = x - self.sigma(s) * eps
|
||||
|
||||
if order == 2:
|
||||
x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
|
||||
x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
|
||||
else:
|
||||
x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
|
||||
x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
|
||||
delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
|
||||
error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
|
||||
accept = pid.propose_step(error)
|
||||
if accept:
|
||||
x_prev = x_low
|
||||
x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
|
||||
s = t
|
||||
info['n_accept'] += 1
|
||||
else:
|
||||
info['n_reject'] += 1
|
||||
info['nfe'] += order
|
||||
info['steps'] += 1
|
||||
|
||||
if self.info_callback is not None:
|
||||
self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
|
||||
|
||||
return x, info
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
|
||||
"""DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
|
||||
if sigma_min <= 0 or sigma_max <= 0:
|
||||
raise ValueError('sigma_min and sigma_max must not be 0')
|
||||
with tqdm(total=n, disable=disable) as pbar:
|
||||
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
|
||||
if callback is not None:
|
||||
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
|
||||
return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
|
||||
"""DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
|
||||
if sigma_min <= 0 or sigma_max <= 0:
|
||||
raise ValueError('sigma_min and sigma_max must not be 0')
|
||||
with tqdm(disable=disable) as pbar:
|
||||
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
|
||||
if callback is not None:
|
||||
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
|
||||
x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
|
||||
if return_info:
|
||||
return x, info
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver++(2S)
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
||||
r = 1 / 2
|
||||
h = t_next - t
|
||||
s = t + r * h
|
||||
x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
"""DPM-Solver++ (stochastic)."""
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver++
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
s = t + h * r
|
||||
fac = 1 / (2 * r)
|
||||
|
||||
# Step 1
|
||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
||||
s_ = t_fn(sd)
|
||||
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
|
||||
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
|
||||
t_next_ = t_fn(sd)
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
|
||||
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""DPM-Solver++(2M)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
old_denoised = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
if old_denoised is None or sigmas[i + 1] == 0:
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
||||
else:
|
||||
h_last = t - t_fn(sigmas[i - 1])
|
||||
r = h_last / h
|
||||
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
||||
old_denoised = denoised
|
||||
return x
|
||||
332
comfy/k_diffusion/utils.py
Normal file
332
comfy/k_diffusion/utils.py
Normal file
@@ -0,0 +1,332 @@
|
||||
from contextlib import contextmanager
|
||||
import hashlib
|
||||
import math
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import urllib
|
||||
import warnings
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from torch.utils import data
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
|
||||
def from_pil_image(x):
|
||||
"""Converts from a PIL image to a tensor."""
|
||||
x = TF.to_tensor(x)
|
||||
if x.ndim == 2:
|
||||
x = x[..., None]
|
||||
return x * 2 - 1
|
||||
|
||||
|
||||
def to_pil_image(x):
|
||||
"""Converts from a tensor to a PIL image."""
|
||||
if x.ndim == 4:
|
||||
assert x.shape[0] == 1
|
||||
x = x[0]
|
||||
if x.shape[0] == 1:
|
||||
x = x[0]
|
||||
return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
|
||||
|
||||
|
||||
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
|
||||
"""Apply passed in transforms for HuggingFace Datasets."""
|
||||
images = [transform(image.convert(mode)) for image in examples[image_key]]
|
||||
return {image_key: images}
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
||||
expanded = x[(...,) + (None,) * dims_to_append]
|
||||
# MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
|
||||
# https://github.com/pytorch/pytorch/issues/84364
|
||||
return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
|
||||
|
||||
|
||||
def n_params(module):
|
||||
"""Returns the number of trainable parameters in a module."""
|
||||
return sum(p.numel() for p in module.parameters())
|
||||
|
||||
|
||||
def download_file(path, url, digest=None):
|
||||
"""Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not path.exists():
|
||||
with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
|
||||
shutil.copyfileobj(response, f)
|
||||
if digest is not None:
|
||||
file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
|
||||
if digest != file_digest:
|
||||
raise OSError(f'hash of {path} (url: {url}) failed to validate')
|
||||
return path
|
||||
|
||||
|
||||
@contextmanager
|
||||
def train_mode(model, mode=True):
|
||||
"""A context manager that places a model into training mode and restores
|
||||
the previous mode on exit."""
|
||||
modes = [module.training for module in model.modules()]
|
||||
try:
|
||||
yield model.train(mode)
|
||||
finally:
|
||||
for i, module in enumerate(model.modules()):
|
||||
module.training = modes[i]
|
||||
|
||||
|
||||
def eval_mode(model):
|
||||
"""A context manager that places a model into evaluation mode and restores
|
||||
the previous mode on exit."""
|
||||
return train_mode(model, False)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ema_update(model, averaged_model, decay):
|
||||
"""Incorporates updated model parameters into an exponential moving averaged
|
||||
version of a model. It should be called after each optimizer step."""
|
||||
model_params = dict(model.named_parameters())
|
||||
averaged_params = dict(averaged_model.named_parameters())
|
||||
assert model_params.keys() == averaged_params.keys()
|
||||
|
||||
for name, param in model_params.items():
|
||||
averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
|
||||
|
||||
model_buffers = dict(model.named_buffers())
|
||||
averaged_buffers = dict(averaged_model.named_buffers())
|
||||
assert model_buffers.keys() == averaged_buffers.keys()
|
||||
|
||||
for name, buf in model_buffers.items():
|
||||
averaged_buffers[name].copy_(buf)
|
||||
|
||||
|
||||
class EMAWarmup:
|
||||
"""Implements an EMA warmup using an inverse decay schedule.
|
||||
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
|
||||
good values for models you plan to train for a million or more steps (reaches decay
|
||||
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
|
||||
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
||||
215.4k steps).
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 1.
|
||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||
max_value (float): The maximum EMA decay rate. Default: 1.
|
||||
start_at (int): The epoch to start averaging at. Default: 0.
|
||||
last_epoch (int): The index of last epoch. Default: 0.
|
||||
"""
|
||||
|
||||
def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
|
||||
last_epoch=0):
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
self.start_at = start_at
|
||||
self.last_epoch = last_epoch
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the class as a :class:`dict`."""
|
||||
return dict(self.__dict__.items())
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the class's state.
|
||||
Args:
|
||||
state_dict (dict): scaler state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_value(self):
|
||||
"""Gets the current EMA decay rate."""
|
||||
epoch = max(0, self.last_epoch - self.start_at)
|
||||
value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
|
||||
return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
|
||||
|
||||
def step(self):
|
||||
"""Updates the step count."""
|
||||
self.last_epoch += 1
|
||||
|
||||
|
||||
class InverseLR(optim.lr_scheduler._LRScheduler):
|
||||
"""Implements an inverse decay learning rate schedule with an optional exponential
|
||||
warmup. When last_epoch=-1, sets initial lr as lr.
|
||||
inv_gamma is the number of steps/epochs required for the learning rate to decay to
|
||||
(1 / 2)**power of its original value.
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
|
||||
power (float): Exponential factor of learning rate decay. Default: 1.
|
||||
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
||||
Default: 0.
|
||||
min_lr (float): The minimum learning rate. Default: 0.
|
||||
last_epoch (int): The index of last epoch. Default: -1.
|
||||
verbose (bool): If ``True``, prints a message to stdout for
|
||||
each update. Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
|
||||
last_epoch=-1, verbose=False):
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
if not 0. <= warmup < 1:
|
||||
raise ValueError('Invalid value for warmup')
|
||||
self.warmup = warmup
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn("To get the last learning rate computed by the scheduler, "
|
||||
"please use `get_last_lr()`.")
|
||||
|
||||
return self._get_closed_form_lr()
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
||||
lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
|
||||
return [warmup * max(self.min_lr, base_lr * lr_mult)
|
||||
for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
class ExponentialLR(optim.lr_scheduler._LRScheduler):
|
||||
"""Implements an exponential learning rate schedule with an optional exponential
|
||||
warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
|
||||
continuously by decay (default 0.5) every num_steps steps.
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
num_steps (float): The number of steps to decay the learning rate by decay in.
|
||||
decay (float): The factor by which to decay the learning rate every num_steps
|
||||
steps. Default: 0.5.
|
||||
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
||||
Default: 0.
|
||||
min_lr (float): The minimum learning rate. Default: 0.
|
||||
last_epoch (int): The index of last epoch. Default: -1.
|
||||
verbose (bool): If ``True``, prints a message to stdout for
|
||||
each update. Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
|
||||
last_epoch=-1, verbose=False):
|
||||
self.num_steps = num_steps
|
||||
self.decay = decay
|
||||
if not 0. <= warmup < 1:
|
||||
raise ValueError('Invalid value for warmup')
|
||||
self.warmup = warmup
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn("To get the last learning rate computed by the scheduler, "
|
||||
"please use `get_last_lr()`.")
|
||||
|
||||
return self._get_closed_form_lr()
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
||||
lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
|
||||
return [warmup * max(self.min_lr, base_lr * lr_mult)
|
||||
for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from an lognormal distribution."""
|
||||
return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
|
||||
|
||||
|
||||
def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from an optionally truncated log-logistic distribution."""
|
||||
min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
|
||||
max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
|
||||
min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
|
||||
max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
|
||||
u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
|
||||
return u.logit().mul(scale).add(loc).exp().to(dtype)
|
||||
|
||||
|
||||
def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from an log-uniform distribution."""
|
||||
min_value = math.log(min_value)
|
||||
max_value = math.log(max_value)
|
||||
return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
|
||||
|
||||
|
||||
def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from a truncated v-diffusion training timestep distribution."""
|
||||
min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
|
||||
max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
|
||||
u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
|
||||
return torch.tan(u * math.pi / 2) * sigma_data
|
||||
|
||||
|
||||
def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from a split lognormal distribution."""
|
||||
n = torch.randn(shape, device=device, dtype=dtype).abs()
|
||||
u = torch.rand(shape, device=device, dtype=dtype)
|
||||
n_left = n * -scale_1 + loc
|
||||
n_right = n * scale_2 + loc
|
||||
ratio = scale_1 / (scale_1 + scale_2)
|
||||
return torch.where(u < ratio, n_left, n_right).exp()
|
||||
|
||||
|
||||
class FolderOfImages(data.Dataset):
|
||||
"""Recursively finds all images in a directory. It does not support
|
||||
classes/targets."""
|
||||
|
||||
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
|
||||
|
||||
def __init__(self, root, transform=None):
|
||||
super().__init__()
|
||||
self.root = Path(root)
|
||||
self.transform = nn.Identity() if transform is None else transform
|
||||
self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
|
||||
|
||||
def __repr__(self):
|
||||
return f'FolderOfImages(root="{self.root}", len: {len(self)})'
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, key):
|
||||
path = self.paths[key]
|
||||
with open(path, 'rb') as f:
|
||||
image = Image.open(f).convert('RGB')
|
||||
image = self.transform(image)
|
||||
return image,
|
||||
|
||||
|
||||
class CSVLogger:
|
||||
def __init__(self, filename, columns):
|
||||
self.filename = Path(filename)
|
||||
self.columns = columns
|
||||
if self.filename.exists():
|
||||
self.file = open(self.filename, 'a')
|
||||
else:
|
||||
self.file = open(self.filename, 'w')
|
||||
self.write(*self.columns)
|
||||
|
||||
def write(self, *args):
|
||||
print(*args, sep=',', file=self.file, flush=True)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tf32_mode(cudnn=None, matmul=None):
|
||||
"""A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
|
||||
cudnn_old = torch.backends.cudnn.allow_tf32
|
||||
matmul_old = torch.backends.cuda.matmul.allow_tf32
|
||||
try:
|
||||
if cudnn is not None:
|
||||
torch.backends.cudnn.allow_tf32 = cudnn
|
||||
if matmul is not None:
|
||||
torch.backends.cuda.matmul.allow_tf32 = matmul
|
||||
yield
|
||||
finally:
|
||||
if cudnn is not None:
|
||||
torch.backends.cudnn.allow_tf32 = cudnn_old
|
||||
if matmul is not None:
|
||||
torch.backends.cuda.matmul.allow_tf32 = matmul_old
|
||||
0
comfy/ldm/data/__init__.py
Normal file
0
comfy/ldm/data/__init__.py
Normal file
24
comfy/ldm/data/util.py
Normal file
24
comfy/ldm/data/util.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
from ldm.modules.midas.api import load_midas_transform
|
||||
|
||||
|
||||
class AddMiDaS(object):
|
||||
def __init__(self, model_type):
|
||||
super().__init__()
|
||||
self.transform = load_midas_transform(model_type)
|
||||
|
||||
def pt2np(self, x):
|
||||
x = ((x + 1.0) * .5).detach().cpu().numpy()
|
||||
return x
|
||||
|
||||
def np2pt(self, x):
|
||||
x = torch.from_numpy(x) * 2 - 1.
|
||||
return x
|
||||
|
||||
def __call__(self, sample):
|
||||
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
||||
x = self.pt2np(sample['jpg'])
|
||||
x = self.transform({"image": x})["image"]
|
||||
sample['midas_in'] = x
|
||||
return sample
|
||||
223
comfy/ldm/models/autoencoder.py
Normal file
223
comfy/ldm/models/autoencoder.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import torch
|
||||
# import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.modules.ema import LitEma
|
||||
|
||||
# class AutoencoderKL(pl.LightningModule):
|
||||
class AutoencoderKL(torch.nn.Module):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
ema_decay=None,
|
||||
learn_logvar=False
|
||||
):
|
||||
super().__init__()
|
||||
self.learn_logvar = learn_logvar
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
|
||||
self.use_ema = ema_decay is not None
|
||||
if self.use_ema:
|
||||
self.ema_decay = ema_decay
|
||||
assert 0. < ema_decay < 1.
|
||||
self.model_ema = LitEma(self, decay=ema_decay)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
if path.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
sd = safetensors.torch.load_file(path, device="cpu")
|
||||
else:
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
|
||||
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, postfix=""):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val"+postfix)
|
||||
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val"+postfix)
|
||||
|
||||
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
|
||||
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
|
||||
if self.learn_logvar:
|
||||
print(f"{self.__class__.__name__}: Learning logvar")
|
||||
ae_params_list.append(self.loss.logvar)
|
||||
opt_ae = torch.optim.Adam(ae_params_list,
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log["reconstructions"] = xrec
|
||||
if log_ema or self.use_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, posterior_ema = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec_ema.shape[1] > 3
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
log["inputs"] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
0
comfy/ldm/models/diffusion/__init__.py
Normal file
0
comfy/ldm/models/diffusion/__init__.py
Normal file
336
comfy/ldm/models/diffusion/ddim.py
Normal file
336
comfy/ldm/models/diffusion/ddim.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None,
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
ucg_schedule=ucg_schedule
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
||||
ucg_schedule=None):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
if ucg_schedule is not None:
|
||||
assert len(ucg_schedule) == len(time_range)
|
||||
unconditional_guidance_scale = ucg_schedule[i]
|
||||
|
||||
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
img, pred_x0 = outs
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [torch.cat([
|
||||
unconditional_conditioning[k][i],
|
||||
c[k][i]]) for i in range(len(c[k]))]
|
||||
else:
|
||||
c_in[k] = torch.cat([
|
||||
unconditional_conditioning[k],
|
||||
c[k]])
|
||||
elif isinstance(c, list):
|
||||
c_in = list()
|
||||
assert isinstance(unconditional_conditioning, list)
|
||||
for i in range(len(c)):
|
||||
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
||||
|
||||
if self.model.parameterization == "v":
|
||||
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
||||
else:
|
||||
e_t = model_output
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps", 'not implemented'
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
else:
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
if dynamic_threshold is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
|
||||
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
|
||||
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
|
||||
|
||||
assert t_enc <= num_reference_steps
|
||||
num_steps = t_enc
|
||||
|
||||
if use_original_steps:
|
||||
alphas_next = self.alphas_cumprod[:num_steps]
|
||||
alphas = self.alphas_cumprod_prev[:num_steps]
|
||||
else:
|
||||
alphas_next = self.ddim_alphas[:num_steps]
|
||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||
|
||||
x_next = x0
|
||||
intermediates = []
|
||||
inter_steps = []
|
||||
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
||||
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
|
||||
if unconditional_guidance_scale == 1.:
|
||||
noise_pred = self.model.apply_model(x_next, t, c)
|
||||
else:
|
||||
assert unconditional_conditioning is not None
|
||||
e_t_uncond, noise_pred = torch.chunk(
|
||||
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
|
||||
torch.cat((unconditional_conditioning, c))), 2)
|
||||
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
|
||||
|
||||
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
||||
weighted_noise_pred = alphas_next[i].sqrt() * (
|
||||
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
|
||||
x_next = xt_weighted + weighted_noise_pred
|
||||
if return_intermediates and i % (
|
||||
num_steps // return_intermediates) == 0 and i < num_steps - 1:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
elif return_intermediates and i >= num_steps - 2:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
if callback: callback(i)
|
||||
|
||||
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
|
||||
if return_intermediates:
|
||||
out.update({'intermediates': intermediates})
|
||||
return x_next, out
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
use_original_steps=False, callback=None):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
if callback: callback(i)
|
||||
return x_dec
|
||||
1800
comfy/ldm/models/diffusion/ddpm.py
Normal file
1800
comfy/ldm/models/diffusion/ddpm.py
Normal file
File diff suppressed because it is too large
Load Diff
1
comfy/ldm/models/diffusion/dpm_solver/__init__.py
Normal file
1
comfy/ldm/models/diffusion/dpm_solver/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .sampler import DPMSolverSampler
|
||||
1154
comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py
Normal file
1154
comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py
Normal file
File diff suppressed because it is too large
Load Diff
87
comfy/ldm/models/diffusion/dpm_solver/sampler.py
Normal file
87
comfy/ldm/models/diffusion/dpm_solver/sampler.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""SAMPLING ONLY."""
|
||||
import torch
|
||||
|
||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||
|
||||
|
||||
MODEL_TYPES = {
|
||||
"eps": "noise",
|
||||
"v": "v"
|
||||
}
|
||||
|
||||
|
||||
class DPMSolverSampler(object):
|
||||
def __init__(self, model, **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
||||
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
|
||||
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
|
||||
|
||||
device = self.model.betas.device
|
||||
if x_T is None:
|
||||
img = torch.randn(size, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
||||
|
||||
model_fn = model_wrapper(
|
||||
lambda x, t, c: self.model.apply_model(x, t, c),
|
||||
ns,
|
||||
model_type=MODEL_TYPES[self.model.parameterization],
|
||||
guidance_type="classifier-free",
|
||||
condition=conditioning,
|
||||
unconditional_condition=unconditional_conditioning,
|
||||
guidance_scale=unconditional_guidance_scale,
|
||||
)
|
||||
|
||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
|
||||
|
||||
return x.to(device), None
|
||||
244
comfy/ldm/models/diffusion/plms.py
Normal file
244
comfy/ldm/models/diffusion/plms.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
if ddim_eta != 0:
|
||||
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||
old_eps = []
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps, t_next=ts_next,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
img, pred_x0, e_t = outs
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
|
||||
dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
||||
22
comfy/ldm/models/diffusion/sampling_util.py
Normal file
22
comfy/ldm/models/diffusion/sampling_util.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
|
||||
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def norm_thresholding(x0, value):
|
||||
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
|
||||
return x0 * (value / s)
|
||||
|
||||
|
||||
def spatial_norm_thresholding(x0, value):
|
||||
# b c h w
|
||||
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
|
||||
return x0 * (value / s)
|
||||
533
comfy/ldm/modules/attention.py
Normal file
533
comfy/ldm/modules/attention.py
Normal file
@@ -0,0 +1,533 @@
|
||||
from inspect import isfunction
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional, Any
|
||||
|
||||
from ldm.modules.diffusionmodules.util import checkpoint
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
|
||||
# CrossAttn precision handling
|
||||
import os
|
||||
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
class CrossAttentionBirchSan(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
query = self.to_q(x)
|
||||
context = default(context, x)
|
||||
key = self.to_k(context)
|
||||
value = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
key_t = key.transpose(1,2).unflatten(1, (self.heads, -1)).flatten(end_dim=1)
|
||||
del key
|
||||
value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
|
||||
dtype = query.dtype
|
||||
# TODO: do we still need to do *everything* in float32, given how we delay the division?
|
||||
# TODO: do we need to support upcast_softmax too? SD 2.1 seems to work without it
|
||||
# if self.upcast_attention:
|
||||
# query = query.float()
|
||||
# key_t = key_t.float()
|
||||
|
||||
bytes_per_token = torch.finfo(query.dtype).bits//8
|
||||
batch_x_heads, q_tokens, _ = query.shape
|
||||
_, _, k_tokens = key_t.shape
|
||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
|
||||
stats = torch.cuda.memory_stats(query.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
|
||||
|
||||
kv_chunk_size_min = None
|
||||
|
||||
query_chunk_size_x = 1024 * 4
|
||||
kv_chunk_size_min_x = None
|
||||
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 1.2) // 1024) * 1024
|
||||
if kv_chunk_size_x < 1024:
|
||||
kv_chunk_size_x = None
|
||||
|
||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||
# i.e. send it down the unchunked fast-path
|
||||
query_chunk_size = q_tokens
|
||||
kv_chunk_size = k_tokens
|
||||
else:
|
||||
query_chunk_size = query_chunk_size_x
|
||||
kv_chunk_size = kv_chunk_size_x
|
||||
kv_chunk_size_min = kv_chunk_size_min_x
|
||||
|
||||
hidden_states = efficient_dot_product_attention(
|
||||
query,
|
||||
key_t,
|
||||
value,
|
||||
query_chunk_size=query_chunk_size,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min=kv_chunk_size_min,
|
||||
use_checkpoint=self.training,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
hidden_states = hidden_states.unflatten(0, (-1, self.heads)).transpose(1,2).flatten(start_dim=2)
|
||||
|
||||
out_proj, dropout = self.to_out
|
||||
hidden_states = out_proj(hidden_states)
|
||||
hidden_states = dropout(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossAttentionDoggettx(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k_in = self.to_k(context)
|
||||
v_in = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
||||
first_op_done = False
|
||||
cleared_cache = False
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * self.scale
|
||||
else:
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
first_op_done = True
|
||||
|
||||
s2 = s1.softmax(dim=-1)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
break
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
if first_op_done == False:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
if cleared_cache == False:
|
||||
cleared_cache = True
|
||||
print("out of memory error, emptying cache and trying again")
|
||||
continue
|
||||
steps *= 2
|
||||
if steps > 64:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again", steps)
|
||||
else:
|
||||
raise e
|
||||
|
||||
del q, k, v
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
|
||||
class OriginalCrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
q, k = q.float(), k.float()
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
del q, k
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', sim, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
class CrossAttention(CrossAttentionDoggettx):
|
||||
pass
|
||||
|
||||
class MemoryEfficientCrossAttention(nn.Module):
|
||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||
f"{heads} heads.")
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
||||
)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
ATTENTION_MODES = {
|
||||
"softmax": CrossAttention, # vanilla attention
|
||||
"softmax-xformers": MemoryEfficientCrossAttention
|
||||
}
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||
disable_self_attn=False):
|
||||
super().__init__()
|
||||
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
NEW: use_linear for more efficiency instead of the 1x1 convs
|
||||
"""
|
||||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None,
|
||||
disable_self_attn=False, use_linear=False,
|
||||
use_checkpoint=True):
|
||||
super().__init__()
|
||||
if exists(context_dim) and not isinstance(context_dim, list):
|
||||
context_dim = [context_dim]
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
if not use_linear:
|
||||
self.proj_in = nn.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
else:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
|
||||
for d in range(depth)]
|
||||
)
|
||||
if not use_linear:
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
else:
|
||||
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
||||
self.use_linear = use_linear
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
if not isinstance(context, list):
|
||||
context = [context]
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
x = block(x, context=context[i])
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
0
comfy/ldm/modules/diffusionmodules/__init__.py
Normal file
0
comfy/ldm/modules/diffusionmodules/__init__.py
Normal file
852
comfy/ldm/modules/diffusionmodules/model.py
Normal file
852
comfy/ldm/modules/diffusionmodules/model.py
Normal file
@@ -0,0 +1,852 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from typing import Optional, Any
|
||||
|
||||
from ldm.modules.attention import MemoryEfficientCrossAttention
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
print("No module 'xformers'. Proceeding without it.")
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0,1,0,1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout, temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels,
|
||||
out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x+h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b,c,h*w)
|
||||
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b,c,h,w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
class MemoryEfficientAttnBlock(nn.Module):
|
||||
"""
|
||||
Uses xformers efficient implementation,
|
||||
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
Note: this is a single-head self-attention operation
|
||||
"""
|
||||
#
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(B, t.shape[1], 1, C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * 1, t.shape[1], C)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(B, 1, out.shape[1], C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B, out.shape[1], C)
|
||||
)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
||||
out = self.proj_out(out)
|
||||
return x+out
|
||||
|
||||
|
||||
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
||||
def forward(self, x, context=None, mask=None):
|
||||
b, c, h, w = x.shape
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
out = super().forward(x, context=context, mask=mask)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
|
||||
return x + out
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
|
||||
attn_type = "vanilla-xformers"
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "vanilla-xformers":
|
||||
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||
return MemoryEfficientAttnBlock(in_channels)
|
||||
elif type == "memory-efficient-cross-attn":
|
||||
attn_kwargs["query_dim"] = in_channels
|
||||
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
||||
elif attn_type == "none":
|
||||
return nn.Identity(in_channels)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch*4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.use_timestep = use_timestep
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList([
|
||||
torch.nn.Linear(self.ch,
|
||||
self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch,
|
||||
self.temb_ch),
|
||||
])
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
skip_in = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch*in_ch_mult[i_level]
|
||||
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x, t=None, context=None):
|
||||
#assert x.shape[2] == x.shape[3] == self.resolution
|
||||
if context is not None:
|
||||
# assume aligned context, cat along channel axis
|
||||
x = torch.cat((x, context), dim=1)
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](
|
||||
torch.cat([h, hs.pop()], dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.weight
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
2*z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||
attn_type="vanilla", **ignorekwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
#assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class SimpleDecoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
||||
ResnetBlock(in_channels=in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=2 * in_channels,
|
||||
out_channels=4 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=4 * in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
nn.Conv2d(2*in_channels, in_channels, 1),
|
||||
Upsample(in_channels, with_conv=True)])
|
||||
# end
|
||||
self.norm_out = Normalize(in_channels)
|
||||
self.conv_out = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.model):
|
||||
if i in [1,2,3]:
|
||||
x = layer(x, None)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
h = self.norm_out(x)
|
||||
h = nonlinearity(h)
|
||||
x = self.conv_out(h)
|
||||
return x
|
||||
|
||||
|
||||
class UpsampleDecoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
||||
ch_mult=(2,2), dropout=0.0):
|
||||
super().__init__()
|
||||
# upsampling
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
block_in = in_channels
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.res_blocks = nn.ModuleList()
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
res_block = []
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
res_block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
self.res_blocks.append(nn.ModuleList(res_block))
|
||||
if i_level != self.num_resolutions - 1:
|
||||
self.upsample_blocks.append(Upsample(block_in, True))
|
||||
curr_res = curr_res * 2
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# upsampling
|
||||
h = x
|
||||
for k, i_level in enumerate(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.res_blocks[i_level][i_block](h, None)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.upsample_blocks[k](h)
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class LatentRescaler(nn.Module):
|
||||
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
||||
super().__init__()
|
||||
# residual block, interpolate, residual block
|
||||
self.factor = factor
|
||||
self.conv_in = nn.Conv2d(in_channels,
|
||||
mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)])
|
||||
self.attn = AttnBlock(mid_channels)
|
||||
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)])
|
||||
|
||||
self.conv_out = nn.Conv2d(mid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
for block in self.res_block1:
|
||||
x = block(x, None)
|
||||
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
|
||||
x = self.attn(x)
|
||||
for block in self.res_block2:
|
||||
x = block(x, None)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class MergedRescaleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
||||
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
|
||||
super().__init__()
|
||||
intermediate_chn = ch * ch_mult[-1]
|
||||
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
|
||||
z_channels=intermediate_chn, double_z=False, resolution=resolution,
|
||||
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
|
||||
out_ch=None)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
|
||||
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.rescaler(x)
|
||||
return x
|
||||
|
||||
|
||||
class MergedRescaleDecoder(nn.Module):
|
||||
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
|
||||
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
|
||||
super().__init__()
|
||||
tmp_chn = z_channels*ch_mult[-1]
|
||||
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
|
||||
ch_mult=ch_mult, resolution=resolution, ch=ch)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
|
||||
out_channels=tmp_chn, depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsampler(nn.Module):
|
||||
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
||||
super().__init__()
|
||||
assert out_size >= in_size
|
||||
num_blocks = int(np.log2(out_size//in_size))+1
|
||||
factor_up = 1.+ (out_size % in_size)
|
||||
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
||||
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
|
||||
out_channels=in_channels)
|
||||
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
|
||||
attn_resolutions=[], in_channels=None, ch=in_channels,
|
||||
ch_mult=[ch_mult for _ in range(num_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
|
||||
class Resize(nn.Module):
|
||||
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
||||
super().__init__()
|
||||
self.with_conv = learned
|
||||
self.mode = mode
|
||||
if self.with_conv:
|
||||
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
||||
raise NotImplementedError()
|
||||
assert in_channels is not None
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x, scale_factor=1.0):
|
||||
if scale_factor==1.0:
|
||||
return x
|
||||
else:
|
||||
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
|
||||
return x
|
||||
786
comfy/ldm/modules/diffusionmodules/openaimodel.py
Normal file
786
comfy/ldm/modules/diffusionmodules/openaimodel.py
Normal file
@@ -0,0 +1,786 @@
|
||||
from abc import abstractmethod
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
checkpoint,
|
||||
conv_nd,
|
||||
linear,
|
||||
avg_pool_nd,
|
||||
zero_module,
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
)
|
||||
from ldm.modules.attention import SpatialTransformer
|
||||
from ldm.util import exists
|
||||
|
||||
|
||||
# dummy replace
|
||||
def convert_module_to_f16(x):
|
||||
pass
|
||||
|
||||
def convert_module_to_f32(x):
|
||||
pass
|
||||
|
||||
|
||||
## go
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads_channels: int,
|
||||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, *_spatial = x.shape
|
||||
x = x.reshape(b, c, -1) # NC(HW)
|
||||
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
||||
x = self.qkv_proj(x)
|
||||
x = self.attention(x)
|
||||
x = self.c_proj(x)
|
||||
return x[:, :, 0]
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, context=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(
|
||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
||||
)
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class TransposedUpsample(nn.Module):
|
||||
'Learned 2x upsampling without padding'
|
||||
def __init__(self, channels, out_channels=None, ks=5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
|
||||
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
|
||||
|
||||
def forward(self,x):
|
||||
return self.up(x)
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
||||
)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
:param channels: the number of input channels.
|
||||
:param emb_channels: the number of timestep embedding channels.
|
||||
:param dropout: the rate of dropout.
|
||||
:param out_channels: if specified, the number of out channels.
|
||||
:param use_conv: if True and out_channels is specified, use a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the
|
||||
channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
||||
:param up: if True, use this block for upsampling.
|
||||
:param down: if True, use this block for downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_conv=False,
|
||||
use_scale_shift_norm=False,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.emb_channels = emb_channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
||||
),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 3, padding=1
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
:param x: an [N x C x ...] Tensor of features.
|
||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
return checkpoint(
|
||||
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
||||
)
|
||||
|
||||
|
||||
def _forward(self, x, emb):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_checkpoint=False,
|
||||
use_new_attention_order=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
if use_new_attention_order:
|
||||
# split qkv before split heads
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
else:
|
||||
# split heads before split qkv
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||
#return pt_checkpoint(self._forward, x) # pytorch
|
||||
|
||||
def _forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
||||
|
||||
def count_flops_attn(model, _x, y):
|
||||
"""
|
||||
A counter for the `thop` package to count the operations in an
|
||||
attention operation.
|
||||
Meant to be used like:
|
||||
macs, params = thop.profile(
|
||||
model,
|
||||
inputs=(inputs, timestamps),
|
||||
custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||
)
|
||||
"""
|
||||
b, c, *spatial = y[0].shape
|
||||
num_spatial = int(np.prod(spatial))
|
||||
# We perform two matmuls with the same number of ops.
|
||||
# The first computes the weight matrix, the second computes
|
||||
# the combination of the value vectors.
|
||||
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
||||
model.total_ops += th.DoubleTensor([matmul_ops])
|
||||
|
||||
|
||||
class QKVAttentionLegacy(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts", q * scale, k * scale
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention and splits in a different order.
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts",
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
:param in_channels: channels in the input Tensor.
|
||||
:param model_channels: base channel count for the model.
|
||||
:param out_channels: channels in the output Tensor.
|
||||
:param num_res_blocks: number of residual blocks per downsample.
|
||||
:param attention_resolutions: a collection of downsample rates at which
|
||||
attention will take place. May be a set, list, or tuple.
|
||||
For example, if this contains 4, then at 4x downsampling, attention
|
||||
will be used.
|
||||
:param dropout: the dropout probability.
|
||||
:param channel_mult: channel multiplier for each level of the UNet.
|
||||
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param num_classes: if specified (as an int), then this model will be
|
||||
class-conditional with `num_classes` classes.
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
||||
:param num_heads: the number of attention heads in each attention layer.
|
||||
:param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
||||
:param resblock_updown: use residual blocks for up/downsampling.
|
||||
:param use_new_attention_order: use a different attention pattern for potentially
|
||||
increased efficiency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None,
|
||||
disable_middle_self_attn=False,
|
||||
use_linear_in_transformer=False,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||
from omegaconf.listconfig import ListConfig
|
||||
if type(context_dim) == ListConfig:
|
||||
context_dim = list(context_dim)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
if len(num_res_blocks) != len(channel_mult):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set.")
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
elif self.num_classes == "continuous":
|
||||
print("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(self.num_res_blocks[level] + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=model_channels * mult,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
)
|
||||
if level and i == self.num_res_blocks[level]:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
normalization(ch),
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
Convert the torso of the model to float16.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f16)
|
||||
self.middle_block.apply(convert_module_to_f16)
|
||||
self.output_blocks.apply(convert_module_to_f16)
|
||||
|
||||
def convert_to_fp32(self):
|
||||
"""
|
||||
Convert the torso of the model to float32.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f32)
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
self.output_blocks.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param context: conditioning plugged in via crossattn
|
||||
:param y: an [N] Tensor of labels, if class-conditional.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
for module in self.output_blocks:
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
else:
|
||||
return self.out(h)
|
||||
81
comfy/ldm/modules/diffusionmodules/upscaling.py
Normal file
81
comfy/ldm/modules/diffusionmodules/upscaling.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
|
||||
from ldm.util import default
|
||||
|
||||
|
||||
class AbstractLowScaleModel(nn.Module):
|
||||
# for concatenating a downsampled image to the latent representation
|
||||
def __init__(self, noise_schedule_config=None):
|
||||
super(AbstractLowScaleModel, self).__init__()
|
||||
if noise_schedule_config is not None:
|
||||
self.register_schedule(**noise_schedule_config)
|
||||
|
||||
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
||||
cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer('betas', to_torch(betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||
|
||||
def forward(self, x):
|
||||
return x, None
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class SimpleImageConcat(AbstractLowScaleModel):
|
||||
# no noise level conditioning
|
||||
def __init__(self):
|
||||
super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
|
||||
self.max_noise_level = 0
|
||||
|
||||
def forward(self, x):
|
||||
# fix to constant noise level
|
||||
return x, torch.zeros(x.shape[0], device=x.device).long()
|
||||
|
||||
|
||||
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
|
||||
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
|
||||
super().__init__(noise_schedule_config=noise_schedule_config)
|
||||
self.max_noise_level = max_noise_level
|
||||
|
||||
def forward(self, x, noise_level=None):
|
||||
if noise_level is None:
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
else:
|
||||
assert isinstance(noise_level, torch.Tensor)
|
||||
z = self.q_sample(x, noise_level)
|
||||
return z, noise_level
|
||||
|
||||
|
||||
|
||||
270
comfy/ldm/modules/diffusionmodules/util.py
Normal file
270
comfy/ldm/modules/diffusionmodules/util.py
Normal file
@@ -0,0 +1,270 @@
|
||||
# adopted from
|
||||
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# and
|
||||
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
# and
|
||||
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
||||
#
|
||||
# thanks!
|
||||
|
||||
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
||||
)
|
||||
|
||||
elif schedule == "cosine":
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
||||
)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
alphas = alphas / alphas[0]
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
elif schedule == "sqrt":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
||||
else:
|
||||
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
# select alphas for computing the variance schedule
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
if verbose:
|
||||
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
print(f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas)
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
||||
"dtype": torch.get_autocast_gpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled()}
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
with torch.enable_grad(), \
|
||||
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||
output_tensors = ctx.run_function(*shallow_copies)
|
||||
input_grads = torch.autograd.grad(
|
||||
output_tensors,
|
||||
ctx.input_tensors + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class HybridConditioner(nn.Module):
|
||||
|
||||
def __init__(self, c_concat_config, c_crossattn_config):
|
||||
super().__init__()
|
||||
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
||||
|
||||
def forward(self, c_concat, c_crossattn):
|
||||
c_concat = self.concat_conditioner(c_concat)
|
||||
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
||||
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
0
comfy/ldm/modules/distributions/__init__.py
Normal file
0
comfy/ldm/modules/distributions/__init__.py
Normal file
92
comfy/ldm/modules/distributions/distributions.py
Normal file
92
comfy/ldm/modules/distributions/distributions.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1,2,3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ torch.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
||||
80
comfy/ldm/modules/ema.py
Normal file
80
comfy/ldm/modules/ema.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LitEma(nn.Module):
|
||||
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
||||
super().__init__()
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
|
||||
self.m_name2s_name = {}
|
||||
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
|
||||
else torch.tensor(-1, dtype=torch.int))
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
# remove as '.'-character is not allowed in buffers
|
||||
s_name = name.replace('.', '')
|
||||
self.m_name2s_name.update({name: s_name})
|
||||
self.register_buffer(s_name, p.clone().detach().data)
|
||||
|
||||
self.collected_params = []
|
||||
|
||||
def reset_num_updates(self):
|
||||
del self.num_updates
|
||||
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
|
||||
|
||||
def forward(self, model):
|
||||
decay = self.decay
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
with torch.no_grad():
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
0
comfy/ldm/modules/encoders/__init__.py
Normal file
0
comfy/ldm/modules/encoders/__init__.py
Normal file
213
comfy/ldm/modules/encoders/modules.py
Normal file
213
comfy/ldm/modules/encoders/modules.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||
|
||||
import open_clip
|
||||
from ldm.util import default, count_params
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IdentityEncoder(AbstractEncoder):
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
self.n_classes = n_classes
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def forward(self, batch, key=None, disable_dropout=False):
|
||||
if key is None:
|
||||
key = self.key
|
||||
# this is for use in crossattn
|
||||
c = batch[key][:, None]
|
||||
if self.ucg_rate > 0. and not disable_dropout:
|
||||
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
|
||||
c = c.long()
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
||||
def get_unconditional_conditioning(self, bs, device="cuda"):
|
||||
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
|
||||
uc = torch.ones((bs,), device=device) * uc_class
|
||||
uc = {self.key: uc}
|
||||
return uc
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
if freeze:
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = [
|
||||
"last",
|
||||
"pooled",
|
||||
"hidden"
|
||||
]
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
self.layer_idx = layer_idx
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert 0 <= abs(layer_idx) <= 12
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = [
|
||||
#"pooled",
|
||||
"last",
|
||||
"penultimate"
|
||||
]
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||
freeze=True, layer="last"):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
|
||||
del model.visual
|
||||
self.model = model
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "last":
|
||||
self.layer_idx = 0
|
||||
elif self.layer == "penultimate":
|
||||
self.layer_idx = 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
tokens = open_clip.tokenize(text)
|
||||
z = self.encode_with_transformer(tokens.to(self.device))
|
||||
return z
|
||||
|
||||
def encode_with_transformer(self, text):
|
||||
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
||||
x = x + self.model.positional_embedding
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint(r, x, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
||||
clip_max_length=77, t5_max_length=77):
|
||||
super().__init__()
|
||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
||||
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
||||
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
def forward(self, text):
|
||||
clip_z = self.clip_encoder.encode(text)
|
||||
t5_z = self.t5_encoder.encode(text)
|
||||
return [clip_z, t5_z]
|
||||
|
||||
|
||||
2
comfy/ldm/modules/image_degradation/__init__.py
Normal file
2
comfy/ldm/modules/image_degradation/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
|
||||
from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
|
||||
730
comfy/ldm/modules/image_degradation/bsrgan.py
Normal file
730
comfy/ldm/modules/image_degradation/bsrgan.py
Normal file
@@ -0,0 +1,730 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Super-Resolution
|
||||
# --------------------------------------------
|
||||
#
|
||||
# Kai Zhang (cskaizhang@gmail.com)
|
||||
# https://github.com/cszn
|
||||
# From 2019/03--2021/08
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from functools import partial
|
||||
import random
|
||||
from scipy import ndimage
|
||||
import scipy
|
||||
import scipy.stats as ss
|
||||
from scipy.interpolate import interp2d
|
||||
from scipy.linalg import orth
|
||||
import albumentations
|
||||
|
||||
import ldm.modules.image_degradation.utils_image as util
|
||||
|
||||
|
||||
def modcrop_np(img, sf):
|
||||
'''
|
||||
Args:
|
||||
img: numpy image, WxH or WxHxC
|
||||
sf: scale factor
|
||||
Return:
|
||||
cropped image
|
||||
'''
|
||||
w, h = img.shape[:2]
|
||||
im = np.copy(img)
|
||||
return im[:w - w % sf, :h - h % sf, ...]
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# anisotropic Gaussian kernels
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def analytic_kernel(k):
|
||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
||||
k_size = k.shape[0]
|
||||
# Calculate the big kernels size
|
||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
||||
# Loop over the small kernel to fill the big one
|
||||
for r in range(k_size):
|
||||
for c in range(k_size):
|
||||
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
|
||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
||||
crop = k_size // 2
|
||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
||||
# Normalize to 1
|
||||
return cropped_big_k / cropped_big_k.sum()
|
||||
|
||||
|
||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
||||
""" generate an anisotropic Gaussian kernel
|
||||
Args:
|
||||
ksize : e.g., 15, kernel size
|
||||
theta : [0, pi], rotation angle range
|
||||
l1 : [0.1,50], scaling of eigenvalues
|
||||
l2 : [0.1,l1], scaling of eigenvalues
|
||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
||||
Returns:
|
||||
k : kernel
|
||||
"""
|
||||
|
||||
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
|
||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
||||
D = np.array([[l1, 0], [0, l2]])
|
||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def gm_blur_kernel(mean, cov, size=15):
|
||||
center = size / 2.0 + 0.5
|
||||
k = np.zeros([size, size])
|
||||
for y in range(size):
|
||||
for x in range(size):
|
||||
cy = y - center + 1
|
||||
cx = x - center + 1
|
||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
||||
|
||||
k = k / np.sum(k)
|
||||
return k
|
||||
|
||||
|
||||
def shift_pixel(x, sf, upper_left=True):
|
||||
"""shift pixel for super-resolution with different scale factors
|
||||
Args:
|
||||
x: WxHxC or WxH
|
||||
sf: scale factor
|
||||
upper_left: shift direction
|
||||
"""
|
||||
h, w = x.shape[:2]
|
||||
shift = (sf - 1) * 0.5
|
||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
||||
if upper_left:
|
||||
x1 = xv + shift
|
||||
y1 = yv + shift
|
||||
else:
|
||||
x1 = xv - shift
|
||||
y1 = yv - shift
|
||||
|
||||
x1 = np.clip(x1, 0, w - 1)
|
||||
y1 = np.clip(y1, 0, h - 1)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = interp2d(xv, yv, x)(x1, y1)
|
||||
if x.ndim == 3:
|
||||
for i in range(x.shape[-1]):
|
||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def blur(x, k):
|
||||
'''
|
||||
x: image, NxcxHxW
|
||||
k: kernel, Nx1xhxw
|
||||
'''
|
||||
n, c = x.shape[:2]
|
||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
|
||||
k = k.repeat(1, c, 1, 1)
|
||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
|
||||
""""
|
||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
||||
# Kai Zhang
|
||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
||||
# max_var = 2.5 * sf
|
||||
"""
|
||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
||||
theta = np.random.rand() * np.pi # random theta
|
||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
||||
|
||||
# Set COV matrix using Lambdas and Theta
|
||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
||||
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
||||
[np.sin(theta), np.cos(theta)]])
|
||||
SIGMA = Q @ LAMBDA @ Q.T
|
||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
||||
|
||||
# Set expectation position (shifting kernel for aligned image)
|
||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
||||
MU = MU[None, None, :, None]
|
||||
|
||||
# Create meshgrid for Gaussian
|
||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
||||
|
||||
# Calcualte Gaussian for every pixel of the kernel
|
||||
ZZ = Z - MU
|
||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
||||
|
||||
# shift the kernel so it will be centered
|
||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
||||
|
||||
# Normalize the kernel and return
|
||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
||||
kernel = raw_kernel / np.sum(raw_kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def fspecial_gaussian(hsize, sigma):
|
||||
hsize = [hsize, hsize]
|
||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
||||
std = sigma
|
||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
||||
arg = -(x * x + y * y) / (2 * std * std)
|
||||
h = np.exp(arg)
|
||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
||||
sumh = h.sum()
|
||||
if sumh != 0:
|
||||
h = h / sumh
|
||||
return h
|
||||
|
||||
|
||||
def fspecial_laplacian(alpha):
|
||||
alpha = max([0, min([alpha, 1])])
|
||||
h1 = alpha / (alpha + 1)
|
||||
h2 = (1 - alpha) / (alpha + 1)
|
||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
||||
h = np.array(h)
|
||||
return h
|
||||
|
||||
|
||||
def fspecial(filter_type, *args, **kwargs):
|
||||
'''
|
||||
python code from:
|
||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
||||
'''
|
||||
if filter_type == 'gaussian':
|
||||
return fspecial_gaussian(*args, **kwargs)
|
||||
if filter_type == 'laplacian':
|
||||
return fspecial_laplacian(*args, **kwargs)
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# degradation models
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def bicubic_degradation(x, sf=3):
|
||||
'''
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
bicubicly downsampled LR image
|
||||
'''
|
||||
x = util.imresize_np(x, scale=1 / sf)
|
||||
return x
|
||||
|
||||
|
||||
def srmd_degradation(x, k, sf=3):
|
||||
''' blur + bicubic downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2018learning,
|
||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={3262--3271},
|
||||
year={2018}
|
||||
}
|
||||
'''
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
return x
|
||||
|
||||
|
||||
def dpsr_degradation(x, k, sf=3):
|
||||
''' bicubic downsampling + blur
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2019deep,
|
||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={1671--1681},
|
||||
year={2019}
|
||||
}
|
||||
'''
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
return x
|
||||
|
||||
|
||||
def classical_degradation(x, k, sf=3):
|
||||
''' blur + downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]/[0, 255]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
'''
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
||||
st = 0
|
||||
return x[st::sf, st::sf, ...]
|
||||
|
||||
|
||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
||||
"""USM sharpening. borrowed from real-ESRGAN
|
||||
Input image: I; Blurry image: B.
|
||||
1. K = I + weight * (I - B)
|
||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
||||
3. Blur mask:
|
||||
4. Out = Mask * K + (1 - Mask) * I
|
||||
Args:
|
||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
||||
weight (float): Sharp weight. Default: 1.
|
||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
||||
threshold (int):
|
||||
"""
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
||||
residual = img - blur
|
||||
mask = np.abs(residual) * 255 > threshold
|
||||
mask = mask.astype('float32')
|
||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
||||
|
||||
K = img + weight * residual
|
||||
K = np.clip(K, 0, 1)
|
||||
return soft_mask * K + (1 - soft_mask) * img
|
||||
|
||||
|
||||
def add_blur(img, sf=4):
|
||||
wd2 = 4.0 + sf
|
||||
wd = 2.0 + 0.2 * sf
|
||||
if random.random() < 0.5:
|
||||
l1 = wd2 * random.random()
|
||||
l2 = wd2 * random.random()
|
||||
k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
|
||||
else:
|
||||
k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def add_resize(img, sf=4):
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.8: # up
|
||||
sf1 = random.uniform(1, 2)
|
||||
elif rnum < 0.7: # down
|
||||
sf1 = random.uniform(0.5 / sf, 1)
|
||||
else:
|
||||
sf1 = 1.0
|
||||
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
# noise_level = random.randint(noise_level1, noise_level2)
|
||||
# rnum = np.random.rand()
|
||||
# if rnum > 0.6: # add color Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
# else: # add noise
|
||||
# L = noise_level2 / 255.
|
||||
# D = np.diag(np.random.rand(3))
|
||||
# U = orth(np.random.rand(3, 3))
|
||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
# img = np.clip(img, 0.0, 1.0)
|
||||
# return img
|
||||
|
||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.6: # add color Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else: # add noise
|
||||
L = noise_level2 / 255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
rnum = random.random()
|
||||
if rnum > 0.6:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else:
|
||||
L = noise_level2 / 255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_Poisson_noise(img):
|
||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
||||
if random.random() < 0.5:
|
||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
||||
else:
|
||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
|
||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||
img += noise_gray[:, :, np.newaxis]
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_JPEG_noise(img):
|
||||
quality_factor = random.randint(30, 95)
|
||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
||||
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
||||
img = cv2.imdecode(encimg, 1)
|
||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
||||
h, w = lq.shape[:2]
|
||||
rnd_h = random.randint(0, h - lq_patchsize)
|
||||
rnd_w = random.randint(0, w - lq_patchsize)
|
||||
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
|
||||
|
||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
||||
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
|
||||
return lq, hq
|
||||
|
||||
|
||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
||||
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
||||
|
||||
hq = img.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
img = util.imresize_np(img, 1 / 2, True)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
||||
|
||||
for i in shuffle_order:
|
||||
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = img.shape[1], img.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
elif i == 6:
|
||||
# add processed camera sensor noise
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
# todo no isp_model?
|
||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = util.uint2single(image)
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = image.shape[:2]
|
||||
|
||||
hq = image.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
image = util.imresize_np(image, 1 / 2, True)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
||||
|
||||
for i in shuffle_order:
|
||||
|
||||
if i == 0:
|
||||
image = add_blur(image, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
image = add_blur(image, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = image.shape[1], image.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
image = add_JPEG_noise(image)
|
||||
|
||||
# elif i == 6:
|
||||
# # add processed camera sensor noise
|
||||
# if random.random() < isp_prob and isp_model is not None:
|
||||
# with torch.no_grad():
|
||||
# img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
image = add_JPEG_noise(image)
|
||||
image = util.single2uint(image)
|
||||
example = {"image":image}
|
||||
return example
|
||||
|
||||
|
||||
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
|
||||
def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
|
||||
"""
|
||||
This is an extended degradation model by combining
|
||||
the degradation models of BSRGAN and Real-ESRGAN
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
use_shuffle: the degradation shuffle
|
||||
use_sharp: sharpening the img
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
||||
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
||||
|
||||
if use_sharp:
|
||||
img = add_sharpening(img)
|
||||
hq = img.copy()
|
||||
|
||||
if random.random() < shuffle_prob:
|
||||
shuffle_order = random.sample(range(13), 13)
|
||||
else:
|
||||
shuffle_order = list(range(13))
|
||||
# local shuffle for noise, JPEG is always the last one
|
||||
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
|
||||
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
|
||||
|
||||
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
|
||||
|
||||
for i in shuffle_order:
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
elif i == 1:
|
||||
img = add_resize(img, sf=sf)
|
||||
elif i == 2:
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
elif i == 3:
|
||||
if random.random() < poisson_prob:
|
||||
img = add_Poisson_noise(img)
|
||||
elif i == 4:
|
||||
if random.random() < speckle_prob:
|
||||
img = add_speckle_noise(img)
|
||||
elif i == 5:
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
elif i == 6:
|
||||
img = add_JPEG_noise(img)
|
||||
elif i == 7:
|
||||
img = add_blur(img, sf=sf)
|
||||
elif i == 8:
|
||||
img = add_resize(img, sf=sf)
|
||||
elif i == 9:
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
elif i == 10:
|
||||
if random.random() < poisson_prob:
|
||||
img = add_Poisson_noise(img)
|
||||
elif i == 11:
|
||||
if random.random() < speckle_prob:
|
||||
img = add_speckle_noise(img)
|
||||
elif i == 12:
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
else:
|
||||
print('check the shuffle!')
|
||||
|
||||
# resize to desired size
|
||||
img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("hey")
|
||||
img = util.imread_uint('utils/test.png', 3)
|
||||
print(img)
|
||||
img = util.uint2single(img)
|
||||
print(img)
|
||||
img = img[:448, :448]
|
||||
h = img.shape[0] // 4
|
||||
print("resizing to", h)
|
||||
sf = 4
|
||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
||||
for i in range(20):
|
||||
print(i)
|
||||
img_lq = deg_fn(img)
|
||||
print(img_lq)
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
|
||||
print(img_lq.shape)
|
||||
print("bicubic", img_lq_bicubic.shape)
|
||||
print(img_hq.shape)
|
||||
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0)
|
||||
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0)
|
||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
util.imsave(img_concat, str(i) + '.png')
|
||||
|
||||
|
||||
651
comfy/ldm/modules/image_degradation/bsrgan_light.py
Normal file
651
comfy/ldm/modules/image_degradation/bsrgan_light.py
Normal file
@@ -0,0 +1,651 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from functools import partial
|
||||
import random
|
||||
from scipy import ndimage
|
||||
import scipy
|
||||
import scipy.stats as ss
|
||||
from scipy.interpolate import interp2d
|
||||
from scipy.linalg import orth
|
||||
import albumentations
|
||||
|
||||
import ldm.modules.image_degradation.utils_image as util
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Super-Resolution
|
||||
# --------------------------------------------
|
||||
#
|
||||
# Kai Zhang (cskaizhang@gmail.com)
|
||||
# https://github.com/cszn
|
||||
# From 2019/03--2021/08
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
def modcrop_np(img, sf):
|
||||
'''
|
||||
Args:
|
||||
img: numpy image, WxH or WxHxC
|
||||
sf: scale factor
|
||||
Return:
|
||||
cropped image
|
||||
'''
|
||||
w, h = img.shape[:2]
|
||||
im = np.copy(img)
|
||||
return im[:w - w % sf, :h - h % sf, ...]
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# anisotropic Gaussian kernels
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def analytic_kernel(k):
|
||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
||||
k_size = k.shape[0]
|
||||
# Calculate the big kernels size
|
||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
||||
# Loop over the small kernel to fill the big one
|
||||
for r in range(k_size):
|
||||
for c in range(k_size):
|
||||
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
|
||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
||||
crop = k_size // 2
|
||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
||||
# Normalize to 1
|
||||
return cropped_big_k / cropped_big_k.sum()
|
||||
|
||||
|
||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
||||
""" generate an anisotropic Gaussian kernel
|
||||
Args:
|
||||
ksize : e.g., 15, kernel size
|
||||
theta : [0, pi], rotation angle range
|
||||
l1 : [0.1,50], scaling of eigenvalues
|
||||
l2 : [0.1,l1], scaling of eigenvalues
|
||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
||||
Returns:
|
||||
k : kernel
|
||||
"""
|
||||
|
||||
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
|
||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
||||
D = np.array([[l1, 0], [0, l2]])
|
||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def gm_blur_kernel(mean, cov, size=15):
|
||||
center = size / 2.0 + 0.5
|
||||
k = np.zeros([size, size])
|
||||
for y in range(size):
|
||||
for x in range(size):
|
||||
cy = y - center + 1
|
||||
cx = x - center + 1
|
||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
||||
|
||||
k = k / np.sum(k)
|
||||
return k
|
||||
|
||||
|
||||
def shift_pixel(x, sf, upper_left=True):
|
||||
"""shift pixel for super-resolution with different scale factors
|
||||
Args:
|
||||
x: WxHxC or WxH
|
||||
sf: scale factor
|
||||
upper_left: shift direction
|
||||
"""
|
||||
h, w = x.shape[:2]
|
||||
shift = (sf - 1) * 0.5
|
||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
||||
if upper_left:
|
||||
x1 = xv + shift
|
||||
y1 = yv + shift
|
||||
else:
|
||||
x1 = xv - shift
|
||||
y1 = yv - shift
|
||||
|
||||
x1 = np.clip(x1, 0, w - 1)
|
||||
y1 = np.clip(y1, 0, h - 1)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = interp2d(xv, yv, x)(x1, y1)
|
||||
if x.ndim == 3:
|
||||
for i in range(x.shape[-1]):
|
||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def blur(x, k):
|
||||
'''
|
||||
x: image, NxcxHxW
|
||||
k: kernel, Nx1xhxw
|
||||
'''
|
||||
n, c = x.shape[:2]
|
||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
|
||||
k = k.repeat(1, c, 1, 1)
|
||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
|
||||
""""
|
||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
||||
# Kai Zhang
|
||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
||||
# max_var = 2.5 * sf
|
||||
"""
|
||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
||||
theta = np.random.rand() * np.pi # random theta
|
||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
||||
|
||||
# Set COV matrix using Lambdas and Theta
|
||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
||||
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
||||
[np.sin(theta), np.cos(theta)]])
|
||||
SIGMA = Q @ LAMBDA @ Q.T
|
||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
||||
|
||||
# Set expectation position (shifting kernel for aligned image)
|
||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
||||
MU = MU[None, None, :, None]
|
||||
|
||||
# Create meshgrid for Gaussian
|
||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
||||
|
||||
# Calcualte Gaussian for every pixel of the kernel
|
||||
ZZ = Z - MU
|
||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
||||
|
||||
# shift the kernel so it will be centered
|
||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
||||
|
||||
# Normalize the kernel and return
|
||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
||||
kernel = raw_kernel / np.sum(raw_kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def fspecial_gaussian(hsize, sigma):
|
||||
hsize = [hsize, hsize]
|
||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
||||
std = sigma
|
||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
||||
arg = -(x * x + y * y) / (2 * std * std)
|
||||
h = np.exp(arg)
|
||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
||||
sumh = h.sum()
|
||||
if sumh != 0:
|
||||
h = h / sumh
|
||||
return h
|
||||
|
||||
|
||||
def fspecial_laplacian(alpha):
|
||||
alpha = max([0, min([alpha, 1])])
|
||||
h1 = alpha / (alpha + 1)
|
||||
h2 = (1 - alpha) / (alpha + 1)
|
||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
||||
h = np.array(h)
|
||||
return h
|
||||
|
||||
|
||||
def fspecial(filter_type, *args, **kwargs):
|
||||
'''
|
||||
python code from:
|
||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
||||
'''
|
||||
if filter_type == 'gaussian':
|
||||
return fspecial_gaussian(*args, **kwargs)
|
||||
if filter_type == 'laplacian':
|
||||
return fspecial_laplacian(*args, **kwargs)
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# degradation models
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def bicubic_degradation(x, sf=3):
|
||||
'''
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
bicubicly downsampled LR image
|
||||
'''
|
||||
x = util.imresize_np(x, scale=1 / sf)
|
||||
return x
|
||||
|
||||
|
||||
def srmd_degradation(x, k, sf=3):
|
||||
''' blur + bicubic downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2018learning,
|
||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={3262--3271},
|
||||
year={2018}
|
||||
}
|
||||
'''
|
||||
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
return x
|
||||
|
||||
|
||||
def dpsr_degradation(x, k, sf=3):
|
||||
''' bicubic downsampling + blur
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2019deep,
|
||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={1671--1681},
|
||||
year={2019}
|
||||
}
|
||||
'''
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
return x
|
||||
|
||||
|
||||
def classical_degradation(x, k, sf=3):
|
||||
''' blur + downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]/[0, 255]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
'''
|
||||
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
||||
st = 0
|
||||
return x[st::sf, st::sf, ...]
|
||||
|
||||
|
||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
||||
"""USM sharpening. borrowed from real-ESRGAN
|
||||
Input image: I; Blurry image: B.
|
||||
1. K = I + weight * (I - B)
|
||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
||||
3. Blur mask:
|
||||
4. Out = Mask * K + (1 - Mask) * I
|
||||
Args:
|
||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
||||
weight (float): Sharp weight. Default: 1.
|
||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
||||
threshold (int):
|
||||
"""
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
||||
residual = img - blur
|
||||
mask = np.abs(residual) * 255 > threshold
|
||||
mask = mask.astype('float32')
|
||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
||||
|
||||
K = img + weight * residual
|
||||
K = np.clip(K, 0, 1)
|
||||
return soft_mask * K + (1 - soft_mask) * img
|
||||
|
||||
|
||||
def add_blur(img, sf=4):
|
||||
wd2 = 4.0 + sf
|
||||
wd = 2.0 + 0.2 * sf
|
||||
|
||||
wd2 = wd2/4
|
||||
wd = wd/4
|
||||
|
||||
if random.random() < 0.5:
|
||||
l1 = wd2 * random.random()
|
||||
l2 = wd2 * random.random()
|
||||
k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
|
||||
else:
|
||||
k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
|
||||
img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def add_resize(img, sf=4):
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.8: # up
|
||||
sf1 = random.uniform(1, 2)
|
||||
elif rnum < 0.7: # down
|
||||
sf1 = random.uniform(0.5 / sf, 1)
|
||||
else:
|
||||
sf1 = 1.0
|
||||
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
# noise_level = random.randint(noise_level1, noise_level2)
|
||||
# rnum = np.random.rand()
|
||||
# if rnum > 0.6: # add color Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
# else: # add noise
|
||||
# L = noise_level2 / 255.
|
||||
# D = np.diag(np.random.rand(3))
|
||||
# U = orth(np.random.rand(3, 3))
|
||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
# img = np.clip(img, 0.0, 1.0)
|
||||
# return img
|
||||
|
||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.6: # add color Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else: # add noise
|
||||
L = noise_level2 / 255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
rnum = random.random()
|
||||
if rnum > 0.6:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else:
|
||||
L = noise_level2 / 255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_Poisson_noise(img):
|
||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
||||
if random.random() < 0.5:
|
||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
||||
else:
|
||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
|
||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||
img += noise_gray[:, :, np.newaxis]
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_JPEG_noise(img):
|
||||
quality_factor = random.randint(80, 95)
|
||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
||||
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
||||
img = cv2.imdecode(encimg, 1)
|
||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
||||
h, w = lq.shape[:2]
|
||||
rnd_h = random.randint(0, h - lq_patchsize)
|
||||
rnd_w = random.randint(0, w - lq_patchsize)
|
||||
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
|
||||
|
||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
||||
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
|
||||
return lq, hq
|
||||
|
||||
|
||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
||||
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
||||
|
||||
hq = img.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
img = util.imresize_np(img, 1 / 2, True)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
||||
|
||||
for i in shuffle_order:
|
||||
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = img.shape[1], img.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
elif i == 6:
|
||||
# add processed camera sensor noise
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
# todo no isp_model?
|
||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = util.uint2single(image)
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = image.shape[:2]
|
||||
|
||||
hq = image.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
image = util.imresize_np(image, 1 / 2, True)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
||||
|
||||
for i in shuffle_order:
|
||||
|
||||
if i == 0:
|
||||
image = add_blur(image, sf=sf)
|
||||
|
||||
# elif i == 1:
|
||||
# image = add_blur(image, sf=sf)
|
||||
|
||||
if i == 0:
|
||||
pass
|
||||
|
||||
elif i == 2:
|
||||
a, b = image.shape[1], image.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.8:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
||||
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
image = add_JPEG_noise(image)
|
||||
#
|
||||
# elif i == 6:
|
||||
# # add processed camera sensor noise
|
||||
# if random.random() < isp_prob and isp_model is not None:
|
||||
# with torch.no_grad():
|
||||
# img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
image = add_JPEG_noise(image)
|
||||
image = util.single2uint(image)
|
||||
if up:
|
||||
image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then
|
||||
example = {"image": image}
|
||||
return example
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("hey")
|
||||
img = util.imread_uint('utils/test.png', 3)
|
||||
img = img[:448, :448]
|
||||
h = img.shape[0] // 4
|
||||
print("resizing to", h)
|
||||
sf = 4
|
||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
||||
for i in range(20):
|
||||
print(i)
|
||||
img_hq = img
|
||||
img_lq = deg_fn(img)["image"]
|
||||
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
|
||||
print(img_lq)
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
|
||||
print(img_lq.shape)
|
||||
print("bicubic", img_lq_bicubic.shape)
|
||||
print(img_hq.shape)
|
||||
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0)
|
||||
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0)
|
||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
util.imsave(img_concat, str(i) + '.png')
|
||||
BIN
comfy/ldm/modules/image_degradation/utils/test.png
Normal file
BIN
comfy/ldm/modules/image_degradation/utils/test.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 431 KiB |
916
comfy/ldm/modules/image_degradation/utils_image.py
Normal file
916
comfy/ldm/modules/image_degradation/utils_image.py
Normal file
@@ -0,0 +1,916 @@
|
||||
import os
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
from torchvision.utils import make_grid
|
||||
from datetime import datetime
|
||||
#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
||||
|
||||
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# Kai Zhang (github: https://github.com/cszn)
|
||||
# 03/Mar/2019
|
||||
# --------------------------------------------
|
||||
# https://github.com/twhui/SRGAN-pyTorch
|
||||
# https://github.com/xinntao/BasicSR
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
return datetime.now().strftime('%y%m%d-%H%M%S')
|
||||
|
||||
|
||||
def imshow(x, title=None, cbar=False, figsize=None):
|
||||
plt.figure(figsize=figsize)
|
||||
plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
|
||||
if title:
|
||||
plt.title(title)
|
||||
if cbar:
|
||||
plt.colorbar()
|
||||
plt.show()
|
||||
|
||||
|
||||
def surf(Z, cmap='rainbow', figsize=None):
|
||||
plt.figure(figsize=figsize)
|
||||
ax3 = plt.axes(projection='3d')
|
||||
|
||||
w, h = Z.shape[:2]
|
||||
xx = np.arange(0,w,1)
|
||||
yy = np.arange(0,h,1)
|
||||
X, Y = np.meshgrid(xx, yy)
|
||||
ax3.plot_surface(X,Y,Z,cmap=cmap)
|
||||
#ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
|
||||
plt.show()
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# get image pathes
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def get_image_paths(dataroot):
|
||||
paths = None # return None if dataroot is None
|
||||
if dataroot is not None:
|
||||
paths = sorted(_get_paths_from_images(dataroot))
|
||||
return paths
|
||||
|
||||
|
||||
def _get_paths_from_images(path):
|
||||
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
|
||||
images = []
|
||||
for dirpath, _, fnames in sorted(os.walk(path)):
|
||||
for fname in sorted(fnames):
|
||||
if is_image_file(fname):
|
||||
img_path = os.path.join(dirpath, fname)
|
||||
images.append(img_path)
|
||||
assert images, '{:s} has no valid image file'.format(path)
|
||||
return images
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# split large images into small images
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
|
||||
w, h = img.shape[:2]
|
||||
patches = []
|
||||
if w > p_max and h > p_max:
|
||||
w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
|
||||
h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
|
||||
w1.append(w-p_size)
|
||||
h1.append(h-p_size)
|
||||
# print(w1)
|
||||
# print(h1)
|
||||
for i in w1:
|
||||
for j in h1:
|
||||
patches.append(img[i:i+p_size, j:j+p_size,:])
|
||||
else:
|
||||
patches.append(img)
|
||||
|
||||
return patches
|
||||
|
||||
|
||||
def imssave(imgs, img_path):
|
||||
"""
|
||||
imgs: list, N images of size WxHxC
|
||||
"""
|
||||
img_name, ext = os.path.splitext(os.path.basename(img_path))
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
if img.ndim == 3:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
|
||||
cv2.imwrite(new_path, img)
|
||||
|
||||
|
||||
def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
|
||||
"""
|
||||
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
|
||||
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
|
||||
will be splitted.
|
||||
Args:
|
||||
original_dataroot:
|
||||
taget_dataroot:
|
||||
p_size: size of small images
|
||||
p_overlap: patch size in training is a good choice
|
||||
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
|
||||
"""
|
||||
paths = get_image_paths(original_dataroot)
|
||||
for img_path in paths:
|
||||
# img_name, ext = os.path.splitext(os.path.basename(img_path))
|
||||
img = imread_uint(img_path, n_channels=n_channels)
|
||||
patches = patches_from_image(img, p_size, p_overlap, p_max)
|
||||
imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
|
||||
#if original_dataroot == taget_dataroot:
|
||||
#del img_path
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# makedir
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def mkdirs(paths):
|
||||
if isinstance(paths, str):
|
||||
mkdir(paths)
|
||||
else:
|
||||
for path in paths:
|
||||
mkdir(path)
|
||||
|
||||
|
||||
def mkdir_and_rename(path):
|
||||
if os.path.exists(path):
|
||||
new_name = path + '_archived_' + get_timestamp()
|
||||
print('Path already exists. Rename it to [{:s}]'.format(new_name))
|
||||
os.rename(path, new_name)
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# read image from path
|
||||
# opencv is fast, but read BGR numpy image
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# get uint8 image of size HxWxn_channles (RGB)
|
||||
# --------------------------------------------
|
||||
def imread_uint(path, n_channels=3):
|
||||
# input: path
|
||||
# output: HxWx3(RGB or GGG), or HxWx1 (G)
|
||||
if n_channels == 1:
|
||||
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
|
||||
img = np.expand_dims(img, axis=2) # HxWx1
|
||||
elif n_channels == 3:
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
|
||||
if img.ndim == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
|
||||
else:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
|
||||
return img
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# matlab's imwrite
|
||||
# --------------------------------------------
|
||||
def imsave(img, img_path):
|
||||
img = np.squeeze(img)
|
||||
if img.ndim == 3:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
cv2.imwrite(img_path, img)
|
||||
|
||||
def imwrite(img, img_path):
|
||||
img = np.squeeze(img)
|
||||
if img.ndim == 3:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
cv2.imwrite(img_path, img)
|
||||
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# get single image of size HxWxn_channles (BGR)
|
||||
# --------------------------------------------
|
||||
def read_img(path):
|
||||
# read image by cv2
|
||||
# return: Numpy float32, HWC, BGR, [0,1]
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
|
||||
img = img.astype(np.float32) / 255.
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
# some images have 4 channels
|
||||
if img.shape[2] > 3:
|
||||
img = img[:, :, :3]
|
||||
return img
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# image format conversion
|
||||
# --------------------------------------------
|
||||
# numpy(single) <---> numpy(unit)
|
||||
# numpy(single) <---> tensor
|
||||
# numpy(unit) <---> tensor
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# numpy(single) [0, 1] <---> numpy(unit)
|
||||
# --------------------------------------------
|
||||
|
||||
|
||||
def uint2single(img):
|
||||
|
||||
return np.float32(img/255.)
|
||||
|
||||
|
||||
def single2uint(img):
|
||||
|
||||
return np.uint8((img.clip(0, 1)*255.).round())
|
||||
|
||||
|
||||
def uint162single(img):
|
||||
|
||||
return np.float32(img/65535.)
|
||||
|
||||
|
||||
def single2uint16(img):
|
||||
|
||||
return np.uint16((img.clip(0, 1)*65535.).round())
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# numpy(unit) (HxWxC or HxW) <---> tensor
|
||||
# --------------------------------------------
|
||||
|
||||
|
||||
# convert uint to 4-dimensional torch tensor
|
||||
def uint2tensor4(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
|
||||
|
||||
|
||||
# convert uint to 3-dimensional torch tensor
|
||||
def uint2tensor3(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
|
||||
|
||||
|
||||
# convert 2/3/4-dimensional torch tensor to uint
|
||||
def tensor2uint(img):
|
||||
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
return np.uint8((img*255.0).round())
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# numpy(single) (HxWxC) <---> tensor
|
||||
# --------------------------------------------
|
||||
|
||||
|
||||
# convert single (HxWxC) to 3-dimensional torch tensor
|
||||
def single2tensor3(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
|
||||
|
||||
|
||||
# convert single (HxWxC) to 4-dimensional torch tensor
|
||||
def single2tensor4(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
|
||||
|
||||
|
||||
# convert torch tensor to single
|
||||
def tensor2single(img):
|
||||
img = img.data.squeeze().float().cpu().numpy()
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
|
||||
return img
|
||||
|
||||
# convert torch tensor to single
|
||||
def tensor2single3(img):
|
||||
img = img.data.squeeze().float().cpu().numpy()
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
elif img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return img
|
||||
|
||||
|
||||
def single2tensor5(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
|
||||
|
||||
|
||||
def single32tensor5(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
|
||||
|
||||
|
||||
def single42tensor4(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
|
||||
|
||||
|
||||
# from skimage.io import imread, imsave
|
||||
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
||||
'''
|
||||
Converts a torch Tensor into an image Numpy array of BGR channel order
|
||||
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
|
||||
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
|
||||
'''
|
||||
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
|
||||
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
|
||||
n_dim = tensor.dim()
|
||||
if n_dim == 4:
|
||||
n_img = len(tensor)
|
||||
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
|
||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
||||
elif n_dim == 3:
|
||||
img_np = tensor.numpy()
|
||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
||||
elif n_dim == 2:
|
||||
img_np = tensor.numpy()
|
||||
else:
|
||||
raise TypeError(
|
||||
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
|
||||
if out_type == np.uint8:
|
||||
img_np = (img_np * 255.0).round()
|
||||
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
|
||||
return img_np.astype(out_type)
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# Augmentation, flipe and/or rotate
|
||||
# --------------------------------------------
|
||||
# The following two are enough.
|
||||
# (1) augmet_img: numpy image of WxHxC or WxH
|
||||
# (2) augment_img_tensor4: tensor image 1xCxWxH
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def augment_img(img, mode=0):
|
||||
'''Kai Zhang (github: https://github.com/cszn)
|
||||
'''
|
||||
if mode == 0:
|
||||
return img
|
||||
elif mode == 1:
|
||||
return np.flipud(np.rot90(img))
|
||||
elif mode == 2:
|
||||
return np.flipud(img)
|
||||
elif mode == 3:
|
||||
return np.rot90(img, k=3)
|
||||
elif mode == 4:
|
||||
return np.flipud(np.rot90(img, k=2))
|
||||
elif mode == 5:
|
||||
return np.rot90(img)
|
||||
elif mode == 6:
|
||||
return np.rot90(img, k=2)
|
||||
elif mode == 7:
|
||||
return np.flipud(np.rot90(img, k=3))
|
||||
|
||||
|
||||
def augment_img_tensor4(img, mode=0):
|
||||
'''Kai Zhang (github: https://github.com/cszn)
|
||||
'''
|
||||
if mode == 0:
|
||||
return img
|
||||
elif mode == 1:
|
||||
return img.rot90(1, [2, 3]).flip([2])
|
||||
elif mode == 2:
|
||||
return img.flip([2])
|
||||
elif mode == 3:
|
||||
return img.rot90(3, [2, 3])
|
||||
elif mode == 4:
|
||||
return img.rot90(2, [2, 3]).flip([2])
|
||||
elif mode == 5:
|
||||
return img.rot90(1, [2, 3])
|
||||
elif mode == 6:
|
||||
return img.rot90(2, [2, 3])
|
||||
elif mode == 7:
|
||||
return img.rot90(3, [2, 3]).flip([2])
|
||||
|
||||
|
||||
def augment_img_tensor(img, mode=0):
|
||||
'''Kai Zhang (github: https://github.com/cszn)
|
||||
'''
|
||||
img_size = img.size()
|
||||
img_np = img.data.cpu().numpy()
|
||||
if len(img_size) == 3:
|
||||
img_np = np.transpose(img_np, (1, 2, 0))
|
||||
elif len(img_size) == 4:
|
||||
img_np = np.transpose(img_np, (2, 3, 1, 0))
|
||||
img_np = augment_img(img_np, mode=mode)
|
||||
img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
|
||||
if len(img_size) == 3:
|
||||
img_tensor = img_tensor.permute(2, 0, 1)
|
||||
elif len(img_size) == 4:
|
||||
img_tensor = img_tensor.permute(3, 2, 0, 1)
|
||||
|
||||
return img_tensor.type_as(img)
|
||||
|
||||
|
||||
def augment_img_np3(img, mode=0):
|
||||
if mode == 0:
|
||||
return img
|
||||
elif mode == 1:
|
||||
return img.transpose(1, 0, 2)
|
||||
elif mode == 2:
|
||||
return img[::-1, :, :]
|
||||
elif mode == 3:
|
||||
img = img[::-1, :, :]
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
elif mode == 4:
|
||||
return img[:, ::-1, :]
|
||||
elif mode == 5:
|
||||
img = img[:, ::-1, :]
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
elif mode == 6:
|
||||
img = img[:, ::-1, :]
|
||||
img = img[::-1, :, :]
|
||||
return img
|
||||
elif mode == 7:
|
||||
img = img[:, ::-1, :]
|
||||
img = img[::-1, :, :]
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
|
||||
def augment_imgs(img_list, hflip=True, rot=True):
|
||||
# horizontal flip OR rotate
|
||||
hflip = hflip and random.random() < 0.5
|
||||
vflip = rot and random.random() < 0.5
|
||||
rot90 = rot and random.random() < 0.5
|
||||
|
||||
def _augment(img):
|
||||
if hflip:
|
||||
img = img[:, ::-1, :]
|
||||
if vflip:
|
||||
img = img[::-1, :, :]
|
||||
if rot90:
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
return [_augment(img) for img in img_list]
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# modcrop and shave
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def modcrop(img_in, scale):
|
||||
# img_in: Numpy, HWC or HW
|
||||
img = np.copy(img_in)
|
||||
if img.ndim == 2:
|
||||
H, W = img.shape
|
||||
H_r, W_r = H % scale, W % scale
|
||||
img = img[:H - H_r, :W - W_r]
|
||||
elif img.ndim == 3:
|
||||
H, W, C = img.shape
|
||||
H_r, W_r = H % scale, W % scale
|
||||
img = img[:H - H_r, :W - W_r, :]
|
||||
else:
|
||||
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
|
||||
return img
|
||||
|
||||
|
||||
def shave(img_in, border=0):
|
||||
# img_in: Numpy, HWC or HW
|
||||
img = np.copy(img_in)
|
||||
h, w = img.shape[:2]
|
||||
img = img[border:h-border, border:w-border]
|
||||
return img
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# image processing process on numpy image
|
||||
# channel_convert(in_c, tar_type, img_list):
|
||||
# rgb2ycbcr(img, only_y=True):
|
||||
# bgr2ycbcr(img, only_y=True):
|
||||
# ycbcr2rgb(img):
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def rgb2ycbcr(img, only_y=True):
|
||||
'''same as matlab rgb2ycbcr
|
||||
only_y: only return Y channel
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
'''
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.
|
||||
# convert
|
||||
if only_y:
|
||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
||||
else:
|
||||
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
|
||||
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def ycbcr2rgb(img):
|
||||
'''same as matlab ycbcr2rgb
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
'''
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.
|
||||
# convert
|
||||
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
||||
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def bgr2ycbcr(img, only_y=True):
|
||||
'''bgr version of rgb2ycbcr
|
||||
only_y: only return Y channel
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
'''
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.
|
||||
# convert
|
||||
if only_y:
|
||||
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
||||
else:
|
||||
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
||||
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def channel_convert(in_c, tar_type, img_list):
|
||||
# conversion among BGR, gray and y
|
||||
if in_c == 3 and tar_type == 'gray': # BGR to gray
|
||||
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
||||
return [np.expand_dims(img, axis=2) for img in gray_list]
|
||||
elif in_c == 3 and tar_type == 'y': # BGR to y
|
||||
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
||||
return [np.expand_dims(img, axis=2) for img in y_list]
|
||||
elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
|
||||
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
||||
else:
|
||||
return img_list
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# metric, PSNR and SSIM
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# PSNR
|
||||
# --------------------------------------------
|
||||
def calculate_psnr(img1, img2, border=0):
|
||||
# img1 and img2 have range [0, 255]
|
||||
#img1 = img1.squeeze()
|
||||
#img2 = img2.squeeze()
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError('Input images must have the same dimensions.')
|
||||
h, w = img1.shape[:2]
|
||||
img1 = img1[border:h-border, border:w-border]
|
||||
img2 = img2[border:h-border, border:w-border]
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
mse = np.mean((img1 - img2)**2)
|
||||
if mse == 0:
|
||||
return float('inf')
|
||||
return 20 * math.log10(255.0 / math.sqrt(mse))
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# SSIM
|
||||
# --------------------------------------------
|
||||
def calculate_ssim(img1, img2, border=0):
|
||||
'''calculate SSIM
|
||||
the same outputs as MATLAB's
|
||||
img1, img2: [0, 255]
|
||||
'''
|
||||
#img1 = img1.squeeze()
|
||||
#img2 = img2.squeeze()
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError('Input images must have the same dimensions.')
|
||||
h, w = img1.shape[:2]
|
||||
img1 = img1[border:h-border, border:w-border]
|
||||
img2 = img2[border:h-border, border:w-border]
|
||||
|
||||
if img1.ndim == 2:
|
||||
return ssim(img1, img2)
|
||||
elif img1.ndim == 3:
|
||||
if img1.shape[2] == 3:
|
||||
ssims = []
|
||||
for i in range(3):
|
||||
ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
|
||||
return np.array(ssims).mean()
|
||||
elif img1.shape[2] == 1:
|
||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
||||
else:
|
||||
raise ValueError('Wrong input image dimensions.')
|
||||
|
||||
|
||||
def ssim(img1, img2):
|
||||
C1 = (0.01 * 255)**2
|
||||
C2 = (0.03 * 255)**2
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
|
||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1**2
|
||||
mu2_sq = mu2**2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
||||
(sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# matlab's bicubic imresize (numpy and torch) [0, 1]
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
# matlab 'imresize' function, now only support 'bicubic'
|
||||
def cubic(x):
|
||||
absx = torch.abs(x)
|
||||
absx2 = absx**2
|
||||
absx3 = absx**3
|
||||
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
|
||||
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
|
||||
|
||||
|
||||
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
||||
if (scale < 1) and (antialiasing):
|
||||
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
||||
kernel_width = kernel_width / scale
|
||||
|
||||
# Output-space coordinates
|
||||
x = torch.linspace(1, out_length, out_length)
|
||||
|
||||
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
||||
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
||||
# space maps to 1.5 in input space.
|
||||
u = x / scale + 0.5 * (1 - 1 / scale)
|
||||
|
||||
# What is the left-most pixel that can be involved in the computation?
|
||||
left = torch.floor(u - kernel_width / 2)
|
||||
|
||||
# What is the maximum number of pixels that can be involved in the
|
||||
# computation? Note: it's OK to use an extra pixel here; if the
|
||||
# corresponding weights are all zero, it will be eliminated at the end
|
||||
# of this function.
|
||||
P = math.ceil(kernel_width) + 2
|
||||
|
||||
# The indices of the input pixels involved in computing the k-th output
|
||||
# pixel are in row k of the indices matrix.
|
||||
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
|
||||
1, P).expand(out_length, P)
|
||||
|
||||
# The weights used to compute the k-th output pixel are in row k of the
|
||||
# weights matrix.
|
||||
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
||||
# apply cubic kernel
|
||||
if (scale < 1) and (antialiasing):
|
||||
weights = scale * cubic(distance_to_center * scale)
|
||||
else:
|
||||
weights = cubic(distance_to_center)
|
||||
# Normalize the weights matrix so that each row sums to 1.
|
||||
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
||||
weights = weights / weights_sum.expand(out_length, P)
|
||||
|
||||
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
||||
weights_zero_tmp = torch.sum((weights == 0), 0)
|
||||
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 1, P - 2)
|
||||
weights = weights.narrow(1, 1, P - 2)
|
||||
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 0, P - 2)
|
||||
weights = weights.narrow(1, 0, P - 2)
|
||||
weights = weights.contiguous()
|
||||
indices = indices.contiguous()
|
||||
sym_len_s = -indices.min() + 1
|
||||
sym_len_e = indices.max() - in_length
|
||||
indices = indices + sym_len_s - 1
|
||||
return weights, indices, int(sym_len_s), int(sym_len_e)
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# imresize for tensor image [0, 1]
|
||||
# --------------------------------------------
|
||||
def imresize(img, scale, antialiasing=True):
|
||||
# Now the scale should be the same for H and W
|
||||
# input: img: pytorch tensor, CHW or HW [0,1]
|
||||
# output: CHW or HW [0,1] w/o round
|
||||
need_squeeze = True if img.dim() == 2 else False
|
||||
if need_squeeze:
|
||||
img.unsqueeze_(0)
|
||||
in_C, in_H, in_W = img.size()
|
||||
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
||||
kernel_width = 4
|
||||
kernel = 'cubic'
|
||||
|
||||
# Return the desired dimension order for performing the resize. The
|
||||
# strategy is to perform the resize first along the dimension with the
|
||||
# smallest scale factor.
|
||||
# Now we do not support this.
|
||||
|
||||
# get weights and indices
|
||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
||||
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
||||
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
||||
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
||||
|
||||
sym_patch = img[:, :sym_len_Hs, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = img[:, -sym_len_He:, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
||||
|
||||
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
||||
kernel_width = weights_H.size(1)
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
for j in range(out_C):
|
||||
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
||||
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
||||
|
||||
sym_patch = out_1[:, :, :sym_len_Ws]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = out_1[:, :, -sym_len_We:]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
||||
|
||||
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
||||
kernel_width = weights_W.size(1)
|
||||
for i in range(out_W):
|
||||
idx = int(indices_W[i][0])
|
||||
for j in range(out_C):
|
||||
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
|
||||
if need_squeeze:
|
||||
out_2.squeeze_()
|
||||
return out_2
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# imresize for numpy image [0, 1]
|
||||
# --------------------------------------------
|
||||
def imresize_np(img, scale, antialiasing=True):
|
||||
# Now the scale should be the same for H and W
|
||||
# input: img: Numpy, HWC or HW [0,1]
|
||||
# output: HWC or HW [0,1] w/o round
|
||||
img = torch.from_numpy(img)
|
||||
need_squeeze = True if img.dim() == 2 else False
|
||||
if need_squeeze:
|
||||
img.unsqueeze_(2)
|
||||
|
||||
in_H, in_W, in_C = img.size()
|
||||
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
||||
kernel_width = 4
|
||||
kernel = 'cubic'
|
||||
|
||||
# Return the desired dimension order for performing the resize. The
|
||||
# strategy is to perform the resize first along the dimension with the
|
||||
# smallest scale factor.
|
||||
# Now we do not support this.
|
||||
|
||||
# get weights and indices
|
||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
||||
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
||||
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
||||
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
||||
|
||||
sym_patch = img[:sym_len_Hs, :, :]
|
||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
||||
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = img[-sym_len_He:, :, :]
|
||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
||||
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
||||
|
||||
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
||||
kernel_width = weights_H.size(1)
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
for j in range(out_C):
|
||||
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
||||
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
||||
|
||||
sym_patch = out_1[:, :sym_len_Ws, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = out_1[:, -sym_len_We:, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
||||
|
||||
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
||||
kernel_width = weights_W.size(1)
|
||||
for i in range(out_W):
|
||||
idx = int(indices_W[i][0])
|
||||
for j in range(out_C):
|
||||
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
|
||||
if need_squeeze:
|
||||
out_2.squeeze_()
|
||||
|
||||
return out_2.numpy()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('---')
|
||||
# img = imread_uint('test.bmp', 3)
|
||||
# img = uint2single(img)
|
||||
# img_bicubic = imresize_np(img, 1/4)
|
||||
0
comfy/ldm/modules/midas/__init__.py
Normal file
0
comfy/ldm/modules/midas/__init__.py
Normal file
170
comfy/ldm/modules/midas/api.py
Normal file
170
comfy/ldm/modules/midas/api.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# based on https://github.com/isl-org/MiDaS
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
|
||||
from ldm.modules.midas.midas.midas_net import MidasNet
|
||||
from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
|
||||
from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
|
||||
|
||||
|
||||
ISL_PATHS = {
|
||||
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
|
||||
"dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
|
||||
"midas_v21": "",
|
||||
"midas_v21_small": "",
|
||||
}
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
def load_midas_transform(model_type):
|
||||
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
||||
# load transform only
|
||||
if model_type == "dpt_large": # DPT-Large
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "midas_v21":
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "upper_bound"
|
||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
elif model_type == "midas_v21_small":
|
||||
net_w, net_h = 256, 256
|
||||
resize_mode = "upper_bound"
|
||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
else:
|
||||
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
||||
|
||||
transform = Compose(
|
||||
[
|
||||
Resize(
|
||||
net_w,
|
||||
net_h,
|
||||
resize_target=None,
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=32,
|
||||
resize_method=resize_mode,
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
normalization,
|
||||
PrepareForNet(),
|
||||
]
|
||||
)
|
||||
|
||||
return transform
|
||||
|
||||
|
||||
def load_model(model_type):
|
||||
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
||||
# load network
|
||||
model_path = ISL_PATHS[model_type]
|
||||
if model_type == "dpt_large": # DPT-Large
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone="vitl16_384",
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone="vitb_rn50_384",
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "midas_v21":
|
||||
model = MidasNet(model_path, non_negative=True)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "upper_bound"
|
||||
normalization = NormalizeImage(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
|
||||
elif model_type == "midas_v21_small":
|
||||
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
||||
non_negative=True, blocks={'expand': True})
|
||||
net_w, net_h = 256, 256
|
||||
resize_mode = "upper_bound"
|
||||
normalization = NormalizeImage(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
|
||||
else:
|
||||
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
||||
assert False
|
||||
|
||||
transform = Compose(
|
||||
[
|
||||
Resize(
|
||||
net_w,
|
||||
net_h,
|
||||
resize_target=None,
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=32,
|
||||
resize_method=resize_mode,
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
normalization,
|
||||
PrepareForNet(),
|
||||
]
|
||||
)
|
||||
|
||||
return model.eval(), transform
|
||||
|
||||
|
||||
class MiDaSInference(nn.Module):
|
||||
MODEL_TYPES_TORCH_HUB = [
|
||||
"DPT_Large",
|
||||
"DPT_Hybrid",
|
||||
"MiDaS_small"
|
||||
]
|
||||
MODEL_TYPES_ISL = [
|
||||
"dpt_large",
|
||||
"dpt_hybrid",
|
||||
"midas_v21",
|
||||
"midas_v21_small",
|
||||
]
|
||||
|
||||
def __init__(self, model_type):
|
||||
super().__init__()
|
||||
assert (model_type in self.MODEL_TYPES_ISL)
|
||||
model, _ = load_model(model_type)
|
||||
self.model = model
|
||||
self.model.train = disabled_train
|
||||
|
||||
def forward(self, x):
|
||||
# x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
|
||||
# NOTE: we expect that the correct transform has been called during dataloading.
|
||||
with torch.no_grad():
|
||||
prediction = self.model(x)
|
||||
prediction = torch.nn.functional.interpolate(
|
||||
prediction.unsqueeze(1),
|
||||
size=x.shape[2:],
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
|
||||
return prediction
|
||||
|
||||
0
comfy/ldm/modules/midas/midas/__init__.py
Normal file
0
comfy/ldm/modules/midas/midas/__init__.py
Normal file
16
comfy/ldm/modules/midas/midas/base_model.py
Normal file
16
comfy/ldm/modules/midas/midas/base_model.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import torch
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def load(self, path):
|
||||
"""Load model from file.
|
||||
|
||||
Args:
|
||||
path (str): file path
|
||||
"""
|
||||
parameters = torch.load(path, map_location=torch.device('cpu'))
|
||||
|
||||
if "optimizer" in parameters:
|
||||
parameters = parameters["model"]
|
||||
|
||||
self.load_state_dict(parameters)
|
||||
342
comfy/ldm/modules/midas/midas/blocks.py
Normal file
342
comfy/ldm/modules/midas/midas/blocks.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .vit import (
|
||||
_make_pretrained_vitb_rn50_384,
|
||||
_make_pretrained_vitl16_384,
|
||||
_make_pretrained_vitb16_384,
|
||||
forward_vit,
|
||||
)
|
||||
|
||||
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
||||
if backbone == "vitl16_384":
|
||||
pretrained = _make_pretrained_vitl16_384(
|
||||
use_pretrained, hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
||||
) # ViT-L/16 - 85.0% Top1 (backbone)
|
||||
elif backbone == "vitb_rn50_384":
|
||||
pretrained = _make_pretrained_vitb_rn50_384(
|
||||
use_pretrained,
|
||||
hooks=hooks,
|
||||
use_vit_only=use_vit_only,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[256, 512, 768, 768], features, groups=groups, expand=expand
|
||||
) # ViT-H/16 - 85.0% Top1 (backbone)
|
||||
elif backbone == "vitb16_384":
|
||||
pretrained = _make_pretrained_vitb16_384(
|
||||
use_pretrained, hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[96, 192, 384, 768], features, groups=groups, expand=expand
|
||||
) # ViT-B/16 - 84.6% Top1 (backbone)
|
||||
elif backbone == "resnext101_wsl":
|
||||
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
||||
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
||||
elif backbone == "efficientnet_lite3":
|
||||
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
||||
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
||||
else:
|
||||
print(f"Backbone '{backbone}' not implemented")
|
||||
assert False
|
||||
|
||||
return pretrained, scratch
|
||||
|
||||
|
||||
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
||||
scratch = nn.Module()
|
||||
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape
|
||||
out_shape3 = out_shape
|
||||
out_shape4 = out_shape
|
||||
if expand==True:
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape*2
|
||||
out_shape3 = out_shape*4
|
||||
out_shape4 = out_shape*8
|
||||
|
||||
scratch.layer1_rn = nn.Conv2d(
|
||||
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer2_rn = nn.Conv2d(
|
||||
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer3_rn = nn.Conv2d(
|
||||
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer4_rn = nn.Conv2d(
|
||||
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
|
||||
return scratch
|
||||
|
||||
|
||||
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
||||
efficientnet = torch.hub.load(
|
||||
"rwightman/gen-efficientnet-pytorch",
|
||||
"tf_efficientnet_lite3",
|
||||
pretrained=use_pretrained,
|
||||
exportable=exportable
|
||||
)
|
||||
return _make_efficientnet_backbone(efficientnet)
|
||||
|
||||
|
||||
def _make_efficientnet_backbone(effnet):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.layer1 = nn.Sequential(
|
||||
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
||||
)
|
||||
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
||||
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
||||
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_resnet_backbone(resnet):
|
||||
pretrained = nn.Module()
|
||||
pretrained.layer1 = nn.Sequential(
|
||||
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
||||
)
|
||||
|
||||
pretrained.layer2 = resnet.layer2
|
||||
pretrained.layer3 = resnet.layer3
|
||||
pretrained.layer4 = resnet.layer4
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_resnext101_wsl(use_pretrained):
|
||||
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
||||
return _make_resnet_backbone(resnet)
|
||||
|
||||
|
||||
|
||||
class Interpolate(nn.Module):
|
||||
"""Interpolation module.
|
||||
"""
|
||||
|
||||
def __init__(self, scale_factor, mode, align_corners=False):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
scale_factor (float): scaling
|
||||
mode (str): interpolation mode
|
||||
"""
|
||||
super(Interpolate, self).__init__()
|
||||
|
||||
self.interp = nn.functional.interpolate
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: interpolated data
|
||||
"""
|
||||
|
||||
x = self.interp(
|
||||
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
"""Residual convolution module.
|
||||
"""
|
||||
|
||||
def __init__(self, features):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
||||
)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
||||
)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
out = self.relu(x)
|
||||
out = self.conv1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
|
||||
return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
"""Feature fusion block.
|
||||
"""
|
||||
|
||||
def __init__(self, features):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock, self).__init__()
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit(features)
|
||||
self.resConfUnit2 = ResidualConvUnit(features)
|
||||
|
||||
def forward(self, *xs):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
output += self.resConfUnit1(xs[1])
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
output = nn.functional.interpolate(
|
||||
output, scale_factor=2, mode="bilinear", align_corners=True
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
||||
class ResidualConvUnit_custom(nn.Module):
|
||||
"""Residual convolution module.
|
||||
"""
|
||||
|
||||
def __init__(self, features, activation, bn):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.bn = bn
|
||||
|
||||
self.groups=1
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
||||
)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
||||
)
|
||||
|
||||
if self.bn==True:
|
||||
self.bn1 = nn.BatchNorm2d(features)
|
||||
self.bn2 = nn.BatchNorm2d(features)
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
if self.bn==True:
|
||||
out = self.bn1(out)
|
||||
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
if self.bn==True:
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.groups > 1:
|
||||
out = self.conv_merge(out)
|
||||
|
||||
return self.skip_add.add(out, x)
|
||||
|
||||
# return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock_custom(nn.Module):
|
||||
"""Feature fusion block.
|
||||
"""
|
||||
|
||||
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock_custom, self).__init__()
|
||||
|
||||
self.deconv = deconv
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.groups=1
|
||||
|
||||
self.expand = expand
|
||||
out_features = features
|
||||
if self.expand==True:
|
||||
out_features = features//2
|
||||
|
||||
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
||||
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, *xs):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
res = self.resConfUnit1(xs[1])
|
||||
output = self.skip_add.add(output, res)
|
||||
# output += res
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
output = nn.functional.interpolate(
|
||||
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
||||
)
|
||||
|
||||
output = self.out_conv(output)
|
||||
|
||||
return output
|
||||
|
||||
109
comfy/ldm/modules/midas/midas/dpt_depth.py
Normal file
109
comfy/ldm/modules/midas/midas/dpt_depth.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .blocks import (
|
||||
FeatureFusionBlock,
|
||||
FeatureFusionBlock_custom,
|
||||
Interpolate,
|
||||
_make_encoder,
|
||||
forward_vit,
|
||||
)
|
||||
|
||||
|
||||
def _make_fusion_block(features, use_bn):
|
||||
return FeatureFusionBlock_custom(
|
||||
features,
|
||||
nn.ReLU(False),
|
||||
deconv=False,
|
||||
bn=use_bn,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
|
||||
class DPT(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
head,
|
||||
features=256,
|
||||
backbone="vitb_rn50_384",
|
||||
readout="project",
|
||||
channels_last=False,
|
||||
use_bn=False,
|
||||
):
|
||||
|
||||
super(DPT, self).__init__()
|
||||
|
||||
self.channels_last = channels_last
|
||||
|
||||
hooks = {
|
||||
"vitb_rn50_384": [0, 1, 8, 11],
|
||||
"vitb16_384": [2, 5, 8, 11],
|
||||
"vitl16_384": [5, 11, 17, 23],
|
||||
}
|
||||
|
||||
# Instantiate backbone and reassemble blocks
|
||||
self.pretrained, self.scratch = _make_encoder(
|
||||
backbone,
|
||||
features,
|
||||
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
||||
groups=1,
|
||||
expand=False,
|
||||
exportable=False,
|
||||
hooks=hooks[backbone],
|
||||
use_readout=readout,
|
||||
)
|
||||
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
||||
|
||||
self.scratch.output_conv = head
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
if self.channels_last == True:
|
||||
x.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv(path_1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DPTDepthModel(DPT):
|
||||
def __init__(self, path=None, non_negative=True, **kwargs):
|
||||
features = kwargs["features"] if "features" in kwargs else 256
|
||||
|
||||
head = nn.Sequential(
|
||||
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
||||
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
||||
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
super().__init__(head, **kwargs)
|
||||
|
||||
if path is not None:
|
||||
self.load(path)
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x).squeeze(dim=1)
|
||||
|
||||
76
comfy/ldm/modules/midas/midas/midas_net.py
Normal file
76
comfy/ldm/modules/midas/midas/midas_net.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
||||
This file contains code that is adapted from
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
||||
|
||||
|
||||
class MidasNet(BaseModel):
|
||||
"""Network for monocular depth estimation.
|
||||
"""
|
||||
|
||||
def __init__(self, path=None, features=256, non_negative=True):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
path (str, optional): Path to saved model. Defaults to None.
|
||||
features (int, optional): Number of features. Defaults to 256.
|
||||
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
||||
"""
|
||||
print("Loading weights: ", path)
|
||||
|
||||
super(MidasNet, self).__init__()
|
||||
|
||||
use_pretrained = False if path is None else True
|
||||
|
||||
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
||||
|
||||
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
||||
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
||||
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
||||
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
||||
|
||||
self.scratch.output_conv = nn.Sequential(
|
||||
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
||||
Interpolate(scale_factor=2, mode="bilinear"),
|
||||
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||
)
|
||||
|
||||
if path:
|
||||
self.load(path)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input data (image)
|
||||
|
||||
Returns:
|
||||
tensor: depth
|
||||
"""
|
||||
|
||||
layer_1 = self.pretrained.layer1(x)
|
||||
layer_2 = self.pretrained.layer2(layer_1)
|
||||
layer_3 = self.pretrained.layer3(layer_2)
|
||||
layer_4 = self.pretrained.layer4(layer_3)
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv(path_1)
|
||||
|
||||
return torch.squeeze(out, dim=1)
|
||||
128
comfy/ldm/modules/midas/midas/midas_net_custom.py
Normal file
128
comfy/ldm/modules/midas/midas/midas_net_custom.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
||||
This file contains code that is adapted from
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
||||
|
||||
|
||||
class MidasNet_small(BaseModel):
|
||||
"""Network for monocular depth estimation.
|
||||
"""
|
||||
|
||||
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
||||
blocks={'expand': True}):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
path (str, optional): Path to saved model. Defaults to None.
|
||||
features (int, optional): Number of features. Defaults to 256.
|
||||
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
||||
"""
|
||||
print("Loading weights: ", path)
|
||||
|
||||
super(MidasNet_small, self).__init__()
|
||||
|
||||
use_pretrained = False if path else True
|
||||
|
||||
self.channels_last = channels_last
|
||||
self.blocks = blocks
|
||||
self.backbone = backbone
|
||||
|
||||
self.groups = 1
|
||||
|
||||
features1=features
|
||||
features2=features
|
||||
features3=features
|
||||
features4=features
|
||||
self.expand = False
|
||||
if "expand" in self.blocks and self.blocks['expand'] == True:
|
||||
self.expand = True
|
||||
features1=features
|
||||
features2=features*2
|
||||
features3=features*4
|
||||
features4=features*8
|
||||
|
||||
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
||||
|
||||
self.scratch.activation = nn.ReLU(False)
|
||||
|
||||
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
||||
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
||||
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
||||
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
||||
|
||||
|
||||
self.scratch.output_conv = nn.Sequential(
|
||||
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
||||
Interpolate(scale_factor=2, mode="bilinear"),
|
||||
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
||||
self.scratch.activation,
|
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
if path:
|
||||
self.load(path)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input data (image)
|
||||
|
||||
Returns:
|
||||
tensor: depth
|
||||
"""
|
||||
if self.channels_last==True:
|
||||
print("self.channels_last = ", self.channels_last)
|
||||
x.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
|
||||
layer_1 = self.pretrained.layer1(x)
|
||||
layer_2 = self.pretrained.layer2(layer_1)
|
||||
layer_3 = self.pretrained.layer3(layer_2)
|
||||
layer_4 = self.pretrained.layer4(layer_3)
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv(path_1)
|
||||
|
||||
return torch.squeeze(out, dim=1)
|
||||
|
||||
|
||||
|
||||
def fuse_model(m):
|
||||
prev_previous_type = nn.Identity()
|
||||
prev_previous_name = ''
|
||||
previous_type = nn.Identity()
|
||||
previous_name = ''
|
||||
for name, module in m.named_modules():
|
||||
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
||||
# print("FUSED ", prev_previous_name, previous_name, name)
|
||||
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
||||
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
||||
# print("FUSED ", prev_previous_name, previous_name)
|
||||
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
||||
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
||||
# print("FUSED ", previous_name, name)
|
||||
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
||||
|
||||
prev_previous_type = previous_type
|
||||
prev_previous_name = previous_name
|
||||
previous_type = type(module)
|
||||
previous_name = name
|
||||
234
comfy/ldm/modules/midas/midas/transforms.py
Normal file
234
comfy/ldm/modules/midas/midas/transforms.py
Normal file
@@ -0,0 +1,234 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import math
|
||||
|
||||
|
||||
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
||||
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
||||
|
||||
Args:
|
||||
sample (dict): sample
|
||||
size (tuple): image size
|
||||
|
||||
Returns:
|
||||
tuple: new size
|
||||
"""
|
||||
shape = list(sample["disparity"].shape)
|
||||
|
||||
if shape[0] >= size[0] and shape[1] >= size[1]:
|
||||
return sample
|
||||
|
||||
scale = [0, 0]
|
||||
scale[0] = size[0] / shape[0]
|
||||
scale[1] = size[1] / shape[1]
|
||||
|
||||
scale = max(scale)
|
||||
|
||||
shape[0] = math.ceil(scale * shape[0])
|
||||
shape[1] = math.ceil(scale * shape[1])
|
||||
|
||||
# resize
|
||||
sample["image"] = cv2.resize(
|
||||
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
||||
)
|
||||
|
||||
sample["disparity"] = cv2.resize(
|
||||
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
||||
)
|
||||
sample["mask"] = cv2.resize(
|
||||
sample["mask"].astype(np.float32),
|
||||
tuple(shape[::-1]),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
sample["mask"] = sample["mask"].astype(bool)
|
||||
|
||||
return tuple(shape)
|
||||
|
||||
|
||||
class Resize(object):
|
||||
"""Resize sample to given size (width, height).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
height,
|
||||
resize_target=True,
|
||||
keep_aspect_ratio=False,
|
||||
ensure_multiple_of=1,
|
||||
resize_method="lower_bound",
|
||||
image_interpolation_method=cv2.INTER_AREA,
|
||||
):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
width (int): desired output width
|
||||
height (int): desired output height
|
||||
resize_target (bool, optional):
|
||||
True: Resize the full sample (image, mask, target).
|
||||
False: Resize image only.
|
||||
Defaults to True.
|
||||
keep_aspect_ratio (bool, optional):
|
||||
True: Keep the aspect ratio of the input sample.
|
||||
Output sample might not have the given width and height, and
|
||||
resize behaviour depends on the parameter 'resize_method'.
|
||||
Defaults to False.
|
||||
ensure_multiple_of (int, optional):
|
||||
Output width and height is constrained to be multiple of this parameter.
|
||||
Defaults to 1.
|
||||
resize_method (str, optional):
|
||||
"lower_bound": Output will be at least as large as the given size.
|
||||
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
||||
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
||||
Defaults to "lower_bound".
|
||||
"""
|
||||
self.__width = width
|
||||
self.__height = height
|
||||
|
||||
self.__resize_target = resize_target
|
||||
self.__keep_aspect_ratio = keep_aspect_ratio
|
||||
self.__multiple_of = ensure_multiple_of
|
||||
self.__resize_method = resize_method
|
||||
self.__image_interpolation_method = image_interpolation_method
|
||||
|
||||
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
||||
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
if max_val is not None and y > max_val:
|
||||
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
if y < min_val:
|
||||
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
return y
|
||||
|
||||
def get_size(self, width, height):
|
||||
# determine new height and width
|
||||
scale_height = self.__height / height
|
||||
scale_width = self.__width / width
|
||||
|
||||
if self.__keep_aspect_ratio:
|
||||
if self.__resize_method == "lower_bound":
|
||||
# scale such that output size is lower bound
|
||||
if scale_width > scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == "upper_bound":
|
||||
# scale such that output size is upper bound
|
||||
if scale_width < scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == "minimal":
|
||||
# scale as least as possbile
|
||||
if abs(1 - scale_width) < abs(1 - scale_height):
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
else:
|
||||
raise ValueError(
|
||||
f"resize_method {self.__resize_method} not implemented"
|
||||
)
|
||||
|
||||
if self.__resize_method == "lower_bound":
|
||||
new_height = self.constrain_to_multiple_of(
|
||||
scale_height * height, min_val=self.__height
|
||||
)
|
||||
new_width = self.constrain_to_multiple_of(
|
||||
scale_width * width, min_val=self.__width
|
||||
)
|
||||
elif self.__resize_method == "upper_bound":
|
||||
new_height = self.constrain_to_multiple_of(
|
||||
scale_height * height, max_val=self.__height
|
||||
)
|
||||
new_width = self.constrain_to_multiple_of(
|
||||
scale_width * width, max_val=self.__width
|
||||
)
|
||||
elif self.__resize_method == "minimal":
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width)
|
||||
else:
|
||||
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||
|
||||
return (new_width, new_height)
|
||||
|
||||
def __call__(self, sample):
|
||||
width, height = self.get_size(
|
||||
sample["image"].shape[1], sample["image"].shape[0]
|
||||
)
|
||||
|
||||
# resize sample
|
||||
sample["image"] = cv2.resize(
|
||||
sample["image"],
|
||||
(width, height),
|
||||
interpolation=self.__image_interpolation_method,
|
||||
)
|
||||
|
||||
if self.__resize_target:
|
||||
if "disparity" in sample:
|
||||
sample["disparity"] = cv2.resize(
|
||||
sample["disparity"],
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
|
||||
if "depth" in sample:
|
||||
sample["depth"] = cv2.resize(
|
||||
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
||||
)
|
||||
|
||||
sample["mask"] = cv2.resize(
|
||||
sample["mask"].astype(np.float32),
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
sample["mask"] = sample["mask"].astype(bool)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
"""Normlize image by given mean and std.
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.__mean = mean
|
||||
self.__std = std
|
||||
|
||||
def __call__(self, sample):
|
||||
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class PrepareForNet(object):
|
||||
"""Prepare sample for usage as network input.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, sample):
|
||||
image = np.transpose(sample["image"], (2, 0, 1))
|
||||
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
||||
|
||||
if "mask" in sample:
|
||||
sample["mask"] = sample["mask"].astype(np.float32)
|
||||
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
||||
|
||||
if "disparity" in sample:
|
||||
disparity = sample["disparity"].astype(np.float32)
|
||||
sample["disparity"] = np.ascontiguousarray(disparity)
|
||||
|
||||
if "depth" in sample:
|
||||
depth = sample["depth"].astype(np.float32)
|
||||
sample["depth"] = np.ascontiguousarray(depth)
|
||||
|
||||
return sample
|
||||
491
comfy/ldm/modules/midas/midas/vit.py
Normal file
491
comfy/ldm/modules/midas/midas/vit.py
Normal file
@@ -0,0 +1,491 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import timm
|
||||
import types
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Slice(nn.Module):
|
||||
def __init__(self, start_index=1):
|
||||
super(Slice, self).__init__()
|
||||
self.start_index = start_index
|
||||
|
||||
def forward(self, x):
|
||||
return x[:, self.start_index :]
|
||||
|
||||
|
||||
class AddReadout(nn.Module):
|
||||
def __init__(self, start_index=1):
|
||||
super(AddReadout, self).__init__()
|
||||
self.start_index = start_index
|
||||
|
||||
def forward(self, x):
|
||||
if self.start_index == 2:
|
||||
readout = (x[:, 0] + x[:, 1]) / 2
|
||||
else:
|
||||
readout = x[:, 0]
|
||||
return x[:, self.start_index :] + readout.unsqueeze(1)
|
||||
|
||||
|
||||
class ProjectReadout(nn.Module):
|
||||
def __init__(self, in_features, start_index=1):
|
||||
super(ProjectReadout, self).__init__()
|
||||
self.start_index = start_index
|
||||
|
||||
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
||||
|
||||
def forward(self, x):
|
||||
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
||||
features = torch.cat((x[:, self.start_index :], readout), -1)
|
||||
|
||||
return self.project(features)
|
||||
|
||||
|
||||
class Transpose(nn.Module):
|
||||
def __init__(self, dim0, dim1):
|
||||
super(Transpose, self).__init__()
|
||||
self.dim0 = dim0
|
||||
self.dim1 = dim1
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(self.dim0, self.dim1)
|
||||
return x
|
||||
|
||||
|
||||
def forward_vit(pretrained, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
glob = pretrained.model.forward_flex(x)
|
||||
|
||||
layer_1 = pretrained.activations["1"]
|
||||
layer_2 = pretrained.activations["2"]
|
||||
layer_3 = pretrained.activations["3"]
|
||||
layer_4 = pretrained.activations["4"]
|
||||
|
||||
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
||||
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
||||
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
||||
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
||||
|
||||
unflatten = nn.Sequential(
|
||||
nn.Unflatten(
|
||||
2,
|
||||
torch.Size(
|
||||
[
|
||||
h // pretrained.model.patch_size[1],
|
||||
w // pretrained.model.patch_size[0],
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if layer_1.ndim == 3:
|
||||
layer_1 = unflatten(layer_1)
|
||||
if layer_2.ndim == 3:
|
||||
layer_2 = unflatten(layer_2)
|
||||
if layer_3.ndim == 3:
|
||||
layer_3 = unflatten(layer_3)
|
||||
if layer_4.ndim == 3:
|
||||
layer_4 = unflatten(layer_4)
|
||||
|
||||
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
||||
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
||||
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
||||
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
||||
|
||||
return layer_1, layer_2, layer_3, layer_4
|
||||
|
||||
|
||||
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
||||
posemb_tok, posemb_grid = (
|
||||
posemb[:, : self.start_index],
|
||||
posemb[0, self.start_index :],
|
||||
)
|
||||
|
||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||
|
||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
||||
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
||||
|
||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
||||
|
||||
return posemb
|
||||
|
||||
|
||||
def forward_flex(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
pos_embed = self._resize_pos_embed(
|
||||
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
||||
)
|
||||
|
||||
B = x.shape[0]
|
||||
|
||||
if hasattr(self.patch_embed, "backbone"):
|
||||
x = self.patch_embed.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||
|
||||
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
||||
|
||||
if getattr(self, "dist_token", None) is not None:
|
||||
cls_tokens = self.cls_token.expand(
|
||||
B, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
dist_token = self.dist_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
||||
else:
|
||||
cls_tokens = self.cls_token.expand(
|
||||
B, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
x = x + pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
activations = {}
|
||||
|
||||
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
activations[name] = output
|
||||
|
||||
return hook
|
||||
|
||||
|
||||
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
||||
if use_readout == "ignore":
|
||||
readout_oper = [Slice(start_index)] * len(features)
|
||||
elif use_readout == "add":
|
||||
readout_oper = [AddReadout(start_index)] * len(features)
|
||||
elif use_readout == "project":
|
||||
readout_oper = [
|
||||
ProjectReadout(vit_features, start_index) for out_feat in features
|
||||
]
|
||||
else:
|
||||
assert (
|
||||
False
|
||||
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
||||
|
||||
return readout_oper
|
||||
|
||||
|
||||
def _make_vit_b16_backbone(
|
||||
model,
|
||||
features=[96, 192, 384, 768],
|
||||
size=[384, 384],
|
||||
hooks=[2, 5, 8, 11],
|
||||
vit_features=768,
|
||||
use_readout="ignore",
|
||||
start_index=1,
|
||||
):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.model = model
|
||||
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
||||
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
||||
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
||||
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
||||
|
||||
pretrained.activations = activations
|
||||
|
||||
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
||||
|
||||
# 32, 48, 136, 384
|
||||
pretrained.act_postprocess1 = nn.Sequential(
|
||||
readout_oper[0],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[0],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[0],
|
||||
out_channels=features[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess2 = nn.Sequential(
|
||||
readout_oper[1],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[1],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[1],
|
||||
out_channels=features[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess3 = nn.Sequential(
|
||||
readout_oper[2],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[2],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess4 = nn.Sequential(
|
||||
readout_oper[3],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[3],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.Conv2d(
|
||||
in_channels=features[3],
|
||||
out_channels=features[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.model.start_index = start_index
|
||||
pretrained.model.patch_size = [16, 16]
|
||||
|
||||
# We inject this function into the VisionTransformer instances so that
|
||||
# we can use it with interpolated position embeddings without modifying the library source.
|
||||
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
||||
pretrained.model._resize_pos_embed = types.MethodType(
|
||||
_resize_pos_embed, pretrained.model
|
||||
)
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
||||
|
||||
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model,
|
||||
features=[256, 512, 1024, 1024],
|
||||
hooks=hooks,
|
||||
vit_features=1024,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
|
||||
|
||||
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
||||
|
||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
|
||||
|
||||
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
||||
|
||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
|
||||
|
||||
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model(
|
||||
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
||||
)
|
||||
|
||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model,
|
||||
features=[96, 192, 384, 768],
|
||||
hooks=hooks,
|
||||
use_readout=use_readout,
|
||||
start_index=2,
|
||||
)
|
||||
|
||||
|
||||
def _make_vit_b_rn50_backbone(
|
||||
model,
|
||||
features=[256, 512, 768, 768],
|
||||
size=[384, 384],
|
||||
hooks=[0, 1, 8, 11],
|
||||
vit_features=768,
|
||||
use_vit_only=False,
|
||||
use_readout="ignore",
|
||||
start_index=1,
|
||||
):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.model = model
|
||||
|
||||
if use_vit_only == True:
|
||||
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
||||
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
||||
else:
|
||||
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
||||
get_activation("1")
|
||||
)
|
||||
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
||||
get_activation("2")
|
||||
)
|
||||
|
||||
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
||||
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
||||
|
||||
pretrained.activations = activations
|
||||
|
||||
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
||||
|
||||
if use_vit_only == True:
|
||||
pretrained.act_postprocess1 = nn.Sequential(
|
||||
readout_oper[0],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[0],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[0],
|
||||
out_channels=features[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess2 = nn.Sequential(
|
||||
readout_oper[1],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[1],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[1],
|
||||
out_channels=features[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
else:
|
||||
pretrained.act_postprocess1 = nn.Sequential(
|
||||
nn.Identity(), nn.Identity(), nn.Identity()
|
||||
)
|
||||
pretrained.act_postprocess2 = nn.Sequential(
|
||||
nn.Identity(), nn.Identity(), nn.Identity()
|
||||
)
|
||||
|
||||
pretrained.act_postprocess3 = nn.Sequential(
|
||||
readout_oper[2],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[2],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess4 = nn.Sequential(
|
||||
readout_oper[3],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[3],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.Conv2d(
|
||||
in_channels=features[3],
|
||||
out_channels=features[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.model.start_index = start_index
|
||||
pretrained.model.patch_size = [16, 16]
|
||||
|
||||
# We inject this function into the VisionTransformer instances so that
|
||||
# we can use it with interpolated position embeddings without modifying the library source.
|
||||
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
||||
|
||||
# We inject this function into the VisionTransformer instances so that
|
||||
# we can use it with interpolated position embeddings without modifying the library source.
|
||||
pretrained.model._resize_pos_embed = types.MethodType(
|
||||
_resize_pos_embed, pretrained.model
|
||||
)
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_vitb_rn50_384(
|
||||
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
||||
):
|
||||
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
||||
|
||||
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
||||
return _make_vit_b_rn50_backbone(
|
||||
model,
|
||||
features=[256, 512, 768, 768],
|
||||
size=[384, 384],
|
||||
hooks=hooks,
|
||||
use_vit_only=use_vit_only,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
189
comfy/ldm/modules/midas/utils.py
Normal file
189
comfy/ldm/modules/midas/utils.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Utils for monoDepth."""
|
||||
import sys
|
||||
import re
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
|
||||
def read_pfm(path):
|
||||
"""Read pfm file.
|
||||
|
||||
Args:
|
||||
path (str): path to file
|
||||
|
||||
Returns:
|
||||
tuple: (data, scale)
|
||||
"""
|
||||
with open(path, "rb") as file:
|
||||
|
||||
color = None
|
||||
width = None
|
||||
height = None
|
||||
scale = None
|
||||
endian = None
|
||||
|
||||
header = file.readline().rstrip()
|
||||
if header.decode("ascii") == "PF":
|
||||
color = True
|
||||
elif header.decode("ascii") == "Pf":
|
||||
color = False
|
||||
else:
|
||||
raise Exception("Not a PFM file: " + path)
|
||||
|
||||
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
|
||||
if dim_match:
|
||||
width, height = list(map(int, dim_match.groups()))
|
||||
else:
|
||||
raise Exception("Malformed PFM header.")
|
||||
|
||||
scale = float(file.readline().decode("ascii").rstrip())
|
||||
if scale < 0:
|
||||
# little-endian
|
||||
endian = "<"
|
||||
scale = -scale
|
||||
else:
|
||||
# big-endian
|
||||
endian = ">"
|
||||
|
||||
data = np.fromfile(file, endian + "f")
|
||||
shape = (height, width, 3) if color else (height, width)
|
||||
|
||||
data = np.reshape(data, shape)
|
||||
data = np.flipud(data)
|
||||
|
||||
return data, scale
|
||||
|
||||
|
||||
def write_pfm(path, image, scale=1):
|
||||
"""Write pfm file.
|
||||
|
||||
Args:
|
||||
path (str): pathto file
|
||||
image (array): data
|
||||
scale (int, optional): Scale. Defaults to 1.
|
||||
"""
|
||||
|
||||
with open(path, "wb") as file:
|
||||
color = None
|
||||
|
||||
if image.dtype.name != "float32":
|
||||
raise Exception("Image dtype must be float32.")
|
||||
|
||||
image = np.flipud(image)
|
||||
|
||||
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
||||
color = True
|
||||
elif (
|
||||
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
|
||||
): # greyscale
|
||||
color = False
|
||||
else:
|
||||
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
||||
|
||||
file.write("PF\n" if color else "Pf\n".encode())
|
||||
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
||||
|
||||
endian = image.dtype.byteorder
|
||||
|
||||
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
||||
scale = -scale
|
||||
|
||||
file.write("%f\n".encode() % scale)
|
||||
|
||||
image.tofile(file)
|
||||
|
||||
|
||||
def read_image(path):
|
||||
"""Read image and output RGB image (0-1).
|
||||
|
||||
Args:
|
||||
path (str): path to file
|
||||
|
||||
Returns:
|
||||
array: RGB image (0-1)
|
||||
"""
|
||||
img = cv2.imread(path)
|
||||
|
||||
if img.ndim == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def resize_image(img):
|
||||
"""Resize image and make it fit for network.
|
||||
|
||||
Args:
|
||||
img (array): image
|
||||
|
||||
Returns:
|
||||
tensor: data ready for network
|
||||
"""
|
||||
height_orig = img.shape[0]
|
||||
width_orig = img.shape[1]
|
||||
|
||||
if width_orig > height_orig:
|
||||
scale = width_orig / 384
|
||||
else:
|
||||
scale = height_orig / 384
|
||||
|
||||
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
||||
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
||||
|
||||
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
|
||||
|
||||
img_resized = (
|
||||
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
|
||||
)
|
||||
img_resized = img_resized.unsqueeze(0)
|
||||
|
||||
return img_resized
|
||||
|
||||
|
||||
def resize_depth(depth, width, height):
|
||||
"""Resize depth map and bring to CPU (numpy).
|
||||
|
||||
Args:
|
||||
depth (tensor): depth
|
||||
width (int): image width
|
||||
height (int): image height
|
||||
|
||||
Returns:
|
||||
array: processed depth
|
||||
"""
|
||||
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
|
||||
|
||||
depth_resized = cv2.resize(
|
||||
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
|
||||
)
|
||||
|
||||
return depth_resized
|
||||
|
||||
def write_depth(path, depth, bits=1):
|
||||
"""Write depth map to pfm and png file.
|
||||
|
||||
Args:
|
||||
path (str): filepath without extension
|
||||
depth (array): depth
|
||||
"""
|
||||
write_pfm(path + ".pfm", depth.astype(np.float32))
|
||||
|
||||
depth_min = depth.min()
|
||||
depth_max = depth.max()
|
||||
|
||||
max_val = (2**(8*bits))-1
|
||||
|
||||
if depth_max - depth_min > np.finfo("float").eps:
|
||||
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
||||
else:
|
||||
out = np.zeros(depth.shape, dtype=depth.type)
|
||||
|
||||
if bits == 1:
|
||||
cv2.imwrite(path + ".png", out.astype("uint8"))
|
||||
elif bits == 2:
|
||||
cv2.imwrite(path + ".png", out.astype("uint16"))
|
||||
|
||||
return
|
||||
204
comfy/ldm/modules/sub_quadratic_attention.py
Normal file
204
comfy/ldm/modules/sub_quadratic_attention.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# original source:
|
||||
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
|
||||
# license:
|
||||
# MIT
|
||||
# credit:
|
||||
# Amin Rezaei (original author)
|
||||
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
|
||||
# implementation of:
|
||||
# Self-attention Does Not Need O(n2) Memory":
|
||||
# https://arxiv.org/abs/2112.05682v2
|
||||
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import math
|
||||
from typing import Optional, NamedTuple, Protocol, List
|
||||
|
||||
from torch import Tensor
|
||||
from typing import List
|
||||
|
||||
def dynamic_slice(
|
||||
x: Tensor,
|
||||
starts: List[int],
|
||||
sizes: List[int],
|
||||
) -> Tensor:
|
||||
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
|
||||
return x[slicing]
|
||||
|
||||
class AttnChunk(NamedTuple):
|
||||
exp_values: Tensor
|
||||
exp_weights_sum: Tensor
|
||||
max_score: Tensor
|
||||
|
||||
class SummarizeChunk(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
) -> AttnChunk: ...
|
||||
|
||||
class ComputeQueryChunkAttn(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
) -> Tensor: ...
|
||||
|
||||
def _summarize_chunk(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
scale: float,
|
||||
) -> AttnChunk:
|
||||
attn_weights = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key_t,
|
||||
alpha=scale,
|
||||
beta=0,
|
||||
)
|
||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||
max_score = max_score.detach()
|
||||
exp_weights = torch.exp(attn_weights - max_score)
|
||||
exp_values = torch.bmm(exp_weights, value)
|
||||
max_score = max_score.squeeze(-1)
|
||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||
|
||||
def _query_chunk_attention(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
summarize_chunk: SummarizeChunk,
|
||||
kv_chunk_size: int,
|
||||
) -> Tensor:
|
||||
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
|
||||
_, _, v_channels_per_head = value.shape
|
||||
|
||||
def chunk_scanner(chunk_idx: int) -> AttnChunk:
|
||||
key_chunk = dynamic_slice(
|
||||
key_t,
|
||||
(0, 0, chunk_idx),
|
||||
(batch_x_heads, k_channels_per_head, kv_chunk_size)
|
||||
)
|
||||
value_chunk = dynamic_slice(
|
||||
value,
|
||||
(0, chunk_idx, 0),
|
||||
(batch_x_heads, kv_chunk_size, v_channels_per_head)
|
||||
)
|
||||
return summarize_chunk(query, key_chunk, value_chunk)
|
||||
|
||||
chunks: List[AttnChunk] = [
|
||||
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
||||
]
|
||||
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
||||
chunk_values, chunk_weights, chunk_max = acc_chunk
|
||||
|
||||
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
|
||||
max_diffs = torch.exp(chunk_max - global_max)
|
||||
chunk_values *= torch.unsqueeze(max_diffs, -1)
|
||||
chunk_weights *= max_diffs
|
||||
|
||||
all_values = chunk_values.sum(dim=0)
|
||||
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
||||
return all_values / all_weights
|
||||
|
||||
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
||||
def _get_attention_scores_no_kv_chunking(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
scale: float,
|
||||
) -> Tensor:
|
||||
attn_scores = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key_t,
|
||||
alpha=scale,
|
||||
beta=0,
|
||||
)
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
hidden_states_slice = torch.bmm(attn_probs, value)
|
||||
return hidden_states_slice
|
||||
|
||||
class ScannedChunk(NamedTuple):
|
||||
chunk_idx: int
|
||||
attn_chunk: AttnChunk
|
||||
|
||||
def efficient_dot_product_attention(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
query_chunk_size=1024,
|
||||
kv_chunk_size: Optional[int] = None,
|
||||
kv_chunk_size_min: Optional[int] = None,
|
||||
use_checkpoint=True,
|
||||
):
|
||||
"""Computes efficient dot-product attention given query, transposed key, and value.
|
||||
This is efficient version of attention presented in
|
||||
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
|
||||
Args:
|
||||
query: queries for calculating attention with shape of
|
||||
`[batch * num_heads, tokens, channels_per_head]`.
|
||||
key_t: keys for calculating attention with shape of
|
||||
`[batch * num_heads, channels_per_head, tokens]`.
|
||||
value: values to be used in attention with shape of
|
||||
`[batch * num_heads, tokens, channels_per_head]`.
|
||||
query_chunk_size: int: query chunks size
|
||||
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
|
||||
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
|
||||
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
|
||||
Returns:
|
||||
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
|
||||
"""
|
||||
batch_x_heads, q_tokens, q_channels_per_head = query.shape
|
||||
_, _, k_tokens = key_t.shape
|
||||
scale = q_channels_per_head ** -0.5
|
||||
|
||||
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
|
||||
if kv_chunk_size_min is not None:
|
||||
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
||||
|
||||
def get_query_chunk(chunk_idx: int) -> Tensor:
|
||||
return dynamic_slice(
|
||||
query,
|
||||
(0, chunk_idx, 0),
|
||||
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
|
||||
)
|
||||
|
||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||
_get_attention_scores_no_kv_chunking,
|
||||
scale=scale
|
||||
) if k_tokens <= kv_chunk_size else (
|
||||
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
|
||||
partial(
|
||||
_query_chunk_attention,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
summarize_chunk=summarize_chunk,
|
||||
)
|
||||
)
|
||||
|
||||
if q_tokens <= query_chunk_size:
|
||||
# fast-path for when there's just 1 query chunk
|
||||
return compute_query_chunk_attn(
|
||||
query=query,
|
||||
key_t=key_t,
|
||||
value=value,
|
||||
)
|
||||
|
||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
||||
res = torch.cat([
|
||||
compute_query_chunk_attn(
|
||||
query=get_query_chunk(i * query_chunk_size),
|
||||
key_t=key_t,
|
||||
value=value,
|
||||
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
||||
], dim=1)
|
||||
return res
|
||||
197
comfy/ldm/util.py
Normal file
197
comfy/ldm/util.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
from torch import optim
|
||||
import numpy as np
|
||||
|
||||
from inspect import isfunction
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
for bi in range(b):
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
txts = np.stack(txts)
|
||||
txts = torch.tensor(txts)
|
||||
return txts
|
||||
|
||||
|
||||
def ismap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x,torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
class AdamWwithEMAandWings(optim.Optimizer):
|
||||
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
||||
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
|
||||
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
|
||||
ema_power=1., param_names=()):
|
||||
"""AdamW that saves EMA versions of the parameters."""
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
if not 0.0 <= ema_decay <= 1.0:
|
||||
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
|
||||
ema_power=ema_power, param_names=param_names)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Args:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
ema_params_with_grad = []
|
||||
state_sums = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps = []
|
||||
amsgrad = group['amsgrad']
|
||||
beta1, beta2 = group['betas']
|
||||
ema_decay = group['ema_decay']
|
||||
ema_power = group['ema_power']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError('AdamW does not support sparse gradients')
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of parameter values
|
||||
state['param_exp_avg'] = p.detach().float().clone()
|
||||
|
||||
exp_avgs.append(state['exp_avg'])
|
||||
exp_avg_sqs.append(state['exp_avg_sq'])
|
||||
ema_params_with_grad.append(state['param_exp_avg'])
|
||||
|
||||
if amsgrad:
|
||||
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
||||
|
||||
# update the steps for each param group update
|
||||
state['step'] += 1
|
||||
# record the step after step update
|
||||
state_steps.append(state['step'])
|
||||
|
||||
optim._functional.adamw(params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group['lr'],
|
||||
weight_decay=group['weight_decay'],
|
||||
eps=group['eps'],
|
||||
maximize=False)
|
||||
|
||||
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
|
||||
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
||||
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
||||
|
||||
return loss
|
||||
114
comfy/samplers.py
Normal file
114
comfy/samplers.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import k_diffusion.sampling
|
||||
import k_diffusion.external
|
||||
import torch
|
||||
import contextlib
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
if len(uncond[0]) == len(cond[0]) and x.shape[0] * x.shape[2] * x.shape[3] <= (96 * 96): #TODO check memory instead
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||
else:
|
||||
cond = self.inner_model(x, sigma, cond=cond)
|
||||
uncond = self.inner_model(x, sigma, cond=uncond)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
|
||||
def simple_scheduler(model, steps):
|
||||
sigs = []
|
||||
ss = len(model.sigmas) / steps
|
||||
for x in range(steps):
|
||||
sigs += [float(model.sigmas[-(1 + int(x * ss))])]
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
|
||||
class KSampler:
|
||||
SCHEDULERS = ["karras", "normal", "simple"]
|
||||
SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral",
|
||||
"sample_lms", "sample_dpm_fast", "sample_dpm_adaptive", "sample_dpmpp_2s_ancestral", "sample_dpmpp_sde",
|
||||
"sample_dpmpp_2m"]
|
||||
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
|
||||
self.model = model
|
||||
if self.model.parameterization == "v":
|
||||
self.model_wrap = k_diffusion.external.CompVisVDenoiser(self.model, quantize=True)
|
||||
else:
|
||||
self.model_wrap = k_diffusion.external.CompVisDenoiser(self.model, quantize=True)
|
||||
self.model_k = CFGDenoiser(self.model_wrap)
|
||||
self.device = device
|
||||
if scheduler not in self.SCHEDULERS:
|
||||
scheduler = self.SCHEDULERS[0]
|
||||
if sampler not in self.SAMPLERS:
|
||||
sampler = self.SAMPLERS[0]
|
||||
self.scheduler = scheduler
|
||||
self.sampler = sampler
|
||||
self.sigma_min=float(self.model_wrap.sigmas[0])
|
||||
self.sigma_max=float(self.model_wrap.sigmas[-1])
|
||||
self.set_steps(steps, denoise)
|
||||
|
||||
def _calculate_sigmas(self, steps):
|
||||
sigmas = None
|
||||
|
||||
discard_penultimate_sigma = False
|
||||
if self.sampler in ['sample_dpm_2', 'sample_dpm_2_ancestral']:
|
||||
steps += 1
|
||||
discard_penultimate_sigma = True
|
||||
|
||||
if self.scheduler == "karras":
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device)
|
||||
elif self.scheduler == "normal":
|
||||
sigmas = self.model_wrap.get_sigmas(steps).to(self.device)
|
||||
elif self.scheduler == "simple":
|
||||
sigmas = simple_scheduler(self.model_wrap, steps).to(self.device)
|
||||
else:
|
||||
print("error invalid scheduler", self.scheduler)
|
||||
|
||||
if discard_penultimate_sigma:
|
||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||
return sigmas
|
||||
|
||||
def set_steps(self, steps, denoise=None):
|
||||
self.steps = steps
|
||||
if denoise is None:
|
||||
self.sigmas = self._calculate_sigmas(steps)
|
||||
else:
|
||||
new_steps = int(steps/denoise)
|
||||
sigmas = self._calculate_sigmas(new_steps)
|
||||
self.sigmas = sigmas[-(steps + 1):]
|
||||
|
||||
|
||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None):
|
||||
sigmas = self.sigmas
|
||||
sigma_min = self.sigma_min
|
||||
|
||||
if last_step is not None:
|
||||
sigma_min = sigmas[last_step]
|
||||
sigmas = sigmas[:last_step + 1]
|
||||
if start_step is not None:
|
||||
sigmas = sigmas[start_step:]
|
||||
|
||||
|
||||
noise *= sigmas[0]
|
||||
if latent_image is not None:
|
||||
noise += latent_image
|
||||
|
||||
if self.model.model.diffusion_model.dtype == torch.float16:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = contextlib.nullcontext
|
||||
|
||||
with precision_scope(self.device):
|
||||
if self.sampler == "sample_dpm_fast":
|
||||
samples = k_diffusion.sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
|
||||
elif self.sampler == "sample_dpm_adaptive":
|
||||
samples = k_diffusion.sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
|
||||
else:
|
||||
samples = getattr(k_diffusion.sampling, self.sampler)(self.model_k, noise, sigmas, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
|
||||
return samples.to(torch.float32)
|
||||
124
comfy/sd.py
Normal file
124
comfy/sd.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import torch
|
||||
|
||||
import sd1_clip
|
||||
import sd2_clip
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.autoencoder import AutoencoderKL
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
|
||||
print(f"Loading model from {ckpt}")
|
||||
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
model = instantiate_from_config(config.model)
|
||||
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
k = list(sd.keys())
|
||||
for x in k:
|
||||
# print(x)
|
||||
if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
|
||||
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
|
||||
sd[y] = sd.pop(x)
|
||||
|
||||
for x in load_state_dict_to:
|
||||
x.load_state_dict(sd, strict=False)
|
||||
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, config):
|
||||
self.target_clip = config["target"]
|
||||
if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder":
|
||||
clip = sd2_clip.SD2ClipModel
|
||||
tokenizer = sd2_clip.SD2Tokenizer
|
||||
elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder":
|
||||
clip = sd1_clip.SD1ClipModel
|
||||
tokenizer = sd1_clip.SD1Tokenizer
|
||||
if "params" in config:
|
||||
self.cond_stage_model = clip(**(config["params"]))
|
||||
else:
|
||||
self.cond_stage_model = clip()
|
||||
self.tokenizer = tokenizer()
|
||||
|
||||
def encode(self, text):
|
||||
tokens = self.tokenizer.tokenize_with_weights(text)
|
||||
cond = self.cond_stage_model.encode_token_weights(tokens)
|
||||
return cond
|
||||
|
||||
|
||||
class VAE:
|
||||
def __init__(self, ckpt_path=None, scale_factor=0.18215, device="cuda", config=None):
|
||||
if config is None:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path)
|
||||
else:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path)
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
self.scale_factor = scale_factor
|
||||
self.device = device
|
||||
|
||||
def decode(self, samples):
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
samples = samples.to(self.device)
|
||||
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples)
|
||||
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
self.first_stage_model = self.first_stage_model.cpu()
|
||||
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
|
||||
samples = self.first_stage_model.encode(2. * pixel_samples - 1.).sample() * self.scale_factor
|
||||
self.first_stage_model = self.first_stage_model.cpu()
|
||||
samples = samples.cpu()
|
||||
return samples
|
||||
|
||||
|
||||
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True):
|
||||
config = OmegaConf.load(config_path)
|
||||
model_config_params = config['model']['params']
|
||||
clip_config = model_config_params['cond_stage_config']
|
||||
scale_factor = model_config_params['scale_factor']
|
||||
vae_config = model_config_params['first_stage_config']
|
||||
|
||||
clip = None
|
||||
vae = None
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
w = WeightsLoader()
|
||||
load_state_dict_to = []
|
||||
if output_vae:
|
||||
vae = VAE(scale_factor=scale_factor, config=vae_config)
|
||||
w.first_stage_model = vae.first_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
if output_clip:
|
||||
clip = CLIP(config=clip_config)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
model = load_model_from_config(config, ckpt_path, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||
return (model, clip, vae)
|
||||
178
comfy/sd1_clip.py
Normal file
178
comfy/sd1_clip.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import os
|
||||
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
|
||||
import torch
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
z_empty = self.encode(self.empty_tokens)
|
||||
output = []
|
||||
for x in token_weight_pairs:
|
||||
tokens = [list(map(lambda a: a[0], x))]
|
||||
z = self.encode(tokens)
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = x[j][1]
|
||||
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
|
||||
output += [z]
|
||||
if (len(output) == 0):
|
||||
return self.encode(self.empty_tokens)
|
||||
return torch.cat(output, dim=-2)
|
||||
|
||||
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = [
|
||||
"last",
|
||||
"pooled",
|
||||
"hidden"
|
||||
]
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
if textmodel_path is not None:
|
||||
self.transformer = CLIPTextModel.from_pretrained(textmodel_path)
|
||||
else:
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
||||
self.transformer = CLIPTextModel(config)
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
self.layer_idx = None
|
||||
self.empty_tokens = [[49406] + [49407] * 76]
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) <= 12
|
||||
self.clip_layer(layer_idx)
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
if abs(layer_idx) >= 12:
|
||||
self.layer = "last"
|
||||
else:
|
||||
self.layer = "hidden"
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(self, tokens):
|
||||
tokens = torch.LongTensor(tokens).to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
z = self.transformer.text_model.final_layer_norm(z)
|
||||
|
||||
return z
|
||||
|
||||
def encode(self, tokens):
|
||||
return self(tokens)
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
current_item = ""
|
||||
nesting_level = 0
|
||||
for char in string:
|
||||
if char == "(":
|
||||
if nesting_level == 0:
|
||||
if current_item:
|
||||
result.append(current_item)
|
||||
current_item = "("
|
||||
else:
|
||||
current_item = "("
|
||||
else:
|
||||
current_item += char
|
||||
nesting_level += 1
|
||||
elif char == ")":
|
||||
nesting_level -= 1
|
||||
if nesting_level == 0:
|
||||
result.append(current_item + ")")
|
||||
current_item = ""
|
||||
else:
|
||||
current_item += char
|
||||
else:
|
||||
current_item += char
|
||||
if current_item:
|
||||
result.append(current_item)
|
||||
return result
|
||||
|
||||
def token_weights(string, current_weight):
|
||||
a = parse_parentheses(string)
|
||||
out = []
|
||||
for x in a:
|
||||
weight = current_weight
|
||||
if len(x) >= 2 and x[-1] == ')' and x[0] == '(':
|
||||
x = x[1:-1]
|
||||
xx = x.rfind(":")
|
||||
weight *= 1.1
|
||||
if xx > 0:
|
||||
try:
|
||||
weight = float(x[xx+1:])
|
||||
x = x[:xx]
|
||||
except:
|
||||
pass
|
||||
out += token_weights(x, weight)
|
||||
else:
|
||||
out += [(x, current_weight)]
|
||||
return out
|
||||
|
||||
def escape_important(text):
|
||||
text = text.replace("\\)", "\0\1")
|
||||
text = text.replace("\\(", "\0\2")
|
||||
return text
|
||||
|
||||
def unescape_important(text):
|
||||
text = text.replace("\0\1", ")")
|
||||
text = text.replace("\0\2", "(")
|
||||
return text
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.max_length = max_length
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.start_token = empty[0]
|
||||
self.end_token = empty[1]
|
||||
self.pad_with_end = pad_with_end
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
|
||||
def tokenize_with_weights(self, text):
|
||||
text = escape_important(text)
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
|
||||
tokens = []
|
||||
for t in parsed_weights:
|
||||
tt = self.tokenizer(unescape_important(t[0]))["input_ids"][1:-1]
|
||||
for x in tt:
|
||||
tokens += [(x, t[1])]
|
||||
|
||||
out_tokens = []
|
||||
for x in range(0, len(tokens), self.max_length - 2):
|
||||
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_length - 2 + x, len(tokens))]
|
||||
o_token += [(self.end_token, 1.0)]
|
||||
if self.pad_with_end:
|
||||
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token))
|
||||
else:
|
||||
o_token +=[(0, 1.0)] * (self.max_length - len(o_token))
|
||||
|
||||
out_tokens += [o_token]
|
||||
|
||||
return out_tokens
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||
25
comfy/sd1_clip_config.json
Normal file
25
comfy/sd1_clip_config.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"_name_or_path": "openai/clip-vit-large-patch14",
|
||||
"architectures": [
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 77,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.24.0",
|
||||
"vocab_size": 49408
|
||||
}
|
||||
48895
comfy/sd1_tokenizer/merges.txt
Normal file
48895
comfy/sd1_tokenizer/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
24
comfy/sd1_tokenizer/special_tokens_map.json
Normal file
24
comfy/sd1_tokenizer/special_tokens_map.json
Normal file
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": "<|endoftext|>",
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
34
comfy/sd1_tokenizer/tokenizer_config.json
Normal file
34
comfy/sd1_tokenizer/tokenizer_config.json
Normal file
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"bos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"do_lower_case": true,
|
||||
"eos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"name_or_path": "openai/clip-vit-large-patch14",
|
||||
"pad_token": "<|endoftext|>",
|
||||
"special_tokens_map_file": "./special_tokens_map.json",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
49410
comfy/sd1_tokenizer/vocab.json
Normal file
49410
comfy/sd1_tokenizer/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
89
comfy/sd2_clip.py
Normal file
89
comfy/sd2_clip.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import sd1_clip
|
||||
import open_clip
|
||||
import torch
|
||||
|
||||
class SD2ClipModel(torch.nn.Module, sd1_clip.ClipTokenWeightEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = [
|
||||
#"pooled",
|
||||
"last",
|
||||
"penultimate",
|
||||
"hidden"
|
||||
]
|
||||
#version="laion2b_s32b_b79k"
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77,
|
||||
freeze=True, layer="penultimate", layer_idx=None):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'))
|
||||
del model.visual
|
||||
self.model = model
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "last":
|
||||
self.layer_idx = 0
|
||||
elif self.layer == "penultimate":
|
||||
self.layer_idx = 1
|
||||
elif self.layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) < 24
|
||||
self.clip_layer(layer_idx)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
#layer_idx should have the same logic as the one for SD1
|
||||
if abs(layer_idx) >= 24:
|
||||
self.layer_idx = 0
|
||||
else:
|
||||
if layer_idx < 0:
|
||||
self.layer_idx = -(layer_idx + 1)
|
||||
else:
|
||||
self.layer_idx = 24 - (layer_idx + 1)
|
||||
|
||||
def forward(self, tokens):
|
||||
tokens = torch.LongTensor(tokens).to(self.device)
|
||||
z = self.encode_with_transformer(tokens)
|
||||
return z
|
||||
|
||||
def encode_with_transformer(self, tokens):
|
||||
x = self.model.token_embedding(tokens) # [batch_size, n_ctx, d_model]
|
||||
x = x + self.model.positional_embedding
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint(r, x, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
def encode(self, tokens):
|
||||
return self(tokens)
|
||||
|
||||
|
||||
|
||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, tokenizer_path=None):
|
||||
super().__init__(tokenizer_path, pad_with_end=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user