Initial commit.
This commit is contained in:
105
comfy/k_diffusion/augmentation.py
Normal file
105
comfy/k_diffusion/augmentation.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from functools import reduce
|
||||
import math
|
||||
import operator
|
||||
|
||||
import numpy as np
|
||||
from skimage import transform
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def translate2d(tx, ty):
|
||||
mat = [[1, 0, tx],
|
||||
[0, 1, ty],
|
||||
[0, 0, 1]]
|
||||
return torch.tensor(mat, dtype=torch.float32)
|
||||
|
||||
|
||||
def scale2d(sx, sy):
|
||||
mat = [[sx, 0, 0],
|
||||
[ 0, sy, 0],
|
||||
[ 0, 0, 1]]
|
||||
return torch.tensor(mat, dtype=torch.float32)
|
||||
|
||||
|
||||
def rotate2d(theta):
|
||||
mat = [[torch.cos(theta), torch.sin(-theta), 0],
|
||||
[torch.sin(theta), torch.cos(theta), 0],
|
||||
[ 0, 0, 1]]
|
||||
return torch.tensor(mat, dtype=torch.float32)
|
||||
|
||||
|
||||
class KarrasAugmentationPipeline:
|
||||
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8):
|
||||
self.a_prob = a_prob
|
||||
self.a_scale = a_scale
|
||||
self.a_aniso = a_aniso
|
||||
self.a_trans = a_trans
|
||||
|
||||
def __call__(self, image):
|
||||
h, w = image.size
|
||||
mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)]
|
||||
|
||||
# x-flip
|
||||
a0 = torch.randint(2, []).float()
|
||||
mats.append(scale2d(1 - 2 * a0, 1))
|
||||
# y-flip
|
||||
do = (torch.rand([]) < self.a_prob).float()
|
||||
a1 = torch.randint(2, []).float() * do
|
||||
mats.append(scale2d(1, 1 - 2 * a1))
|
||||
# scaling
|
||||
do = (torch.rand([]) < self.a_prob).float()
|
||||
a2 = torch.randn([]) * do
|
||||
mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2))
|
||||
# rotation
|
||||
do = (torch.rand([]) < self.a_prob).float()
|
||||
a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do
|
||||
mats.append(rotate2d(-a3))
|
||||
# anisotropy
|
||||
do = (torch.rand([]) < self.a_prob).float()
|
||||
a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do
|
||||
a5 = torch.randn([]) * do
|
||||
mats.append(rotate2d(a4))
|
||||
mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5))
|
||||
mats.append(rotate2d(-a4))
|
||||
# translation
|
||||
do = (torch.rand([]) < self.a_prob).float()
|
||||
a6 = torch.randn([]) * do
|
||||
a7 = torch.randn([]) * do
|
||||
mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7))
|
||||
|
||||
# form the transformation matrix and conditioning vector
|
||||
mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5))
|
||||
mat = reduce(operator.matmul, mats)
|
||||
cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7])
|
||||
|
||||
# apply the transformation
|
||||
image_orig = np.array(image, dtype=np.float32) / 255
|
||||
if image_orig.ndim == 2:
|
||||
image_orig = image_orig[..., None]
|
||||
tf = transform.AffineTransform(mat.numpy())
|
||||
image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
|
||||
image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1
|
||||
image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
|
||||
return image, image_orig, cond
|
||||
|
||||
|
||||
class KarrasAugmentWrapper(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
|
||||
def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs):
|
||||
if aug_cond is None:
|
||||
aug_cond = input.new_zeros([input.shape[0], 9])
|
||||
if mapping_cond is None:
|
||||
mapping_cond = aug_cond
|
||||
else:
|
||||
mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1)
|
||||
return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs)
|
||||
|
||||
def set_skip_stages(self, skip_stages):
|
||||
return self.inner_model.set_skip_stages(skip_stages)
|
||||
|
||||
def set_patch_size(self, patch_size):
|
||||
return self.inner_model.set_patch_size(patch_size)
|
||||
110
comfy/k_diffusion/config.py
Normal file
110
comfy/k_diffusion/config.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from functools import partial
|
||||
import json
|
||||
import math
|
||||
import warnings
|
||||
|
||||
from jsonmerge import merge
|
||||
|
||||
from . import augmentation, layers, models, utils
|
||||
|
||||
|
||||
def load_config(file):
|
||||
defaults = {
|
||||
'model': {
|
||||
'sigma_data': 1.,
|
||||
'patch_size': 1,
|
||||
'dropout_rate': 0.,
|
||||
'augment_wrapper': True,
|
||||
'augment_prob': 0.,
|
||||
'mapping_cond_dim': 0,
|
||||
'unet_cond_dim': 0,
|
||||
'cross_cond_dim': 0,
|
||||
'cross_attn_depths': None,
|
||||
'skip_stages': 0,
|
||||
'has_variance': False,
|
||||
},
|
||||
'dataset': {
|
||||
'type': 'imagefolder',
|
||||
},
|
||||
'optimizer': {
|
||||
'type': 'adamw',
|
||||
'lr': 1e-4,
|
||||
'betas': [0.95, 0.999],
|
||||
'eps': 1e-6,
|
||||
'weight_decay': 1e-3,
|
||||
},
|
||||
'lr_sched': {
|
||||
'type': 'inverse',
|
||||
'inv_gamma': 20000.,
|
||||
'power': 1.,
|
||||
'warmup': 0.99,
|
||||
},
|
||||
'ema_sched': {
|
||||
'type': 'inverse',
|
||||
'power': 0.6667,
|
||||
'max_value': 0.9999
|
||||
},
|
||||
}
|
||||
config = json.load(file)
|
||||
return merge(defaults, config)
|
||||
|
||||
|
||||
def make_model(config):
|
||||
config = config['model']
|
||||
assert config['type'] == 'image_v1'
|
||||
model = models.ImageDenoiserModelV1(
|
||||
config['input_channels'],
|
||||
config['mapping_out'],
|
||||
config['depths'],
|
||||
config['channels'],
|
||||
config['self_attn_depths'],
|
||||
config['cross_attn_depths'],
|
||||
patch_size=config['patch_size'],
|
||||
dropout_rate=config['dropout_rate'],
|
||||
mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0),
|
||||
unet_cond_dim=config['unet_cond_dim'],
|
||||
cross_cond_dim=config['cross_cond_dim'],
|
||||
skip_stages=config['skip_stages'],
|
||||
has_variance=config['has_variance'],
|
||||
)
|
||||
if config['augment_wrapper']:
|
||||
model = augmentation.KarrasAugmentWrapper(model)
|
||||
return model
|
||||
|
||||
|
||||
def make_denoiser_wrapper(config):
|
||||
config = config['model']
|
||||
sigma_data = config.get('sigma_data', 1.)
|
||||
has_variance = config.get('has_variance', False)
|
||||
if not has_variance:
|
||||
return partial(layers.Denoiser, sigma_data=sigma_data)
|
||||
return partial(layers.DenoiserWithVariance, sigma_data=sigma_data)
|
||||
|
||||
|
||||
def make_sample_density(config):
|
||||
sd_config = config['sigma_sample_density']
|
||||
sigma_data = config['sigma_data']
|
||||
if sd_config['type'] == 'lognormal':
|
||||
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
|
||||
scale = sd_config['std'] if 'std' in sd_config else sd_config['scale']
|
||||
return partial(utils.rand_log_normal, loc=loc, scale=scale)
|
||||
if sd_config['type'] == 'loglogistic':
|
||||
loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data)
|
||||
scale = sd_config['scale'] if 'scale' in sd_config else 0.5
|
||||
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
|
||||
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
|
||||
return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value)
|
||||
if sd_config['type'] == 'loguniform':
|
||||
min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min']
|
||||
max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max']
|
||||
return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value)
|
||||
if sd_config['type'] == 'v-diffusion':
|
||||
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
|
||||
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
|
||||
return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value)
|
||||
if sd_config['type'] == 'split-lognormal':
|
||||
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
|
||||
scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1']
|
||||
scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2']
|
||||
return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2)
|
||||
raise ValueError('Unknown sample density type')
|
||||
134
comfy/k_diffusion/evaluation.py
Normal file
134
comfy/k_diffusion/evaluation.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from cleanfid.inception_torchscript import InceptionV3W
|
||||
import clip
|
||||
from resize_right import resize
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import trange
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
class InceptionV3FeatureExtractor(nn.Module):
|
||||
def __init__(self, device='cpu'):
|
||||
super().__init__()
|
||||
path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion'
|
||||
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
||||
digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4'
|
||||
utils.download_file(path / 'inception-2015-12-05.pt', url, digest)
|
||||
self.model = InceptionV3W(str(path), resize_inside=False).to(device)
|
||||
self.size = (299, 299)
|
||||
|
||||
def forward(self, x):
|
||||
if x.shape[2:4] != self.size:
|
||||
x = resize(x, out_shape=self.size, pad_mode='reflect')
|
||||
if x.shape[1] == 1:
|
||||
x = torch.cat([x] * 3, dim=1)
|
||||
x = (x * 127.5 + 127.5).clamp(0, 255)
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class CLIPFeatureExtractor(nn.Module):
|
||||
def __init__(self, name='ViT-L/14@336px', device='cpu'):
|
||||
super().__init__()
|
||||
self.model = clip.load(name, device=device)[0].eval().requires_grad_(False)
|
||||
self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711))
|
||||
self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution)
|
||||
|
||||
def forward(self, x):
|
||||
if x.shape[2:4] != self.size:
|
||||
x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1)
|
||||
x = self.normalize(x)
|
||||
x = self.model.encode_image(x).float()
|
||||
x = F.normalize(x) * x.shape[1] ** 0.5
|
||||
return x
|
||||
|
||||
|
||||
def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size):
|
||||
n_per_proc = math.ceil(n / accelerator.num_processes)
|
||||
feats_all = []
|
||||
try:
|
||||
for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process):
|
||||
cur_batch_size = min(n - i, batch_size)
|
||||
samples = sample_fn(cur_batch_size)[:cur_batch_size]
|
||||
feats_all.append(accelerator.gather(extractor_fn(samples)))
|
||||
except StopIteration:
|
||||
pass
|
||||
return torch.cat(feats_all)[:n]
|
||||
|
||||
|
||||
def polynomial_kernel(x, y):
|
||||
d = x.shape[-1]
|
||||
dot = x @ y.transpose(-2, -1)
|
||||
return (dot / d + 1) ** 3
|
||||
|
||||
|
||||
def squared_mmd(x, y, kernel=polynomial_kernel):
|
||||
m = x.shape[-2]
|
||||
n = y.shape[-2]
|
||||
kxx = kernel(x, x)
|
||||
kyy = kernel(y, y)
|
||||
kxy = kernel(x, y)
|
||||
kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1)
|
||||
kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1)
|
||||
kxy_sum = kxy.sum([-1, -2])
|
||||
term_1 = kxx_sum / m / (m - 1)
|
||||
term_2 = kyy_sum / n / (n - 1)
|
||||
term_3 = kxy_sum * 2 / m / n
|
||||
return term_1 + term_2 - term_3
|
||||
|
||||
|
||||
@utils.tf32_mode(matmul=False)
|
||||
def kid(x, y, max_size=5000):
|
||||
x_size, y_size = x.shape[0], y.shape[0]
|
||||
n_partitions = math.ceil(max(x_size / max_size, y_size / max_size))
|
||||
total_mmd = x.new_zeros([])
|
||||
for i in range(n_partitions):
|
||||
cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)]
|
||||
cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)]
|
||||
total_mmd = total_mmd + squared_mmd(cur_x, cur_y)
|
||||
return total_mmd / n_partitions
|
||||
|
||||
|
||||
class _MatrixSquareRootEig(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a):
|
||||
vals, vecs = torch.linalg.eigh(a)
|
||||
ctx.save_for_backward(vals, vecs)
|
||||
return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
vals, vecs = ctx.saved_tensors
|
||||
d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1)
|
||||
vecs_t = vecs.transpose(-2, -1)
|
||||
return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t
|
||||
|
||||
|
||||
def sqrtm_eig(a):
|
||||
if a.ndim < 2:
|
||||
raise RuntimeError('tensor of matrices must have at least 2 dimensions')
|
||||
if a.shape[-2] != a.shape[-1]:
|
||||
raise RuntimeError('tensor must be batches of square matrices')
|
||||
return _MatrixSquareRootEig.apply(a)
|
||||
|
||||
|
||||
@utils.tf32_mode(matmul=False)
|
||||
def fid(x, y, eps=1e-8):
|
||||
x_mean = x.mean(dim=0)
|
||||
y_mean = y.mean(dim=0)
|
||||
mean_term = (x_mean - y_mean).pow(2).sum()
|
||||
x_cov = torch.cov(x.T)
|
||||
y_cov = torch.cov(y.T)
|
||||
eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps
|
||||
x_cov = x_cov + eps_eye
|
||||
y_cov = y_cov + eps_eye
|
||||
x_cov_sqrt = sqrtm_eig(x_cov)
|
||||
cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt))
|
||||
return mean_term + cov_term
|
||||
179
comfy/k_diffusion/external.py
Normal file
179
comfy/k_diffusion/external.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from . import sampling, utils
|
||||
|
||||
|
||||
class VDenoiser(nn.Module):
|
||||
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
|
||||
|
||||
def __init__(self, inner_model):
|
||||
super().__init__()
|
||||
self.inner_model = inner_model
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
||||
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_skip, c_out, c_in
|
||||
|
||||
def sigma_to_t(self, sigma):
|
||||
return sigma.atan() / math.pi * 2
|
||||
|
||||
def t_to_sigma(self, t):
|
||||
return (t * math.pi / 2).tan()
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
||||
|
||||
|
||||
class DiscreteSchedule(nn.Module):
|
||||
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
|
||||
levels."""
|
||||
|
||||
def __init__(self, sigmas, quantize):
|
||||
super().__init__()
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
self.quantize = quantize
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def get_sigmas(self, n=None):
|
||||
if n is None:
|
||||
return sampling.append_zero(self.sigmas.flip(0))
|
||||
t_max = len(self.sigmas) - 1
|
||||
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
||||
return sampling.append_zero(self.t_to_sigma(t))
|
||||
|
||||
def sigma_to_t(self, sigma, quantize=None):
|
||||
quantize = self.quantize if quantize is None else quantize
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma - self.log_sigmas[:, None]
|
||||
if quantize:
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = w.clamp(0, 1)
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
return t.view(sigma.shape)
|
||||
|
||||
def t_to_sigma(self, t):
|
||||
t = t.float()
|
||||
low_idx = t.floor().long()
|
||||
high_idx = t.ceil().long()
|
||||
w = t-low_idx if t.device.type == 'mps' else t.frac()
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp()
|
||||
|
||||
|
||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
||||
noise)."""
|
||||
|
||||
def __init__(self, model, alphas_cumprod, quantize):
|
||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
||||
self.inner_model = model
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_out, c_in
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model(*args, **kwargs)
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return (eps - noise).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return input + eps * c_out
|
||||
|
||||
|
||||
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
"""A wrapper for OpenAI diffusion models."""
|
||||
|
||||
def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
|
||||
alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
|
||||
super().__init__(model, alphas_cumprod, quantize=quantize)
|
||||
self.has_learned_sigmas = has_learned_sigmas
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
model_output = self.inner_model(*args, **kwargs)
|
||||
if self.has_learned_sigmas:
|
||||
return model_output.chunk(2, dim=1)[0]
|
||||
return model_output
|
||||
|
||||
|
||||
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
"""A wrapper for CompVis diffusion models."""
|
||||
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model.apply_model(*args, **kwargs)
|
||||
|
||||
|
||||
class DiscreteVDDPMDenoiser(DiscreteSchedule):
|
||||
"""A wrapper for discrete schedule DDPM models that output v."""
|
||||
|
||||
def __init__(self, model, alphas_cumprod, quantize):
|
||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
||||
self.inner_model = model
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
||||
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_skip, c_out, c_in
|
||||
|
||||
def get_v(self, *args, **kwargs):
|
||||
return self.inner_model(*args, **kwargs)
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
||||
|
||||
|
||||
class CompVisVDenoiser(DiscreteVDDPMDenoiser):
|
||||
"""A wrapper for CompVis diffusion models that output v."""
|
||||
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_v(self, x, t, cond, **kwargs):
|
||||
return self.inner_model.apply_model(x, t, cond)
|
||||
99
comfy/k_diffusion/gns.py
Normal file
99
comfy/k_diffusion/gns.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class DDPGradientStatsHook:
|
||||
def __init__(self, ddp_module):
|
||||
try:
|
||||
ddp_module.register_comm_hook(self, self._hook_fn)
|
||||
except AttributeError:
|
||||
raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules')
|
||||
self._clear_state()
|
||||
|
||||
def _clear_state(self):
|
||||
self.bucket_sq_norms_small_batch = []
|
||||
self.bucket_sq_norms_large_batch = []
|
||||
|
||||
@staticmethod
|
||||
def _hook_fn(self, bucket):
|
||||
buf = bucket.buffer()
|
||||
self.bucket_sq_norms_small_batch.append(buf.pow(2).sum())
|
||||
fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future()
|
||||
def callback(fut):
|
||||
buf = fut.value()[0]
|
||||
self.bucket_sq_norms_large_batch.append(buf.pow(2).sum())
|
||||
return buf
|
||||
return fut.then(callback)
|
||||
|
||||
def get_stats(self):
|
||||
sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch)
|
||||
sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch)
|
||||
self._clear_state()
|
||||
stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch])
|
||||
torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG)
|
||||
return stats[0].item(), stats[1].item()
|
||||
|
||||
|
||||
class GradientNoiseScale:
|
||||
"""Calculates the gradient noise scale (1 / SNR), or critical batch size,
|
||||
from _An Empirical Model of Large-Batch Training_,
|
||||
https://arxiv.org/abs/1812.06162).
|
||||
|
||||
Args:
|
||||
beta (float): The decay factor for the exponential moving averages used to
|
||||
calculate the gradient noise scale.
|
||||
Default: 0.9998
|
||||
eps (float): Added for numerical stability.
|
||||
Default: 1e-8
|
||||
"""
|
||||
|
||||
def __init__(self, beta=0.9998, eps=1e-8):
|
||||
self.beta = beta
|
||||
self.eps = eps
|
||||
self.ema_sq_norm = 0.
|
||||
self.ema_var = 0.
|
||||
self.beta_cumprod = 1.
|
||||
self.gradient_noise_scale = float('nan')
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the object as a :class:`dict`."""
|
||||
return dict(self.__dict__.items())
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the object's state.
|
||||
Args:
|
||||
state_dict (dict): object state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch):
|
||||
"""Updates the state with a new batch's gradient statistics, and returns the
|
||||
current gradient noise scale.
|
||||
|
||||
Args:
|
||||
sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or
|
||||
per sample gradients.
|
||||
sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or
|
||||
per sample gradients.
|
||||
n_small_batch (int): The batch size of the individual microbatch or per sample
|
||||
gradients (1 if per sample).
|
||||
n_large_batch (int): The total batch size of the mean of the microbatch or
|
||||
per sample gradients.
|
||||
"""
|
||||
est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch)
|
||||
est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch)
|
||||
self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm
|
||||
self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var
|
||||
self.beta_cumprod *= self.beta
|
||||
self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps)
|
||||
return self.gradient_noise_scale
|
||||
|
||||
def get_gns(self):
|
||||
"""Returns the current gradient noise scale."""
|
||||
return self.gradient_noise_scale
|
||||
|
||||
def get_stats(self):
|
||||
"""Returns the current (debiased) estimates of the squared mean gradient
|
||||
and gradient variance."""
|
||||
return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod)
|
||||
246
comfy/k_diffusion/layers.py
Normal file
246
comfy/k_diffusion/layers.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import math
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from . import utils
|
||||
|
||||
# Karras et al. preconditioned denoiser
|
||||
|
||||
class Denoiser(nn.Module):
|
||||
"""A Karras et al. preconditioner for denoising diffusion models."""
|
||||
|
||||
def __init__(self, inner_model, sigma_data=1.):
|
||||
super().__init__()
|
||||
self.inner_model = inner_model
|
||||
self.sigma_data = sigma_data
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
||||
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_skip, c_out, c_in
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output = self.inner_model(noised_input * c_in, sigma, **kwargs)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip
|
||||
|
||||
|
||||
class DenoiserWithVariance(Denoiser):
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs)
|
||||
logvar = utils.append_dims(logvar, model_output.ndim)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2
|
||||
return losses.flatten(1).mean(1)
|
||||
|
||||
|
||||
# Residual blocks
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, *main, skip=None):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(*main)
|
||||
self.skip = skip if skip else nn.Identity()
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input) + self.skip(input)
|
||||
|
||||
|
||||
# Noise level (and other) conditioning
|
||||
|
||||
class ConditionedModule(nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
class UnconditionedModule(ConditionedModule):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self, input, cond=None):
|
||||
return self.module(input)
|
||||
|
||||
|
||||
class ConditionedSequential(nn.Sequential, ConditionedModule):
|
||||
def forward(self, input, cond):
|
||||
for module in self:
|
||||
if isinstance(module, ConditionedModule):
|
||||
input = module(input, cond)
|
||||
else:
|
||||
input = module(input)
|
||||
return input
|
||||
|
||||
|
||||
class ConditionedResidualBlock(ConditionedModule):
|
||||
def __init__(self, *main, skip=None):
|
||||
super().__init__()
|
||||
self.main = ConditionedSequential(*main)
|
||||
self.skip = skip if skip else nn.Identity()
|
||||
|
||||
def forward(self, input, cond):
|
||||
skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input)
|
||||
return self.main(input, cond) + skip
|
||||
|
||||
|
||||
class AdaGN(ConditionedModule):
|
||||
def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'):
|
||||
super().__init__()
|
||||
self.num_groups = num_groups
|
||||
self.eps = eps
|
||||
self.cond_key = cond_key
|
||||
self.mapper = nn.Linear(feats_in, c_out * 2)
|
||||
|
||||
def forward(self, input, cond):
|
||||
weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1)
|
||||
input = F.group_norm(input, self.num_groups, eps=self.eps)
|
||||
return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1)
|
||||
|
||||
|
||||
# Attention
|
||||
|
||||
class SelfAttention2d(ConditionedModule):
|
||||
def __init__(self, c_in, n_head, norm, dropout_rate=0.):
|
||||
super().__init__()
|
||||
assert c_in % n_head == 0
|
||||
self.norm_in = norm(c_in)
|
||||
self.n_head = n_head
|
||||
self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
|
||||
self.out_proj = nn.Conv2d(c_in, c_in, 1)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, input, cond):
|
||||
n, c, h, w = input.shape
|
||||
qkv = self.qkv_proj(self.norm_in(input, cond))
|
||||
qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = k.shape[3] ** -0.25
|
||||
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
|
||||
att = self.dropout(att)
|
||||
y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
|
||||
return input + self.out_proj(y)
|
||||
|
||||
|
||||
class CrossAttention2d(ConditionedModule):
|
||||
def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0.,
|
||||
cond_key='cross', cond_key_padding='cross_padding'):
|
||||
super().__init__()
|
||||
assert c_dec % n_head == 0
|
||||
self.cond_key = cond_key
|
||||
self.cond_key_padding = cond_key_padding
|
||||
self.norm_enc = nn.LayerNorm(c_enc)
|
||||
self.norm_dec = norm_dec(c_dec)
|
||||
self.n_head = n_head
|
||||
self.q_proj = nn.Conv2d(c_dec, c_dec, 1)
|
||||
self.kv_proj = nn.Linear(c_enc, c_dec * 2)
|
||||
self.out_proj = nn.Conv2d(c_dec, c_dec, 1)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, input, cond):
|
||||
n, c, h, w = input.shape
|
||||
q = self.q_proj(self.norm_dec(input, cond))
|
||||
q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3)
|
||||
kv = self.kv_proj(self.norm_enc(cond[self.cond_key]))
|
||||
kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2)
|
||||
k, v = kv.chunk(2, dim=1)
|
||||
scale = k.shape[3] ** -0.25
|
||||
att = ((q * scale) @ (k.transpose(2, 3) * scale))
|
||||
att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000
|
||||
att = att.softmax(3)
|
||||
att = self.dropout(att)
|
||||
y = (att @ v).transpose(2, 3)
|
||||
y = y.contiguous().view([n, c, h, w])
|
||||
return input + self.out_proj(y)
|
||||
|
||||
|
||||
# Downsampling/upsampling
|
||||
|
||||
_kernels = {
|
||||
'linear':
|
||||
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||
'cubic':
|
||||
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
|
||||
0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
||||
'lanczos3':
|
||||
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
|
||||
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
|
||||
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
|
||||
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
|
||||
}
|
||||
_kernels['bilinear'] = _kernels['linear']
|
||||
_kernels['bicubic'] = _kernels['cubic']
|
||||
|
||||
|
||||
class Downsample2d(nn.Module):
|
||||
def __init__(self, kernel='linear', pad_mode='reflect'):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([_kernels[kernel]])
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
|
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
return F.conv2d(x, weight, stride=2)
|
||||
|
||||
|
||||
class Upsample2d(nn.Module):
|
||||
def __init__(self, kernel='linear', pad_mode='reflect'):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([_kernels[kernel]]) * 2
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
|
||||
|
||||
# Embeddings
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, in_features, out_features, std=1.):
|
||||
super().__init__()
|
||||
assert out_features % 2 == 0
|
||||
self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std)
|
||||
|
||||
def forward(self, input):
|
||||
f = 2 * math.pi * input @ self.weight.T
|
||||
return torch.cat([f.cos(), f.sin()], dim=-1)
|
||||
|
||||
|
||||
# U-Nets
|
||||
|
||||
class UNet(ConditionedModule):
|
||||
def __init__(self, d_blocks, u_blocks, skip_stages=0):
|
||||
super().__init__()
|
||||
self.d_blocks = nn.ModuleList(d_blocks)
|
||||
self.u_blocks = nn.ModuleList(u_blocks)
|
||||
self.skip_stages = skip_stages
|
||||
|
||||
def forward(self, input, cond):
|
||||
skips = []
|
||||
for block in self.d_blocks[self.skip_stages:]:
|
||||
input = block(input, cond)
|
||||
skips.append(input)
|
||||
for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))):
|
||||
input = block(input, cond, skip if i > 0 else None)
|
||||
return input
|
||||
1
comfy/k_diffusion/models/__init__.py
Normal file
1
comfy/k_diffusion/models/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .image_v1 import ImageDenoiserModelV1
|
||||
156
comfy/k_diffusion/models/image_v1.py
Normal file
156
comfy/k_diffusion/models/image_v1.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .. import layers, utils
|
||||
|
||||
|
||||
def orthogonal_(module):
|
||||
nn.init.orthogonal_(module.weight)
|
||||
return module
|
||||
|
||||
|
||||
class ResConvBlock(layers.ConditionedResidualBlock):
|
||||
def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.):
|
||||
skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False))
|
||||
super().__init__(
|
||||
layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(c_in, c_mid, 3, padding=1),
|
||||
nn.Dropout2d(dropout_rate, inplace=True),
|
||||
layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(c_mid, c_out, 3, padding=1),
|
||||
nn.Dropout2d(dropout_rate, inplace=True),
|
||||
skip=skip)
|
||||
|
||||
|
||||
class DBlock(layers.ConditionedSequential):
|
||||
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0):
|
||||
modules = [nn.Identity()]
|
||||
for i in range(n_layers):
|
||||
my_c_in = c_in if i == 0 else c_mid
|
||||
my_c_out = c_mid if i < n_layers - 1 else c_out
|
||||
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
|
||||
if self_attn:
|
||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
||||
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
|
||||
if cross_attn:
|
||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
||||
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
|
||||
super().__init__(*modules)
|
||||
self.set_downsample(downsample)
|
||||
|
||||
def set_downsample(self, downsample):
|
||||
self[0] = layers.Downsample2d() if downsample else nn.Identity()
|
||||
return self
|
||||
|
||||
|
||||
class UBlock(layers.ConditionedSequential):
|
||||
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0):
|
||||
modules = []
|
||||
for i in range(n_layers):
|
||||
my_c_in = c_in if i == 0 else c_mid
|
||||
my_c_out = c_mid if i < n_layers - 1 else c_out
|
||||
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
|
||||
if self_attn:
|
||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
||||
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
|
||||
if cross_attn:
|
||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
||||
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
|
||||
modules.append(nn.Identity())
|
||||
super().__init__(*modules)
|
||||
self.set_upsample(upsample)
|
||||
|
||||
def forward(self, input, cond, skip=None):
|
||||
if skip is not None:
|
||||
input = torch.cat([input, skip], dim=1)
|
||||
return super().forward(input, cond)
|
||||
|
||||
def set_upsample(self, upsample):
|
||||
self[-1] = layers.Upsample2d() if upsample else nn.Identity()
|
||||
return self
|
||||
|
||||
|
||||
class MappingNet(nn.Sequential):
|
||||
def __init__(self, feats_in, feats_out, n_layers=2):
|
||||
layers = []
|
||||
for i in range(n_layers):
|
||||
layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out)))
|
||||
layers.append(nn.GELU())
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
class ImageDenoiserModelV1(nn.Module):
|
||||
def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False):
|
||||
super().__init__()
|
||||
self.c_in = c_in
|
||||
self.channels = channels
|
||||
self.unet_cond_dim = unet_cond_dim
|
||||
self.patch_size = patch_size
|
||||
self.has_variance = has_variance
|
||||
self.timestep_embed = layers.FourierFeatures(1, feats_in)
|
||||
if mapping_cond_dim > 0:
|
||||
self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False)
|
||||
self.mapping = MappingNet(feats_in, feats_in)
|
||||
self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1)
|
||||
self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
|
||||
nn.init.zeros_(self.proj_out.weight)
|
||||
nn.init.zeros_(self.proj_out.bias)
|
||||
if cross_cond_dim == 0:
|
||||
cross_attn_depths = [False] * len(self_attn_depths)
|
||||
d_blocks, u_blocks = [], []
|
||||
for i in range(len(depths)):
|
||||
my_c_in = channels[max(0, i - 1)]
|
||||
d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
|
||||
for i in range(len(depths)):
|
||||
my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i]
|
||||
my_c_out = channels[max(0, i - 1)]
|
||||
u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
|
||||
self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages)
|
||||
|
||||
def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False):
|
||||
c_noise = sigma.log() / 4
|
||||
timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2))
|
||||
mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond)
|
||||
mapping_out = self.mapping(timestep_embed + mapping_cond_embed)
|
||||
cond = {'cond': mapping_out}
|
||||
if unet_cond is not None:
|
||||
input = torch.cat([input, unet_cond], dim=1)
|
||||
if cross_cond is not None:
|
||||
cond['cross'] = cross_cond
|
||||
cond['cross_padding'] = cross_cond_padding
|
||||
if self.patch_size > 1:
|
||||
input = F.pixel_unshuffle(input, self.patch_size)
|
||||
input = self.proj_in(input)
|
||||
input = self.u_net(input, cond)
|
||||
input = self.proj_out(input)
|
||||
if self.has_variance:
|
||||
input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1)
|
||||
if self.patch_size > 1:
|
||||
input = F.pixel_shuffle(input, self.patch_size)
|
||||
if self.has_variance and return_variance:
|
||||
return input, logvar
|
||||
return input
|
||||
|
||||
def set_skip_stages(self, skip_stages):
|
||||
self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1)
|
||||
self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1)
|
||||
nn.init.zeros_(self.proj_out.weight)
|
||||
nn.init.zeros_(self.proj_out.bias)
|
||||
self.u_net.skip_stages = skip_stages
|
||||
for i, block in enumerate(self.u_net.d_blocks):
|
||||
block.set_downsample(i > skip_stages)
|
||||
for i, block in enumerate(reversed(self.u_net.u_blocks)):
|
||||
block.set_upsample(i > skip_stages)
|
||||
return self
|
||||
|
||||
def set_patch_size(self, patch_size):
|
||||
self.patch_size = patch_size
|
||||
self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1)
|
||||
self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
|
||||
nn.init.zeros_(self.proj_out.weight)
|
||||
nn.init.zeros_(self.proj_out.bias)
|
||||
607
comfy/k_diffusion/sampling.py
Normal file
607
comfy/k_diffusion/sampling.py
Normal file
@@ -0,0 +1,607 @@
|
||||
import math
|
||||
|
||||
from scipy import integrate
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchdiffeq import odeint
|
||||
import torchsde
|
||||
from tqdm.auto import trange, tqdm
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
|
||||
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
ramp = torch.linspace(0, 1, n, device=device)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return append_zero(sigmas).to(device)
|
||||
|
||||
|
||||
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
|
||||
"""Constructs an exponential noise schedule."""
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
|
||||
"""Constructs an polynomial in log sigma noise schedule."""
|
||||
ramp = torch.linspace(1, 0, n, device=device) ** rho
|
||||
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
||||
"""Constructs a continuous VP noise schedule."""
|
||||
t = torch.linspace(1, eps_s, n, device=device)
|
||||
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
||||
|
||||
|
||||
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
||||
"""Calculates the noise level (sigma_down) to step down to and the amount
|
||||
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
||||
if not eta:
|
||||
return sigma_to, 0.
|
||||
sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
|
||||
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
def default_noise_sampler(x):
|
||||
return lambda sigma, sigma_next: torch.randn_like(x)
|
||||
|
||||
|
||||
class BatchedBrownianTree:
|
||||
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
||||
|
||||
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
||||
t0, t1, self.sign = self.sort(t0, t1)
|
||||
w0 = kwargs.get('w0', torch.zeros_like(x))
|
||||
if seed is None:
|
||||
seed = torch.randint(0, 2 ** 63 - 1, []).item()
|
||||
self.batched = True
|
||||
try:
|
||||
assert len(seed) == x.shape[0]
|
||||
w0 = w0[0]
|
||||
except TypeError:
|
||||
seed = [seed]
|
||||
self.batched = False
|
||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
||||
|
||||
@staticmethod
|
||||
def sort(a, b):
|
||||
return (a, b, 1) if a < b else (b, a, -1)
|
||||
|
||||
def __call__(self, t0, t1):
|
||||
t0, t1, sign = self.sort(t0, t1)
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
||||
return w if self.batched else w[0]
|
||||
|
||||
|
||||
class BrownianTreeNoiseSampler:
|
||||
"""A noise sampler backed by a torchsde.BrownianTree.
|
||||
|
||||
Args:
|
||||
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
||||
random samples.
|
||||
sigma_min (float): The low end of the valid interval.
|
||||
sigma_max (float): The high end of the valid interval.
|
||||
seed (int or List[int]): The random seed. If a list of seeds is
|
||||
supplied instead of a single integer, then the noise sampler will
|
||||
use one BrownianTree per batch item, each with its own seed.
|
||||
transform (callable): A function that maps sigma to the sampler's
|
||||
internal timestep.
|
||||
"""
|
||||
|
||||
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
|
||||
self.transform = transform
|
||||
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
|
||||
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
||||
|
||||
def __call__(self, sigma, sigma_next):
|
||||
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
|
||||
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver-2
|
||||
sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
|
||||
dt_1 = sigma_mid - sigma_hat
|
||||
dt_2 = sigmas[i + 1] - sigma_hat
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver-2
|
||||
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j):
|
||||
if order - 1 > i:
|
||||
raise ValueError(f'Order {order} too high for step {i}')
|
||||
def fn(tau):
|
||||
prod = 1.
|
||||
for k in range(order):
|
||||
if j == k:
|
||||
continue
|
||||
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
||||
return prod
|
||||
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigmas_cpu = sigmas.detach().cpu().numpy()
|
||||
ds = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
ds.append(d)
|
||||
if len(ds) > order:
|
||||
ds.pop(0)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
v = torch.randint_like(x, 2) * 2 - 1
|
||||
fevals = 0
|
||||
def ode_fn(sigma, x):
|
||||
nonlocal fevals
|
||||
with torch.enable_grad():
|
||||
x = x[0].detach().requires_grad_()
|
||||
denoised = model(x, sigma * s_in, **extra_args)
|
||||
d = to_d(x, sigma, denoised)
|
||||
fevals += 1
|
||||
grad = torch.autograd.grad((d * v).sum(), x)[0]
|
||||
d_ll = (v * grad).flatten(1).sum(1)
|
||||
return d.detach(), d_ll
|
||||
x_min = x, x.new_zeros([x.shape[0]])
|
||||
t = x.new_tensor([sigma_min, sigma_max])
|
||||
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
|
||||
latent, delta_ll = sol[0][-1], sol[1][-1]
|
||||
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
|
||||
return ll_prior + delta_ll, {'fevals': fevals}
|
||||
|
||||
|
||||
class PIDStepSizeController:
|
||||
"""A PID controller for ODE adaptive step size control."""
|
||||
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
||||
self.h = h
|
||||
self.b1 = (pcoeff + icoeff + dcoeff) / order
|
||||
self.b2 = -(pcoeff + 2 * dcoeff) / order
|
||||
self.b3 = dcoeff / order
|
||||
self.accept_safety = accept_safety
|
||||
self.eps = eps
|
||||
self.errs = []
|
||||
|
||||
def limiter(self, x):
|
||||
return 1 + math.atan(x - 1)
|
||||
|
||||
def propose_step(self, error):
|
||||
inv_error = 1 / (float(error) + self.eps)
|
||||
if not self.errs:
|
||||
self.errs = [inv_error, inv_error, inv_error]
|
||||
self.errs[0] = inv_error
|
||||
factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
|
||||
factor = self.limiter(factor)
|
||||
accept = factor >= self.accept_safety
|
||||
if accept:
|
||||
self.errs[2] = self.errs[1]
|
||||
self.errs[1] = self.errs[0]
|
||||
self.h *= factor
|
||||
return accept
|
||||
|
||||
|
||||
class DPMSolver(nn.Module):
|
||||
"""DPM-Solver. See https://arxiv.org/abs/2206.00927."""
|
||||
|
||||
def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.extra_args = {} if extra_args is None else extra_args
|
||||
self.eps_callback = eps_callback
|
||||
self.info_callback = info_callback
|
||||
|
||||
def t(self, sigma):
|
||||
return -sigma.log()
|
||||
|
||||
def sigma(self, t):
|
||||
return t.neg().exp()
|
||||
|
||||
def eps(self, eps_cache, key, x, t, *args, **kwargs):
|
||||
if key in eps_cache:
|
||||
return eps_cache[key], eps_cache
|
||||
sigma = self.sigma(t) * x.new_ones([x.shape[0]])
|
||||
eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
|
||||
if self.eps_callback is not None:
|
||||
self.eps_callback()
|
||||
return eps, {key: eps, **eps_cache}
|
||||
|
||||
def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
|
||||
eps_cache = {} if eps_cache is None else eps_cache
|
||||
h = t_next - t
|
||||
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
||||
x_1 = x - self.sigma(t_next) * h.expm1() * eps
|
||||
return x_1, eps_cache
|
||||
|
||||
def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
|
||||
eps_cache = {} if eps_cache is None else eps_cache
|
||||
h = t_next - t
|
||||
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
||||
s1 = t + r1 * h
|
||||
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
|
||||
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
|
||||
x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
|
||||
return x_2, eps_cache
|
||||
|
||||
def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
|
||||
eps_cache = {} if eps_cache is None else eps_cache
|
||||
h = t_next - t
|
||||
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
||||
s1 = t + r1 * h
|
||||
s2 = t + r2 * h
|
||||
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
|
||||
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
|
||||
u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
|
||||
eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
|
||||
x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
|
||||
return x_3, eps_cache
|
||||
|
||||
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
if not t_end > t_start and eta:
|
||||
raise ValueError('eta must be 0 for reverse sampling')
|
||||
|
||||
m = math.floor(nfe / 3) + 1
|
||||
ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
|
||||
|
||||
if nfe % 3 == 0:
|
||||
orders = [3] * (m - 2) + [2, 1]
|
||||
else:
|
||||
orders = [3] * (m - 1) + [nfe % 3]
|
||||
|
||||
for i in range(len(orders)):
|
||||
eps_cache = {}
|
||||
t, t_next = ts[i], ts[i + 1]
|
||||
if eta:
|
||||
sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
|
||||
t_next_ = torch.minimum(t_end, self.t(sd))
|
||||
su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
|
||||
else:
|
||||
t_next_, su = t_next, 0.
|
||||
|
||||
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
||||
denoised = x - self.sigma(t) * eps
|
||||
if self.info_callback is not None:
|
||||
self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
|
||||
|
||||
if orders[i] == 1:
|
||||
x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
|
||||
elif orders[i] == 2:
|
||||
x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
|
||||
else:
|
||||
x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
|
||||
|
||||
x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
|
||||
|
||||
return x
|
||||
|
||||
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
if order not in {2, 3}:
|
||||
raise ValueError('order should be 2 or 3')
|
||||
forward = t_end > t_start
|
||||
if not forward and eta:
|
||||
raise ValueError('eta must be 0 for reverse sampling')
|
||||
h_init = abs(h_init) * (1 if forward else -1)
|
||||
atol = torch.tensor(atol)
|
||||
rtol = torch.tensor(rtol)
|
||||
s = t_start
|
||||
x_prev = x
|
||||
accept = True
|
||||
pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
|
||||
info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
|
||||
|
||||
while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
|
||||
eps_cache = {}
|
||||
t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
|
||||
if eta:
|
||||
sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
|
||||
t_ = torch.minimum(t_end, self.t(sd))
|
||||
su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
|
||||
else:
|
||||
t_, su = t, 0.
|
||||
|
||||
eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
|
||||
denoised = x - self.sigma(s) * eps
|
||||
|
||||
if order == 2:
|
||||
x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
|
||||
x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
|
||||
else:
|
||||
x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
|
||||
x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
|
||||
delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
|
||||
error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
|
||||
accept = pid.propose_step(error)
|
||||
if accept:
|
||||
x_prev = x_low
|
||||
x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
|
||||
s = t
|
||||
info['n_accept'] += 1
|
||||
else:
|
||||
info['n_reject'] += 1
|
||||
info['nfe'] += order
|
||||
info['steps'] += 1
|
||||
|
||||
if self.info_callback is not None:
|
||||
self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
|
||||
|
||||
return x, info
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
|
||||
"""DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
|
||||
if sigma_min <= 0 or sigma_max <= 0:
|
||||
raise ValueError('sigma_min and sigma_max must not be 0')
|
||||
with tqdm(total=n, disable=disable) as pbar:
|
||||
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
|
||||
if callback is not None:
|
||||
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
|
||||
return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
|
||||
"""DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
|
||||
if sigma_min <= 0 or sigma_max <= 0:
|
||||
raise ValueError('sigma_min and sigma_max must not be 0')
|
||||
with tqdm(disable=disable) as pbar:
|
||||
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
|
||||
if callback is not None:
|
||||
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
|
||||
x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
|
||||
if return_info:
|
||||
return x, info
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver++(2S)
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
||||
r = 1 / 2
|
||||
h = t_next - t
|
||||
s = t + r * h
|
||||
x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
"""DPM-Solver++ (stochastic)."""
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver++
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
s = t + h * r
|
||||
fac = 1 / (2 * r)
|
||||
|
||||
# Step 1
|
||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
||||
s_ = t_fn(sd)
|
||||
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
|
||||
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
|
||||
t_next_ = t_fn(sd)
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
|
||||
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""DPM-Solver++(2M)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
old_denoised = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
if old_denoised is None or sigmas[i + 1] == 0:
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
||||
else:
|
||||
h_last = t - t_fn(sigmas[i - 1])
|
||||
r = h_last / h
|
||||
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
||||
old_denoised = denoised
|
||||
return x
|
||||
332
comfy/k_diffusion/utils.py
Normal file
332
comfy/k_diffusion/utils.py
Normal file
@@ -0,0 +1,332 @@
|
||||
from contextlib import contextmanager
|
||||
import hashlib
|
||||
import math
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import urllib
|
||||
import warnings
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from torch.utils import data
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
|
||||
def from_pil_image(x):
|
||||
"""Converts from a PIL image to a tensor."""
|
||||
x = TF.to_tensor(x)
|
||||
if x.ndim == 2:
|
||||
x = x[..., None]
|
||||
return x * 2 - 1
|
||||
|
||||
|
||||
def to_pil_image(x):
|
||||
"""Converts from a tensor to a PIL image."""
|
||||
if x.ndim == 4:
|
||||
assert x.shape[0] == 1
|
||||
x = x[0]
|
||||
if x.shape[0] == 1:
|
||||
x = x[0]
|
||||
return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
|
||||
|
||||
|
||||
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
|
||||
"""Apply passed in transforms for HuggingFace Datasets."""
|
||||
images = [transform(image.convert(mode)) for image in examples[image_key]]
|
||||
return {image_key: images}
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
||||
expanded = x[(...,) + (None,) * dims_to_append]
|
||||
# MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
|
||||
# https://github.com/pytorch/pytorch/issues/84364
|
||||
return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
|
||||
|
||||
|
||||
def n_params(module):
|
||||
"""Returns the number of trainable parameters in a module."""
|
||||
return sum(p.numel() for p in module.parameters())
|
||||
|
||||
|
||||
def download_file(path, url, digest=None):
|
||||
"""Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not path.exists():
|
||||
with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
|
||||
shutil.copyfileobj(response, f)
|
||||
if digest is not None:
|
||||
file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
|
||||
if digest != file_digest:
|
||||
raise OSError(f'hash of {path} (url: {url}) failed to validate')
|
||||
return path
|
||||
|
||||
|
||||
@contextmanager
|
||||
def train_mode(model, mode=True):
|
||||
"""A context manager that places a model into training mode and restores
|
||||
the previous mode on exit."""
|
||||
modes = [module.training for module in model.modules()]
|
||||
try:
|
||||
yield model.train(mode)
|
||||
finally:
|
||||
for i, module in enumerate(model.modules()):
|
||||
module.training = modes[i]
|
||||
|
||||
|
||||
def eval_mode(model):
|
||||
"""A context manager that places a model into evaluation mode and restores
|
||||
the previous mode on exit."""
|
||||
return train_mode(model, False)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ema_update(model, averaged_model, decay):
|
||||
"""Incorporates updated model parameters into an exponential moving averaged
|
||||
version of a model. It should be called after each optimizer step."""
|
||||
model_params = dict(model.named_parameters())
|
||||
averaged_params = dict(averaged_model.named_parameters())
|
||||
assert model_params.keys() == averaged_params.keys()
|
||||
|
||||
for name, param in model_params.items():
|
||||
averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
|
||||
|
||||
model_buffers = dict(model.named_buffers())
|
||||
averaged_buffers = dict(averaged_model.named_buffers())
|
||||
assert model_buffers.keys() == averaged_buffers.keys()
|
||||
|
||||
for name, buf in model_buffers.items():
|
||||
averaged_buffers[name].copy_(buf)
|
||||
|
||||
|
||||
class EMAWarmup:
|
||||
"""Implements an EMA warmup using an inverse decay schedule.
|
||||
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
|
||||
good values for models you plan to train for a million or more steps (reaches decay
|
||||
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
|
||||
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
||||
215.4k steps).
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 1.
|
||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||
max_value (float): The maximum EMA decay rate. Default: 1.
|
||||
start_at (int): The epoch to start averaging at. Default: 0.
|
||||
last_epoch (int): The index of last epoch. Default: 0.
|
||||
"""
|
||||
|
||||
def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
|
||||
last_epoch=0):
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
self.start_at = start_at
|
||||
self.last_epoch = last_epoch
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the class as a :class:`dict`."""
|
||||
return dict(self.__dict__.items())
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the class's state.
|
||||
Args:
|
||||
state_dict (dict): scaler state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_value(self):
|
||||
"""Gets the current EMA decay rate."""
|
||||
epoch = max(0, self.last_epoch - self.start_at)
|
||||
value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
|
||||
return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
|
||||
|
||||
def step(self):
|
||||
"""Updates the step count."""
|
||||
self.last_epoch += 1
|
||||
|
||||
|
||||
class InverseLR(optim.lr_scheduler._LRScheduler):
|
||||
"""Implements an inverse decay learning rate schedule with an optional exponential
|
||||
warmup. When last_epoch=-1, sets initial lr as lr.
|
||||
inv_gamma is the number of steps/epochs required for the learning rate to decay to
|
||||
(1 / 2)**power of its original value.
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
|
||||
power (float): Exponential factor of learning rate decay. Default: 1.
|
||||
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
||||
Default: 0.
|
||||
min_lr (float): The minimum learning rate. Default: 0.
|
||||
last_epoch (int): The index of last epoch. Default: -1.
|
||||
verbose (bool): If ``True``, prints a message to stdout for
|
||||
each update. Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
|
||||
last_epoch=-1, verbose=False):
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
if not 0. <= warmup < 1:
|
||||
raise ValueError('Invalid value for warmup')
|
||||
self.warmup = warmup
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn("To get the last learning rate computed by the scheduler, "
|
||||
"please use `get_last_lr()`.")
|
||||
|
||||
return self._get_closed_form_lr()
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
||||
lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
|
||||
return [warmup * max(self.min_lr, base_lr * lr_mult)
|
||||
for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
class ExponentialLR(optim.lr_scheduler._LRScheduler):
|
||||
"""Implements an exponential learning rate schedule with an optional exponential
|
||||
warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
|
||||
continuously by decay (default 0.5) every num_steps steps.
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
num_steps (float): The number of steps to decay the learning rate by decay in.
|
||||
decay (float): The factor by which to decay the learning rate every num_steps
|
||||
steps. Default: 0.5.
|
||||
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
||||
Default: 0.
|
||||
min_lr (float): The minimum learning rate. Default: 0.
|
||||
last_epoch (int): The index of last epoch. Default: -1.
|
||||
verbose (bool): If ``True``, prints a message to stdout for
|
||||
each update. Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
|
||||
last_epoch=-1, verbose=False):
|
||||
self.num_steps = num_steps
|
||||
self.decay = decay
|
||||
if not 0. <= warmup < 1:
|
||||
raise ValueError('Invalid value for warmup')
|
||||
self.warmup = warmup
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn("To get the last learning rate computed by the scheduler, "
|
||||
"please use `get_last_lr()`.")
|
||||
|
||||
return self._get_closed_form_lr()
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
||||
lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
|
||||
return [warmup * max(self.min_lr, base_lr * lr_mult)
|
||||
for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from an lognormal distribution."""
|
||||
return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
|
||||
|
||||
|
||||
def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from an optionally truncated log-logistic distribution."""
|
||||
min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
|
||||
max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
|
||||
min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
|
||||
max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
|
||||
u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
|
||||
return u.logit().mul(scale).add(loc).exp().to(dtype)
|
||||
|
||||
|
||||
def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from an log-uniform distribution."""
|
||||
min_value = math.log(min_value)
|
||||
max_value = math.log(max_value)
|
||||
return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
|
||||
|
||||
|
||||
def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from a truncated v-diffusion training timestep distribution."""
|
||||
min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
|
||||
max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
|
||||
u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
|
||||
return torch.tan(u * math.pi / 2) * sigma_data
|
||||
|
||||
|
||||
def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from a split lognormal distribution."""
|
||||
n = torch.randn(shape, device=device, dtype=dtype).abs()
|
||||
u = torch.rand(shape, device=device, dtype=dtype)
|
||||
n_left = n * -scale_1 + loc
|
||||
n_right = n * scale_2 + loc
|
||||
ratio = scale_1 / (scale_1 + scale_2)
|
||||
return torch.where(u < ratio, n_left, n_right).exp()
|
||||
|
||||
|
||||
class FolderOfImages(data.Dataset):
|
||||
"""Recursively finds all images in a directory. It does not support
|
||||
classes/targets."""
|
||||
|
||||
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
|
||||
|
||||
def __init__(self, root, transform=None):
|
||||
super().__init__()
|
||||
self.root = Path(root)
|
||||
self.transform = nn.Identity() if transform is None else transform
|
||||
self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
|
||||
|
||||
def __repr__(self):
|
||||
return f'FolderOfImages(root="{self.root}", len: {len(self)})'
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, key):
|
||||
path = self.paths[key]
|
||||
with open(path, 'rb') as f:
|
||||
image = Image.open(f).convert('RGB')
|
||||
image = self.transform(image)
|
||||
return image,
|
||||
|
||||
|
||||
class CSVLogger:
|
||||
def __init__(self, filename, columns):
|
||||
self.filename = Path(filename)
|
||||
self.columns = columns
|
||||
if self.filename.exists():
|
||||
self.file = open(self.filename, 'a')
|
||||
else:
|
||||
self.file = open(self.filename, 'w')
|
||||
self.write(*self.columns)
|
||||
|
||||
def write(self, *args):
|
||||
print(*args, sep=',', file=self.file, flush=True)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tf32_mode(cudnn=None, matmul=None):
|
||||
"""A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
|
||||
cudnn_old = torch.backends.cudnn.allow_tf32
|
||||
matmul_old = torch.backends.cuda.matmul.allow_tf32
|
||||
try:
|
||||
if cudnn is not None:
|
||||
torch.backends.cudnn.allow_tf32 = cudnn
|
||||
if matmul is not None:
|
||||
torch.backends.cuda.matmul.allow_tf32 = matmul
|
||||
yield
|
||||
finally:
|
||||
if cudnn is not None:
|
||||
torch.backends.cudnn.allow_tf32 = cudnn_old
|
||||
if matmul is not None:
|
||||
torch.backends.cuda.matmul.allow_tf32 = matmul_old
|
||||
Reference in New Issue
Block a user