Basic Hunyuan Video model support.

This commit is contained in:
comfyanonymous
2024-12-16 19:35:40 -05:00
parent 19ee5d9d8b
commit bda1482a27
18 changed files with 413646 additions and 76 deletions

View File

@@ -114,7 +114,7 @@ class Modulation(nn.Module):
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -141,6 +141,7 @@ class DoubleStreamBlock(nn.Module):
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
img_mod1, img_mod2 = self.img_mod(vec)
@@ -160,13 +161,22 @@ class DoubleStreamBlock(nn.Module):
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
if self.flipped_img_txt:
# run actual attention
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)

View File

@@ -0,0 +1,330 @@
#Based on Flux code because of weird hunyuan video code license.
import torch
import comfy.ldm.flux.layers
import comfy.ldm.modules.diffusionmodules.mmdit
from comfy.ldm.modules.attention import optimized_attention
from dataclasses import dataclass
from einops import repeat
from torch import Tensor, nn
from comfy.ldm.flux.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding
)
import comfy.ldm.common_dit
@dataclass
class HunyuanVideoParams:
in_channels: int
out_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list
theta: int
patch_size: list
qkv_bias: bool
guidance_embed: bool
class SelfAttentionRef(nn.Module):
def __init__(self, dim: int, qkv_bias: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
class TokenRefinerBlock(nn.Module):
def __init__(
self,
hidden_size,
heads,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.heads = heads
mlp_hidden_dim = hidden_size * 4
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device),
)
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.self_attn = SelfAttentionRef(hidden_size, True, dtype=dtype, device=device, operations=operations)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, x, c, mask):
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn.qkv(norm_x)
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
return x
class IndividualTokenRefiner(nn.Module):
def __init__(
self,
hidden_size,
heads,
num_blocks,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.blocks = nn.ModuleList(
[
TokenRefinerBlock(
hidden_size=hidden_size,
heads=heads,
dtype=dtype,
device=device,
operations=operations
)
for _ in range(num_blocks)
]
)
def forward(self, x, c, mask):
m = None
if mask is not None:
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
m = m + m.transpose(2, 3)
for block in self.blocks:
x = block(x, c, m)
return x
class TokenRefiner(nn.Module):
def __init__(
self,
text_dim,
hidden_size,
heads,
num_blocks,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.input_embedder = operations.Linear(text_dim, hidden_size, bias=True, dtype=dtype, device=device)
self.t_embedder = MLPEmbedder(256, hidden_size, dtype=dtype, device=device, operations=operations)
self.c_embedder = MLPEmbedder(text_dim, hidden_size, dtype=dtype, device=device, operations=operations)
self.individual_token_refiner = IndividualTokenRefiner(hidden_size, heads, num_blocks, dtype=dtype, device=device, operations=operations)
def forward(
self,
x,
timesteps,
mask,
):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1)
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
c = x.sum(dim=1) / x.shape[1]
c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, mask)
return x
class HunyuanVideo(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
params = HunyuanVideoParams(**kwargs)
self.params = params
self.patch_size = params.patch_size
self.in_channels = params.in_channels
self.out_channels = params.out_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
self.txt_in = TokenRefiner(params.context_in_dim, self.hidden_size, self.num_heads, 2, dtype=dtype, device=device, operations=operations)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
flipped_img_txt=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
transformer_options={},
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
initial_shape = list(img.shape)
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
txt = self.txt_in(txt, timesteps, txt_mask)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
attn_mask[:, 0, img_len:] = txt_mask
else:
attn_mask = None
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((img, txt), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, : img_len] += add
img = img[:, : img_len]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
shape = initial_shape[-3:]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape)
return img
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
return out

View File

