Initial commit.

This commit is contained in:
comfyanonymous
2023-01-03 01:53:32 -05:00
commit 220afe3310
77 changed files with 129040 additions and 0 deletions

View File

@@ -0,0 +1,105 @@
from functools import reduce
import math
import operator
import numpy as np
from skimage import transform
import torch
from torch import nn
def translate2d(tx, ty):
mat = [[1, 0, tx],
[0, 1, ty],
[0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)
def scale2d(sx, sy):
mat = [[sx, 0, 0],
[ 0, sy, 0],
[ 0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)
def rotate2d(theta):
mat = [[torch.cos(theta), torch.sin(-theta), 0],
[torch.sin(theta), torch.cos(theta), 0],
[ 0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)
class KarrasAugmentationPipeline:
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8):
self.a_prob = a_prob
self.a_scale = a_scale
self.a_aniso = a_aniso
self.a_trans = a_trans
def __call__(self, image):
h, w = image.size
mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)]
# x-flip
a0 = torch.randint(2, []).float()
mats.append(scale2d(1 - 2 * a0, 1))
# y-flip
do = (torch.rand([]) < self.a_prob).float()
a1 = torch.randint(2, []).float() * do
mats.append(scale2d(1, 1 - 2 * a1))
# scaling
do = (torch.rand([]) < self.a_prob).float()
a2 = torch.randn([]) * do
mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2))
# rotation
do = (torch.rand([]) < self.a_prob).float()
a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do
mats.append(rotate2d(-a3))
# anisotropy
do = (torch.rand([]) < self.a_prob).float()
a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do
a5 = torch.randn([]) * do
mats.append(rotate2d(a4))
mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5))
mats.append(rotate2d(-a4))
# translation
do = (torch.rand([]) < self.a_prob).float()
a6 = torch.randn([]) * do
a7 = torch.randn([]) * do
mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7))
# form the transformation matrix and conditioning vector
mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5))
mat = reduce(operator.matmul, mats)
cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7])
# apply the transformation
image_orig = np.array(image, dtype=np.float32) / 255
if image_orig.ndim == 2:
image_orig = image_orig[..., None]
tf = transform.AffineTransform(mat.numpy())
image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1
image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
return image, image_orig, cond
class KarrasAugmentWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs):
if aug_cond is None:
aug_cond = input.new_zeros([input.shape[0], 9])
if mapping_cond is None:
mapping_cond = aug_cond
else:
mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1)
return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs)
def set_skip_stages(self, skip_stages):
return self.inner_model.set_skip_stages(skip_stages)
def set_patch_size(self, patch_size):
return self.inner_model.set_patch_size(patch_size)

110
comfy/k_diffusion/config.py Normal file
View File

@@ -0,0 +1,110 @@
from functools import partial
import json
import math
import warnings
from jsonmerge import merge
from . import augmentation, layers, models, utils
def load_config(file):
defaults = {
'model': {
'sigma_data': 1.,
'patch_size': 1,
'dropout_rate': 0.,
'augment_wrapper': True,
'augment_prob': 0.,
'mapping_cond_dim': 0,
'unet_cond_dim': 0,
'cross_cond_dim': 0,
'cross_attn_depths': None,
'skip_stages': 0,
'has_variance': False,
},
'dataset': {
'type': 'imagefolder',
},
'optimizer': {
'type': 'adamw',
'lr': 1e-4,
'betas': [0.95, 0.999],
'eps': 1e-6,
'weight_decay': 1e-3,
},
'lr_sched': {
'type': 'inverse',
'inv_gamma': 20000.,
'power': 1.,
'warmup': 0.99,
},
'ema_sched': {
'type': 'inverse',
'power': 0.6667,
'max_value': 0.9999
},
}
config = json.load(file)
return merge(defaults, config)
def make_model(config):
config = config['model']
assert config['type'] == 'image_v1'
model = models.ImageDenoiserModelV1(
config['input_channels'],
config['mapping_out'],
config['depths'],
config['channels'],
config['self_attn_depths'],
config['cross_attn_depths'],
patch_size=config['patch_size'],
dropout_rate=config['dropout_rate'],
mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0),
unet_cond_dim=config['unet_cond_dim'],
cross_cond_dim=config['cross_cond_dim'],
skip_stages=config['skip_stages'],
has_variance=config['has_variance'],
)
if config['augment_wrapper']:
model = augmentation.KarrasAugmentWrapper(model)
return model
def make_denoiser_wrapper(config):
config = config['model']
sigma_data = config.get('sigma_data', 1.)
has_variance = config.get('has_variance', False)
if not has_variance:
return partial(layers.Denoiser, sigma_data=sigma_data)
return partial(layers.DenoiserWithVariance, sigma_data=sigma_data)
def make_sample_density(config):
sd_config = config['sigma_sample_density']
sigma_data = config['sigma_data']
if sd_config['type'] == 'lognormal':
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
scale = sd_config['std'] if 'std' in sd_config else sd_config['scale']
return partial(utils.rand_log_normal, loc=loc, scale=scale)
if sd_config['type'] == 'loglogistic':
loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data)
scale = sd_config['scale'] if 'scale' in sd_config else 0.5
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value)
if sd_config['type'] == 'loguniform':
min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min']
max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max']
return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value)
if sd_config['type'] == 'v-diffusion':
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value)
if sd_config['type'] == 'split-lognormal':
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1']
scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2']
return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2)
raise ValueError('Unknown sample density type')

View File

@@ -0,0 +1,134 @@
import math
import os
from pathlib import Path
from cleanfid.inception_torchscript import InceptionV3W
import clip
from resize_right import resize
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from tqdm.auto import trange
from . import utils
class InceptionV3FeatureExtractor(nn.Module):
def __init__(self, device='cpu'):
super().__init__()
path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion'
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4'
utils.download_file(path / 'inception-2015-12-05.pt', url, digest)
self.model = InceptionV3W(str(path), resize_inside=False).to(device)
self.size = (299, 299)
def forward(self, x):
if x.shape[2:4] != self.size:
x = resize(x, out_shape=self.size, pad_mode='reflect')
if x.shape[1] == 1:
x = torch.cat([x] * 3, dim=1)
x = (x * 127.5 + 127.5).clamp(0, 255)
return self.model(x)
class CLIPFeatureExtractor(nn.Module):
def __init__(self, name='ViT-L/14@336px', device='cpu'):
super().__init__()
self.model = clip.load(name, device=device)[0].eval().requires_grad_(False)
self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711))
self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution)
def forward(self, x):
if x.shape[2:4] != self.size:
x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1)
x = self.normalize(x)
x = self.model.encode_image(x).float()
x = F.normalize(x) * x.shape[1] ** 0.5
return x
def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size):
n_per_proc = math.ceil(n / accelerator.num_processes)
feats_all = []
try:
for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process):
cur_batch_size = min(n - i, batch_size)
samples = sample_fn(cur_batch_size)[:cur_batch_size]
feats_all.append(accelerator.gather(extractor_fn(samples)))
except StopIteration:
pass
return torch.cat(feats_all)[:n]
def polynomial_kernel(x, y):
d = x.shape[-1]
dot = x @ y.transpose(-2, -1)
return (dot / d + 1) ** 3
def squared_mmd(x, y, kernel=polynomial_kernel):
m = x.shape[-2]
n = y.shape[-2]
kxx = kernel(x, x)
kyy = kernel(y, y)
kxy = kernel(x, y)
kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1)
kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1)
kxy_sum = kxy.sum([-1, -2])
term_1 = kxx_sum / m / (m - 1)
term_2 = kyy_sum / n / (n - 1)
term_3 = kxy_sum * 2 / m / n
return term_1 + term_2 - term_3
@utils.tf32_mode(matmul=False)
def kid(x, y, max_size=5000):
x_size, y_size = x.shape[0], y.shape[0]
n_partitions = math.ceil(max(x_size / max_size, y_size / max_size))
total_mmd = x.new_zeros([])
for i in range(n_partitions):
cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)]
cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)]
total_mmd = total_mmd + squared_mmd(cur_x, cur_y)
return total_mmd / n_partitions
class _MatrixSquareRootEig(torch.autograd.Function):
@staticmethod
def forward(ctx, a):
vals, vecs = torch.linalg.eigh(a)
ctx.save_for_backward(vals, vecs)
return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1)
@staticmethod
def backward(ctx, grad_output):
vals, vecs = ctx.saved_tensors
d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1)
vecs_t = vecs.transpose(-2, -1)
return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t
def sqrtm_eig(a):
if a.ndim < 2:
raise RuntimeError('tensor of matrices must have at least 2 dimensions')
if a.shape[-2] != a.shape[-1]:
raise RuntimeError('tensor must be batches of square matrices')
return _MatrixSquareRootEig.apply(a)
@utils.tf32_mode(matmul=False)
def fid(x, y, eps=1e-8):
x_mean = x.mean(dim=0)
y_mean = y.mean(dim=0)
mean_term = (x_mean - y_mean).pow(2).sum()
x_cov = torch.cov(x.T)
y_cov = torch.cov(y.T)
eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps
x_cov = x_cov + eps_eye
y_cov = y_cov + eps_eye
x_cov_sqrt = sqrtm_eig(x_cov)
cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt))
return mean_term + cov_term

View File

@@ -0,0 +1,179 @@
import math
import torch
from torch import nn
from . import sampling, utils
class VDenoiser(nn.Module):
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
def __init__(self, inner_model):
super().__init__()
self.inner_model = inner_model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in
def sigma_to_t(self, sigma):
return sigma.atan() / math.pi * 2
def t_to_sigma(self, t):
return (t * math.pi / 2).tan()
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
class DiscreteSchedule(nn.Module):
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
levels."""
def __init__(self, sigmas, quantize):
super().__init__()
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
self.quantize = quantize
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def get_sigmas(self, n=None):
if n is None:
return sampling.append_zero(self.sigmas.flip(0))
t_max = len(self.sigmas) - 1
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
return sampling.append_zero(self.t_to_sigma(t))
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
log_sigma = sigma.log()
dists = log_sigma - self.log_sigmas[:, None]
if quantize:
return dists.abs().argmin(dim=0).view(sigma.shape)
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)
def t_to_sigma(self, t):
t = t.float()
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t-low_idx if t.device.type == 'mps' else t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
noise)."""
def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_out = -sigma
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_out, c_in
def get_eps(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def loss(self, input, noise, sigma, **kwargs):
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
return (eps - noise).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
return input + eps * c_out
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for OpenAI diffusion models."""
def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
super().__init__(model, alphas_cumprod, quantize=quantize)
self.has_learned_sigmas = has_learned_sigmas
def get_eps(self, *args, **kwargs):
model_output = self.inner_model(*args, **kwargs)
if self.has_learned_sigmas:
return model_output.chunk(2, dim=1)[0]
return model_output
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for CompVis diffusion models."""
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_eps(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)
class DiscreteVDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output v."""
def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in
def get_v(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
class CompVisVDenoiser(DiscreteVDDPMDenoiser):
"""A wrapper for CompVis diffusion models that output v."""
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond)

99
comfy/k_diffusion/gns.py Normal file
View File

@@ -0,0 +1,99 @@
import torch
from torch import nn
class DDPGradientStatsHook:
def __init__(self, ddp_module):
try:
ddp_module.register_comm_hook(self, self._hook_fn)
except AttributeError:
raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules')
self._clear_state()
def _clear_state(self):
self.bucket_sq_norms_small_batch = []
self.bucket_sq_norms_large_batch = []
@staticmethod
def _hook_fn(self, bucket):
buf = bucket.buffer()
self.bucket_sq_norms_small_batch.append(buf.pow(2).sum())
fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future()
def callback(fut):
buf = fut.value()[0]
self.bucket_sq_norms_large_batch.append(buf.pow(2).sum())
return buf
return fut.then(callback)
def get_stats(self):
sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch)
sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch)
self._clear_state()
stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch])
torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG)
return stats[0].item(), stats[1].item()
class GradientNoiseScale:
"""Calculates the gradient noise scale (1 / SNR), or critical batch size,
from _An Empirical Model of Large-Batch Training_,
https://arxiv.org/abs/1812.06162).
Args:
beta (float): The decay factor for the exponential moving averages used to
calculate the gradient noise scale.
Default: 0.9998
eps (float): Added for numerical stability.
Default: 1e-8
"""
def __init__(self, beta=0.9998, eps=1e-8):
self.beta = beta
self.eps = eps
self.ema_sq_norm = 0.
self.ema_var = 0.
self.beta_cumprod = 1.
self.gradient_noise_scale = float('nan')
def state_dict(self):
"""Returns the state of the object as a :class:`dict`."""
return dict(self.__dict__.items())
def load_state_dict(self, state_dict):
"""Loads the object's state.
Args:
state_dict (dict): object state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch):
"""Updates the state with a new batch's gradient statistics, and returns the
current gradient noise scale.
Args:
sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or
per sample gradients.
sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or
per sample gradients.
n_small_batch (int): The batch size of the individual microbatch or per sample
gradients (1 if per sample).
n_large_batch (int): The total batch size of the mean of the microbatch or
per sample gradients.
"""
est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch)
est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch)
self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm
self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var
self.beta_cumprod *= self.beta
self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps)
return self.gradient_noise_scale
def get_gns(self):
"""Returns the current gradient noise scale."""
return self.gradient_noise_scale
def get_stats(self):
"""Returns the current (debiased) estimates of the squared mean gradient
and gradient variance."""
return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod)

246
comfy/k_diffusion/layers.py Normal file
View File

@@ -0,0 +1,246 @@
import math
from einops import rearrange, repeat
import torch
from torch import nn
from torch.nn import functional as F
from . import utils
# Karras et al. preconditioned denoiser
class Denoiser(nn.Module):
"""A Karras et al. preconditioner for denoising diffusion models."""
def __init__(self, inner_model, sigma_data=1.):
super().__init__()
self.inner_model = inner_model
self.sigma_data = sigma_data
def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.inner_model(noised_input * c_in, sigma, **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip
class DenoiserWithVariance(Denoiser):
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs)
logvar = utils.append_dims(logvar, model_output.ndim)
target = (input - c_skip * noised_input) / c_out
losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2
return losses.flatten(1).mean(1)
# Residual blocks
class ResidualBlock(nn.Module):
def __init__(self, *main, skip=None):
super().__init__()
self.main = nn.Sequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input):
return self.main(input) + self.skip(input)
# Noise level (and other) conditioning
class ConditionedModule(nn.Module):
pass
class UnconditionedModule(ConditionedModule):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, input, cond=None):
return self.module(input)
class ConditionedSequential(nn.Sequential, ConditionedModule):
def forward(self, input, cond):
for module in self:
if isinstance(module, ConditionedModule):
input = module(input, cond)
else:
input = module(input)
return input
class ConditionedResidualBlock(ConditionedModule):
def __init__(self, *main, skip=None):
super().__init__()
self.main = ConditionedSequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input, cond):
skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input)
return self.main(input, cond) + skip
class AdaGN(ConditionedModule):
def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'):
super().__init__()
self.num_groups = num_groups
self.eps = eps
self.cond_key = cond_key
self.mapper = nn.Linear(feats_in, c_out * 2)
def forward(self, input, cond):
weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1)
input = F.group_norm(input, self.num_groups, eps=self.eps)
return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1)
# Attention
class SelfAttention2d(ConditionedModule):
def __init__(self, c_in, n_head, norm, dropout_rate=0.):
super().__init__()
assert c_in % n_head == 0
self.norm_in = norm(c_in)
self.n_head = n_head
self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
self.out_proj = nn.Conv2d(c_in, c_in, 1)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, input, cond):
n, c, h, w = input.shape
qkv = self.qkv_proj(self.norm_in(input, cond))
qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
q, k, v = qkv.chunk(3, dim=1)
scale = k.shape[3] ** -0.25
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
att = self.dropout(att)
y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
return input + self.out_proj(y)
class CrossAttention2d(ConditionedModule):
def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0.,
cond_key='cross', cond_key_padding='cross_padding'):
super().__init__()
assert c_dec % n_head == 0
self.cond_key = cond_key
self.cond_key_padding = cond_key_padding
self.norm_enc = nn.LayerNorm(c_enc)
self.norm_dec = norm_dec(c_dec)
self.n_head = n_head
self.q_proj = nn.Conv2d(c_dec, c_dec, 1)
self.kv_proj = nn.Linear(c_enc, c_dec * 2)
self.out_proj = nn.Conv2d(c_dec, c_dec, 1)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, input, cond):
n, c, h, w = input.shape
q = self.q_proj(self.norm_dec(input, cond))
q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3)
kv = self.kv_proj(self.norm_enc(cond[self.cond_key]))
kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2)
k, v = kv.chunk(2, dim=1)
scale = k.shape[3] ** -0.25
att = ((q * scale) @ (k.transpose(2, 3) * scale))
att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000
att = att.softmax(3)
att = self.dropout(att)
y = (att @ v).transpose(2, 3)
y = y.contiguous().view([n, c, h, w])
return input + self.out_proj(y)
# Downsampling/upsampling
_kernels = {
'linear':
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
'cubic':
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
0.43359375, 0.11328125, -0.03515625, -0.01171875],
'lanczos3':
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
}
_kernels['bilinear'] = _kernels['linear']
_kernels['bicubic'] = _kernels['cubic']
class Downsample2d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect'):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([_kernels[kernel]])
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
def forward(self, x):
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv2d(x, weight, stride=2)
class Upsample2d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect'):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([_kernels[kernel]]) * 2
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
def forward(self, x):
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
# Embeddings
class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1.):
super().__init__()
assert out_features % 2 == 0
self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std)
def forward(self, input):
f = 2 * math.pi * input @ self.weight.T
return torch.cat([f.cos(), f.sin()], dim=-1)
# U-Nets
class UNet(ConditionedModule):
def __init__(self, d_blocks, u_blocks, skip_stages=0):
super().__init__()
self.d_blocks = nn.ModuleList(d_blocks)
self.u_blocks = nn.ModuleList(u_blocks)
self.skip_stages = skip_stages
def forward(self, input, cond):
skips = []
for block in self.d_blocks[self.skip_stages:]:
input = block(input, cond)
skips.append(input)
for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))):
input = block(input, cond, skip if i > 0 else None)
return input

View File

@@ -0,0 +1 @@
from .image_v1 import ImageDenoiserModelV1

View File

@@ -0,0 +1,156 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from .. import layers, utils
def orthogonal_(module):
nn.init.orthogonal_(module.weight)
return module
class ResConvBlock(layers.ConditionedResidualBlock):
def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.):
skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False))
super().__init__(
layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)),
nn.GELU(),
nn.Conv2d(c_in, c_mid, 3, padding=1),
nn.Dropout2d(dropout_rate, inplace=True),
layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)),
nn.GELU(),
nn.Conv2d(c_mid, c_out, 3, padding=1),
nn.Dropout2d(dropout_rate, inplace=True),
skip=skip)
class DBlock(layers.ConditionedSequential):
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0):
modules = [nn.Identity()]
for i in range(n_layers):
my_c_in = c_in if i == 0 else c_mid
my_c_out = c_mid if i < n_layers - 1 else c_out
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
if self_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
if cross_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
super().__init__(*modules)
self.set_downsample(downsample)
def set_downsample(self, downsample):
self[0] = layers.Downsample2d() if downsample else nn.Identity()
return self
class UBlock(layers.ConditionedSequential):
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0):
modules = []
for i in range(n_layers):
my_c_in = c_in if i == 0 else c_mid
my_c_out = c_mid if i < n_layers - 1 else c_out
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
if self_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
if cross_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
modules.append(nn.Identity())
super().__init__(*modules)
self.set_upsample(upsample)
def forward(self, input, cond, skip=None):
if skip is not None:
input = torch.cat([input, skip], dim=1)
return super().forward(input, cond)
def set_upsample(self, upsample):
self[-1] = layers.Upsample2d() if upsample else nn.Identity()
return self
class MappingNet(nn.Sequential):
def __init__(self, feats_in, feats_out, n_layers=2):
layers = []
for i in range(n_layers):
layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out)))
layers.append(nn.GELU())
super().__init__(*layers)
class ImageDenoiserModelV1(nn.Module):
def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False):
super().__init__()
self.c_in = c_in
self.channels = channels
self.unet_cond_dim = unet_cond_dim
self.patch_size = patch_size
self.has_variance = has_variance
self.timestep_embed = layers.FourierFeatures(1, feats_in)
if mapping_cond_dim > 0:
self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False)
self.mapping = MappingNet(feats_in, feats_in)
self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1)
self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
if cross_cond_dim == 0:
cross_attn_depths = [False] * len(self_attn_depths)
d_blocks, u_blocks = [], []
for i in range(len(depths)):
my_c_in = channels[max(0, i - 1)]
d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
for i in range(len(depths)):
my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i]
my_c_out = channels[max(0, i - 1)]
u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages)
def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False):
c_noise = sigma.log() / 4
timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2))
mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond)
mapping_out = self.mapping(timestep_embed + mapping_cond_embed)
cond = {'cond': mapping_out}
if unet_cond is not None:
input = torch.cat([input, unet_cond], dim=1)
if cross_cond is not None:
cond['cross'] = cross_cond
cond['cross_padding'] = cross_cond_padding
if self.patch_size > 1:
input = F.pixel_unshuffle(input, self.patch_size)
input = self.proj_in(input)
input = self.u_net(input, cond)
input = self.proj_out(input)
if self.has_variance:
input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1)
if self.patch_size > 1:
input = F.pixel_shuffle(input, self.patch_size)
if self.has_variance and return_variance:
return input, logvar
return input
def set_skip_stages(self, skip_stages):
self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1)
self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
self.u_net.skip_stages = skip_stages
for i, block in enumerate(self.u_net.d_blocks):
block.set_downsample(i > skip_stages)
for i, block in enumerate(reversed(self.u_net.u_blocks)):
block.set_upsample(i > skip_stages)
return self
def set_patch_size(self, patch_size):
self.patch_size = patch_size
self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1)
self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)

View File

