Refactor comfy.ops

comfy.ops -> comfy.ops.disable_weight_init

This should make it more clear what they actually do.

Some unused code has also been removed.
This commit is contained in:
comfyanonymous
2023-12-11 23:27:13 -05:00
parent b0aab1e4ea
commit 77755ab8db
10 changed files with 94 additions and 170 deletions

View File

@@ -8,6 +8,7 @@ from typing import Optional, Any
from comfy import model_management
import comfy.ops
ops = comfy.ops.disable_weight_init
if model_management.xformers_enabled_vae():
import xformers
@@ -48,7 +49,7 @@ class Upsample(nn.Module):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = comfy.ops.Conv2d(in_channels,
self.conv = ops.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
@@ -78,7 +79,7 @@ class Downsample(nn.Module):
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = comfy.ops.Conv2d(in_channels,
self.conv = ops.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
@@ -105,30 +106,30 @@ class ResnetBlock(nn.Module):
self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.conv1 = comfy.ops.Conv2d(in_channels,
self.conv1 = ops.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = comfy.ops.Linear(temb_channels,
self.temb_proj = ops.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = comfy.ops.Conv2d(out_channels,
self.conv2 = ops.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = comfy.ops.Conv2d(in_channels,
self.conv_shortcut = ops.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = comfy.ops.Conv2d(in_channels,
self.nin_shortcut = ops.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
@@ -245,22 +246,22 @@ class AttnBlock(nn.Module):
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = comfy.ops.Conv2d(in_channels,
self.q = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = comfy.ops.Conv2d(in_channels,
self.k = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = comfy.ops.Conv2d(in_channels,
self.v = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = comfy.ops.Conv2d(in_channels,
self.proj_out = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
@@ -312,14 +313,14 @@ class Model(nn.Module):
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
comfy.ops.Linear(self.ch,
ops.Linear(self.ch,
self.temb_ch),
comfy.ops.Linear(self.temb_ch,
ops.Linear(self.temb_ch,
self.temb_ch),
])
# downsampling
self.conv_in = comfy.ops.Conv2d(in_channels,
self.conv_in = ops.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
@@ -388,7 +389,7 @@ class Model(nn.Module):
# end
self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in,
self.conv_out = ops.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
@@ -461,7 +462,7 @@ class Encoder(nn.Module):
self.in_channels = in_channels
# downsampling
self.conv_in = comfy.ops.Conv2d(in_channels,
self.conv_in = ops.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
@@ -506,7 +507,7 @@ class Encoder(nn.Module):
# end
self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in,
self.conv_out = ops.Conv2d(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
@@ -541,7 +542,7 @@ class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
conv_out_op=comfy.ops.Conv2d,
conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
**ignorekwargs):
@@ -565,7 +566,7 @@ class Decoder(nn.Module):
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = comfy.ops.Conv2d(z_channels,
self.conv_in = ops.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,

View File

@@ -12,13 +12,13 @@ from .util import (
checkpoint,
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
AlphaBlender,
)
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists
import comfy.ops
ops = comfy.ops.disable_weight_init
class TimestepBlock(nn.Module):
"""
@@ -70,7 +70,7 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -106,7 +106,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -159,7 +159,7 @@ class ResBlock(TimestepBlock):
skip_t_emb=False,
dtype=None,
device=None,
operations=comfy.ops
operations=ops
):
super().__init__()
self.channels = channels
@@ -284,7 +284,7 @@ class VideoResBlock(ResBlock):
down: bool = False,
dtype=None,
device=None,
operations=comfy.ops
operations=ops
):
super().__init__(
channels,
@@ -434,7 +434,7 @@ class UNetModel(nn.Module):
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
device=None,
operations=comfy.ops,
operations=ops,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
@@ -581,7 +581,7 @@ class UNetModel(nn.Module):
up=False,
dtype=None,
device=None,
operations=comfy.ops
operations=ops
):
if self.use_temporal_resblocks:
return VideoResBlock(

View File

@@ -16,7 +16,6 @@ import numpy as np
from einops import repeat, rearrange
from comfy.ldm.util import instantiate_from_config
import comfy.ops
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
@@ -273,46 +272,6 @@ def mean_flat(tensor):
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels, dtype=None):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels, dtype=dtype)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return comfy.ops.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return comfy.ops.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.