@@ -162,12 +162,19 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
},
**kwargs,
)
self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
if ddconfig.get("conv3d", False):
conv_op = comfy.ops.disable_weight_init.Conv3d
else:
conv_op = comfy.ops.disable_weight_init.Conv2d
self.quant_conv = conv_op(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list:

View File

@@ -43,51 +43,100 @@ def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class VideoConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__()
self.padding_mode = padding_mode
if padding != 0:
padding = (padding, padding, padding, padding, kernel_size - 1, 0)
else:
kwargs["padding"] = padding
self.padding = padding
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
if self.padding != 0:
x = torch.nn.functional.pad(x, self.padding, mode=self.padding_mode)
return self.conv(x)
def interpolate_up(x, scale_factor):
try:
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
except: #operation not implemented for bf16
orig_shape = list(x.shape)
out_shape = orig_shape[:2]
for i in range(len(orig_shape) - 2):
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
return out
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
super().__init__()
self.with_conv = with_conv
self.scale_factor = scale_factor
if self.with_conv:
self.conv = ops.Conv2d(in_channels,
self.conv = conv_op(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
try:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
except: #operation not implemented for bf16
b, c, h, w = x.shape
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
del x
x = out
scale_factor = self.scale_factor
if not isinstance(scale_factor, tuple):
scale_factor = (scale_factor,) * (x.ndim - 2)
if x.ndim == 5 and scale_factor[0] > 1.0:
t = x.shape[2]
if t > 1:
a, b = x.split((1, t - 1), dim=2)
del x
b = interpolate_up(b, scale_factor)
else:
a = x
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
if t > 1:
x = torch.cat((a, b), dim=2)
else:
x = a
else:
x = interpolate_up(x, self.scale_factor)
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
def __init__(self, in_channels, with_conv, stride=2, conv_op=ops.Conv2d):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = ops.Conv2d(in_channels,
self.conv = conv_op(in_channels,
in_channels,
kernel_size=3,
stride=2,
stride=stride,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
if x.ndim == 4:
pad = (0, 1, 0, 1)
mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0)
mode = "replicate"
x = torch.nn.functional.pad(x, pad, mode=mode)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
@@ -96,7 +145,7 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
dropout, temb_channels=512, conv_op=ops.Conv2d):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@@ -105,7 +154,7 @@ class ResnetBlock(nn.Module):
self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.conv1 = ops.Conv2d(in_channels,
self.conv1 = conv_op(in_channels,
out_channels,
kernel_size=3,
stride=1,
@@ -115,20 +164,20 @@ class ResnetBlock(nn.Module):
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = ops.Conv2d(out_channels,
self.conv2 = conv_op(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = ops.Conv2d(in_channels,
self.conv_shortcut = conv_op(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = ops.Conv2d(in_channels,
self.nin_shortcut = conv_op(in_channels,
out_channels,
kernel_size=1,
stride=1,
@@ -194,21 +243,25 @@ def slice_attention(q, k, v):
def normal_attention(q, k, v):
# compute attention
b,c,h,w = q.shape
orig_shape = q.shape
b = orig_shape[0]
c = orig_shape[1]
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w)
q = q.reshape(b, c, -1)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, -1) # b,c,hw
v = v.reshape(b, c, -1)
r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w)
h_ = r1.reshape(orig_shape)
del r1
return h_
def xformers_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v),
@@ -216,14 +269,16 @@ def xformers_attention(q, k, v):
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(B, C, H, W)
out = out.transpose(1, 2).reshape(orig_shape)
except NotImplementedError:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
def pytorch_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
@@ -231,35 +286,35 @@ def pytorch_attention(q, k, v):
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
class AttnBlock(nn.Module):
def __init__(self, in_channels):
def __init__(self, in_channels, conv_op=ops.Conv2d):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = ops.Conv2d(in_channels,
self.q = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = ops.Conv2d(in_channels,
self.k = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = ops.Conv2d(in_channels,
self.v = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = ops.Conv2d(in_channels,
self.proj_out = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
@@ -289,8 +344,8 @@ class AttnBlock(nn.Module):
return x+h_
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
return AttnBlock(in_channels)
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None, conv_op=ops.Conv2d):
return AttnBlock(in_channels, conv_op=conv_op)
class Model(nn.Module):
@@ -449,6 +504,7 @@ class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
conv3d=False, time_compress=None,
**ignore_kwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
@@ -459,8 +515,15 @@ class Encoder(nn.Module):
self.resolution = resolution
self.in_channels = in_channels
if conv3d:
conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
mid_attn_conv_op = ops.Conv2d
# downsampling
self.conv_in = ops.Conv2d(in_channels,
self.conv_in = conv_op(in_channels,
self.ch,
kernel_size=3,
stride=1,
@@ -479,15 +542,20 @@ class Encoder(nn.Module):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
dropout=dropout,
conv_op=conv_op))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
attn.append(make_attn(block_in, attn_type=attn_type, conv_op=conv_op))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
stride = 2
if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2)
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2
self.down.append(down)
@@ -496,16 +564,18 @@ class Encoder(nn.Module):
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
dropout=dropout,
conv_op=conv_op)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type, conv_op=mid_attn_conv_op)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
dropout=dropout,
conv_op=conv_op)
# end
self.norm_out = Normalize(block_in)
self.conv_out = ops.Conv2d(block_in,
self.conv_out = conv_op(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
@@ -543,6 +613,8 @@ class Decoder(nn.Module):
conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
conv3d=False,
time_compress=None,
**ignorekwargs):
super().__init__()
self.ch = ch
@@ -554,6 +626,14 @@ class Decoder(nn.Module):
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
if conv3d:
conv_op = VideoConv3d
conv_out_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
mid_attn_conv_op = ops.Conv2d
# compute block_in and curr_res at lowest res
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
@@ -562,7 +642,7 @@ class Decoder(nn.Module):
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = ops.Conv2d(z_channels,
self.conv_in = conv_op(z_channels,
block_in,
kernel_size=3,
stride=1,
@@ -573,12 +653,14 @@ class Decoder(nn.Module):
self.mid.block_1 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = attn_op(block_in)
dropout=dropout,
conv_op=conv_op)
self.mid.attn_1 = attn_op(block_in, conv_op=mid_attn_conv_op)
self.mid.block_2 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
dropout=dropout,
conv_op=conv_op)
# upsampling
self.up = nn.ModuleList()
@@ -590,15 +672,21 @@ class Decoder(nn.Module):
block.append(resnet_op(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
dropout=dropout,
conv_op=conv_op))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(attn_op(block_in))
attn.append(attn_op(block_in, conv_op=conv_op))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
scale_factor = 2.0
if time_compress is not None:
if i_level > math.log2(time_compress):
scale_factor = (1.0, 2.0, 2.0)
up.upsample = Upsample(block_in, resamp_with_conv, conv_op=conv_op, scale_factor=scale_factor)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order

View File

@@ -194,6 +194,7 @@ def make_time_attn(
attn_kwargs=None,
alpha: float = 0,
merge_strategy: str = "learned",
conv_op=ops.Conv2d,
):
return partialclass(
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy