Implement Self-Attention Guidance (#2201)
* First SAG test * need to put extra options on the model instead of patcher * no errors and results seem not-broken * Use @ashen-uncensored formula, which works better!!! * Fix a crash when using weird resolutions. Remove an unnecessary UNet call * Improve comments, optimize memory in blur routine * SAG works with sampler_cfg_function
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import enum
|
||||
from comfy import model_management
|
||||
import math
|
||||
@@ -60,10 +61,10 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||
for t in range(rr):
|
||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||
|
||||
conditionning = {}
|
||||
conditioning = {}
|
||||
model_conds = conds["model_conds"]
|
||||
for c in model_conds:
|
||||
conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||
|
||||
control = None
|
||||
if 'control' in conds:
|
||||
@@ -82,7 +83,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||
|
||||
patches['middle_patch'] = [gligen_patch]
|
||||
|
||||
return (input_x, mult, conditionning, area, control, patches)
|
||||
return (input_x, mult, conditioning, area, control, patches)
|
||||
|
||||
def cond_equal_size(c1, c2):
|
||||
if c1 is c2:
|
||||
@@ -246,15 +247,71 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||
return out_cond, out_uncond
|
||||
|
||||
|
||||
if math.isclose(cond_scale, 1.0):
|
||||
# if we're doing SAG, we still need to do uncond guidance, even though the cond and uncond will cancel out.
|
||||
if math.isclose(cond_scale, 1.0) and "sag" not in model_options:
|
||||
uncond = None
|
||||
|
||||
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
|
||||
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
|
||||
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||
if "sampler_cfg_function" in model_options:
|
||||
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
||||
return x - model_options["sampler_cfg_function"](args)
|
||||
else:
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
||||
cfg_result = x - model_options["sampler_cfg_function"](args)
|
||||
|
||||
if "sag" in model_options:
|
||||
assert uncond is not None, "SAG requires uncond guidance"
|
||||
sag_scale = model_options["sag_scale"]
|
||||
sag_sigma = model_options["sag_sigma"]
|
||||
sag_threshold = model_options.get("sag_threshold", 1.0)
|
||||
|
||||
# these methods are added by the sag patcher
|
||||
uncond_attn = model.get_attn_scores()
|
||||
mid_shape = model.get_mid_block_shape()
|
||||
# create the adversarially blurred image
|
||||
degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold)
|
||||
degraded_noised = degraded + x - uncond_pred
|
||||
# call into the UNet
|
||||
(sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options)
|
||||
cfg_result += (degraded - sag) * sag_scale
|
||||
return cfg_result
|
||||
|
||||
def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0):
|
||||
# reshape and GAP the attention map
|
||||
_, hw1, hw2 = attn.shape
|
||||
b, _, lh, lw = x0.shape
|
||||
attn = attn.reshape(b, -1, hw1, hw2)
|
||||
# Global Average Pool
|
||||
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||
# Reshape
|
||||
mask = (
|
||||
mask.reshape(b, *mid_shape)
|
||||
.unsqueeze(1)
|
||||
.type(attn.dtype)
|
||||
)
|
||||
# Upsample
|
||||
mask = F.interpolate(mask, (lh, lw))
|
||||
|
||||
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
|
||||
blurred = blurred * mask + x0 * (1 - mask)
|
||||
return blurred
|
||||
|
||||
def gaussian_blur_2d(img, kernel_size, sigma):
|
||||
ksize_half = (kernel_size - 1) * 0.5
|
||||
|
||||
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
||||
|
||||
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||
|
||||
x_kernel = pdf / pdf.sum()
|
||||
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
|
||||
|
||||
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
|
||||
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
|
||||
|
||||
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
||||
|
||||
img = F.pad(img, padding, mode="reflect")
|
||||
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
|
||||
return img
|
||||
|
||||
class CFGNoisePredictor(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
|
||||
Reference in New Issue
Block a user