@@ -0,0 +1,607 @@
import math
from scipy import integrate
import torch
from torch import nn
from torchdiffeq import odeint
import torchsde
from tqdm.auto import trange, tqdm
from . import utils
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
"""Constructs the noise schedule of Karras et al. (2022)."""
ramp = torch.linspace(0, 1, n, device=device)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return append_zero(sigmas).to(device)
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
"""Constructs an exponential noise schedule."""
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
return append_zero(sigmas)
def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
"""Constructs an polynomial in log sigma noise schedule."""
ramp = torch.linspace(1, 0, n, device=device) ** rho
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
return append_zero(sigmas)
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule."""
t = torch.linspace(1, eps_s, n, device=device)
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
return append_zero(sigmas)
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / utils.append_dims(sigma, x.ndim)
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
if not eta:
return sigma_to, 0.
sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
return sigma_down, sigma_up
def default_noise_sampler(x):
return lambda sigma, sigma_next: torch.randn_like(x)
class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
def __init__(self, x, t0, t1, seed=None, **kwargs):
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get('w0', torch.zeros_like(x))
if seed is None:
seed = torch.randint(0, 2 ** 63 - 1, []).item()
self.batched = True
try:
assert len(seed) == x.shape[0]
w0 = w0[0]
except TypeError:
seed = [seed]
self.batched = False
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
@staticmethod
def sort(a, b):
return (a, b, 1) if a < b else (b, a, -1)
def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
return w if self.batched else w[0]
class BrownianTreeNoiseSampler:
"""A noise sampler backed by a torchsde.BrownianTree.
Args:
x (Tensor): The tensor whose shape, device and dtype to use to generate
random samples.
sigma_min (float): The low end of the valid interval.
sigma_max (float): The high end of the valid interval.
seed (int or List[int]): The random seed. If a list of seeds is
supplied instead of a single integer, then the noise sampler will
use one BrownianTree per batch item, each with its own seed.
transform (callable): A function that maps sigma to the sampler's
internal timestep.
"""
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
self.transform = transform
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
self.tree = BatchedBrownianTree(x, t0, t1, seed)
def __call__(self, sigma, sigma_next):
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
return x
@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
return x
@torch.no_grad()
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
dt = sigmas[i + 1] - sigma_hat
x = x + d * dt
else:
# DPM-Solver-2
sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
dt_1 = sigma_mid - sigma_hat
dt_2 = sigmas[i + 1] - sigma_hat
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
return x
@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
if sigma_down == 0:
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver-2
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
def linear_multistep_coeff(order, t, i, j):
if order - 1 > i:
raise ValueError(f'Order {order} too high for step {i}')
def fn(tau):
prod = 1.
for k in range(order):
if j == k:
continue
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
return prod
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
@torch.no_grad()
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigmas_cpu = sigmas.detach().cpu().numpy()
ds = []
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
ds.append(d)
if len(ds) > order:
ds.pop(0)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x
@torch.no_grad()
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
v = torch.randint_like(x, 2) * 2 - 1
fevals = 0
def ode_fn(sigma, x):
nonlocal fevals
with torch.enable_grad():
x = x[0].detach().requires_grad_()
denoised = model(x, sigma * s_in, **extra_args)
d = to_d(x, sigma, denoised)
fevals += 1
grad = torch.autograd.grad((d * v).sum(), x)[0]
d_ll = (v * grad).flatten(1).sum(1)
return d.detach(), d_ll
x_min = x, x.new_zeros([x.shape[0]])
t = x.new_tensor([sigma_min, sigma_max])
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
latent, delta_ll = sol[0][-1], sol[1][-1]
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
return ll_prior + delta_ll, {'fevals': fevals}
class PIDStepSizeController:
"""A PID controller for ODE adaptive step size control."""
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
self.h = h
self.b1 = (pcoeff + icoeff + dcoeff) / order
self.b2 = -(pcoeff + 2 * dcoeff) / order
self.b3 = dcoeff / order
self.accept_safety = accept_safety
self.eps = eps
self.errs = []
def limiter(self, x):
return 1 + math.atan(x - 1)
def propose_step(self, error):
inv_error = 1 / (float(error) + self.eps)
if not self.errs:
self.errs = [inv_error, inv_error, inv_error]
self.errs[0] = inv_error
factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
factor = self.limiter(factor)
accept = factor >= self.accept_safety
if accept:
self.errs[2] = self.errs[1]
self.errs[1] = self.errs[0]
self.h *= factor
return accept
class DPMSolver(nn.Module):
"""DPM-Solver. See https://arxiv.org/abs/2206.00927."""
def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
super().__init__()
self.model = model
self.extra_args = {} if extra_args is None else extra_args
self.eps_callback = eps_callback
self.info_callback = info_callback
def t(self, sigma):
return -sigma.log()
def sigma(self, t):
return t.neg().exp()
def eps(self, eps_cache, key, x, t, *args, **kwargs):
if key in eps_cache:
return eps_cache[key], eps_cache
sigma = self.sigma(t) * x.new_ones([x.shape[0]])
eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
if self.eps_callback is not None:
self.eps_callback()
return eps, {key: eps, **eps_cache}
def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
eps_cache = {} if eps_cache is None else eps_cache
h = t_next - t
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
x_1 = x - self.sigma(t_next) * h.expm1() * eps
return x_1, eps_cache
def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
eps_cache = {} if eps_cache is None else eps_cache
h = t_next - t
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
s1 = t + r1 * h
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
return x_2, eps_cache
def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
eps_cache = {} if eps_cache is None else eps_cache
h = t_next - t
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
s1 = t + r1 * h
s2 = t + r2 * h
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
return x_3, eps_cache
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
if not t_end > t_start and eta:
raise ValueError('eta must be 0 for reverse sampling')
m = math.floor(nfe / 3) + 1
ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
if nfe % 3 == 0:
orders = [3] * (m - 2) + [2, 1]
else:
orders = [3] * (m - 1) + [nfe % 3]
for i in range(len(orders)):
eps_cache = {}
t, t_next = ts[i], ts[i + 1]
if eta:
sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
t_next_ = torch.minimum(t_end, self.t(sd))
su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
else:
t_next_, su = t_next, 0.
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
denoised = x - self.sigma(t) * eps
if self.info_callback is not None:
self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
if orders[i] == 1:
x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
elif orders[i] == 2:
x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
else:
x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
return x
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
if order not in {2, 3}:
raise ValueError('order should be 2 or 3')
forward = t_end > t_start
if not forward and eta:
raise ValueError('eta must be 0 for reverse sampling')
h_init = abs(h_init) * (1 if forward else -1)
atol = torch.tensor(atol)
rtol = torch.tensor(rtol)
s = t_start
x_prev = x
accept = True
pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
eps_cache = {}
t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
if eta:
sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
t_ = torch.minimum(t_end, self.t(sd))
su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
else:
t_, su = t, 0.
eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
denoised = x - self.sigma(s) * eps
if order == 2:
x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
else:
x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
accept = pid.propose_step(error)
if accept:
x_prev = x_low
x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
s = t
info['n_accept'] += 1
else:
info['n_reject'] += 1
info['nfe'] += order
info['steps'] += 1
if self.info_callback is not None:
self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
return x, info
@torch.no_grad()
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
"""DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
if sigma_min <= 0 or sigma_max <= 0:
raise ValueError('sigma_min and sigma_max must not be 0')
with tqdm(total=n, disable=disable) as pbar:
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
if callback is not None:
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
@torch.no_grad()
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
"""DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
if sigma_min <= 0 or sigma_max <= 0:
raise ValueError('sigma_min and sigma_max must not be 0')
with tqdm(disable=disable) as pbar:
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
if callback is not None:
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
if return_info:
return x, info
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigma_down == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
r = 1 / 2
h = t_next - t
s = t + r * h
x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
# Noise addition
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigmas[i + 1] - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
s = t + h * r
fac = 1 / (2 * r)
# Step 1
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
s_ = t_fn(sd)
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
# Step 2
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
t_next_ = t_fn(sd)
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
return x
@torch.no_grad()
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
old_denoised = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_denoised is None or sigmas[i + 1] == 0:
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
old_denoised = denoised
return x

332
comfy/k_diffusion/utils.py Normal file
View File

@@ -0,0 +1,332 @@
from contextlib import contextmanager
import hashlib
import math
from pathlib import Path
import shutil
import urllib
import warnings
from PIL import Image
import torch
from torch import nn, optim
from torch.utils import data
from torchvision.transforms import functional as TF
def from_pil_image(x):
"""Converts from a PIL image to a tensor."""
x = TF.to_tensor(x)
if x.ndim == 2:
x = x[..., None]
return x * 2 - 1
def to_pil_image(x):
"""Converts from a tensor to a PIL image."""
if x.ndim == 4:
assert x.shape[0] == 1
x = x[0]
if x.shape[0] == 1:
x = x[0]
return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
"""Apply passed in transforms for HuggingFace Datasets."""
images = [transform(image.convert(mode)) for image in examples[image_key]]
return {image_key: images}
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
expanded = x[(...,) + (None,) * dims_to_append]
# MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
# https://github.com/pytorch/pytorch/issues/84364
return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
def n_params(module):
"""Returns the number of trainable parameters in a module."""
return sum(p.numel() for p in module.parameters())
def download_file(path, url, digest=None):
"""Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
if not path.exists():
with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
shutil.copyfileobj(response, f)
if digest is not None:
file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
if digest != file_digest:
raise OSError(f'hash of {path} (url: {url}) failed to validate')
return path
@contextmanager
def train_mode(model, mode=True):
"""A context manager that places a model into training mode and restores
the previous mode on exit."""
modes = [module.training for module in model.modules()]
try:
yield model.train(mode)
finally:
for i, module in enumerate(model.modules()):
module.training = modes[i]
def eval_mode(model):
"""A context manager that places a model into evaluation mode and restores
the previous mode on exit."""
return train_mode(model, False)
@torch.no_grad()
def ema_update(model, averaged_model, decay):
"""Incorporates updated model parameters into an exponential moving averaged
version of a model. It should be called after each optimizer step."""
model_params = dict(model.named_parameters())
averaged_params = dict(averaged_model.named_parameters())
assert model_params.keys() == averaged_params.keys()
for name, param in model_params.items():
averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
model_buffers = dict(model.named_buffers())
averaged_buffers = dict(averaged_model.named_buffers())
assert model_buffers.keys() == averaged_buffers.keys()
for name, buf in model_buffers.items():
averaged_buffers[name].copy_(buf)
class EMAWarmup:
"""Implements an EMA warmup using an inverse decay schedule.
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 1.
min_value (float): The minimum EMA decay rate. Default: 0.
max_value (float): The maximum EMA decay rate. Default: 1.
start_at (int): The epoch to start averaging at. Default: 0.
last_epoch (int): The index of last epoch. Default: 0.
"""
def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
last_epoch=0):
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.max_value = max_value
self.start_at = start_at
self.last_epoch = last_epoch
def state_dict(self):
"""Returns the state of the class as a :class:`dict`."""
return dict(self.__dict__.items())
def load_state_dict(self, state_dict):
"""Loads the class's state.
Args:
state_dict (dict): scaler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_value(self):
"""Gets the current EMA decay rate."""
epoch = max(0, self.last_epoch - self.start_at)
value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
def step(self):
"""Updates the step count."""
self.last_epoch += 1
class InverseLR(optim.lr_scheduler._LRScheduler):
"""Implements an inverse decay learning rate schedule with an optional exponential
warmup. When last_epoch=-1, sets initial lr as lr.
inv_gamma is the number of steps/epochs required for the learning rate to decay to
(1 / 2)**power of its original value.
Args:
optimizer (Optimizer): Wrapped optimizer.
inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
power (float): Exponential factor of learning rate decay. Default: 1.
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
Default: 0.
min_lr (float): The minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
"""
def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
last_epoch=-1, verbose=False):
self.inv_gamma = inv_gamma
self.power = power
if not 0. <= warmup < 1:
raise ValueError('Invalid value for warmup')
self.warmup = warmup
self.min_lr = min_lr
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
warmup = 1 - self.warmup ** (self.last_epoch + 1)
lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
return [warmup * max(self.min_lr, base_lr * lr_mult)
for base_lr in self.base_lrs]
class ExponentialLR(optim.lr_scheduler._LRScheduler):
"""Implements an exponential learning rate schedule with an optional exponential
warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
continuously by decay (default 0.5) every num_steps steps.
Args:
optimizer (Optimizer): Wrapped optimizer.
num_steps (float): The number of steps to decay the learning rate by decay in.
decay (float): The factor by which to decay the learning rate every num_steps
steps. Default: 0.5.
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
Default: 0.
min_lr (float): The minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
"""
def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
last_epoch=-1, verbose=False):
self.num_steps = num_steps
self.decay = decay
if not 0. <= warmup < 1:
raise ValueError('Invalid value for warmup')
self.warmup = warmup
self.min_lr = min_lr
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
warmup = 1 - self.warmup ** (self.last_epoch + 1)
lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
return [warmup * max(self.min_lr, base_lr * lr_mult)
for base_lr in self.base_lrs]
def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
"""Draws samples from an lognormal distribution."""
return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
"""Draws samples from an optionally truncated log-logistic distribution."""
min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
return u.logit().mul(scale).add(loc).exp().to(dtype)
def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
"""Draws samples from an log-uniform distribution."""
min_value = math.log(min_value)
max_value = math.log(max_value)
return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
"""Draws samples from a truncated v-diffusion training timestep distribution."""
min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
return torch.tan(u * math.pi / 2) * sigma_data
def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
"""Draws samples from a split lognormal distribution."""
n = torch.randn(shape, device=device, dtype=dtype).abs()
u = torch.rand(shape, device=device, dtype=dtype)
n_left = n * -scale_1 + loc
n_right = n * scale_2 + loc
ratio = scale_1 / (scale_1 + scale_2)
return torch.where(u < ratio, n_left, n_right).exp()
class FolderOfImages(data.Dataset):
"""Recursively finds all images in a directory. It does not support
classes/targets."""
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
def __init__(self, root, transform=None):
super().__init__()
self.root = Path(root)
self.transform = nn.Identity() if transform is None else transform
self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
def __repr__(self):
return f'FolderOfImages(root="{self.root}", len: {len(self)})'
def __len__(self):
return len(self.paths)
def __getitem__(self, key):
path = self.paths[key]
with open(path, 'rb') as f:
image = Image.open(f).convert('RGB')
image = self.transform(image)
return image,
class CSVLogger:
def __init__(self, filename, columns):
self.filename = Path(filename)
self.columns = columns
if self.filename.exists():
self.file = open(self.filename, 'a')
else:
self.file = open(self.filename, 'w')
self.write(*self.columns)
def write(self, *args):
print(*args, sep=',', file=self.file, flush=True)
@contextmanager
def tf32_mode(cudnn=None, matmul=None):
"""A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
cudnn_old = torch.backends.cudnn.allow_tf32
matmul_old = torch.backends.cuda.matmul.allow_tf32
try:
if cudnn is not None:
torch.backends.cudnn.allow_tf32 = cudnn
if matmul is not None:
torch.backends.cuda.matmul.allow_tf32 = matmul
yield
finally:
if cudnn is not None:
torch.backends.cudnn.allow_tf32 = cudnn_old
if matmul is not None:
torch.backends.cuda.matmul.allow_tf32 = matmul_old