Support for Control Loras.
Control loras are controlnets where some of the weights are stored in "lora" format: an up and a down low rank matrice that when multiplied together and added to the unet weight give the controlnet weight. This allows a much smaller memory footprint depending on the rank of the matrices. These controlnets are used just like regular ones.
This commit is contained in:
@@ -6,8 +6,6 @@ import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from ..ldm.modules.diffusionmodules.util import (
|
||||
conv_nd,
|
||||
linear,
|
||||
zero_module,
|
||||
timestep_embedding,
|
||||
)
|
||||
@@ -15,7 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
|
||||
from ..ldm.modules.attention import SpatialTransformer
|
||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
||||
from ..ldm.util import exists
|
||||
|
||||
import comfy.ops
|
||||
|
||||
class ControlledUnetModel(UNetModel):
|
||||
#implemented in the ldm unet
|
||||
@@ -55,6 +53,8 @@ class ControlNet(nn.Module):
|
||||
use_linear_in_transformer=False,
|
||||
adm_in_channels=None,
|
||||
transformer_depth_middle=None,
|
||||
device=None,
|
||||
operations=comfy.ops,
|
||||
):
|
||||
super().__init__()
|
||||
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||
@@ -117,9 +117,9 @@ class ControlNet(nn.Module):
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
@@ -132,9 +132,9 @@ class ControlNet(nn.Module):
|
||||
assert adm_in_channels is not None
|
||||
self.label_emb = nn.Sequential(
|
||||
nn.Sequential(
|
||||
linear(adm_in_channels, time_embed_dim),
|
||||
operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -143,28 +143,28 @@ class ControlNet(nn.Module):
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
||||
)
|
||||
]
|
||||
)
|
||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)])
|
||||
|
||||
self.input_hint_block = TimestepEmbedSequential(
|
||||
conv_nd(dims, hint_channels, 16, 3, padding=1),
|
||||
operations.conv_nd(dims, hint_channels, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 16, 3, padding=1),
|
||||
operations.conv_nd(dims, 16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
||||
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 32, 3, padding=1),
|
||||
operations.conv_nd(dims, 32, 32, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
||||
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 96, 96, 3, padding=1),
|
||||
operations.conv_nd(dims, 96, 96, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
||||
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
||||
zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1))
|
||||
)
|
||||
|
||||
self._feature_size = model_channels
|
||||
@@ -182,6 +182,7 @@ class ControlNet(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
operations=operations
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
@@ -204,11 +205,11 @@ class ControlNet(nn.Module):
|
||||
SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
use_checkpoint=use_checkpoint, operations=operations
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.zero_convs.append(self.make_zero_conv(ch))
|
||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
@@ -224,16 +225,17 @@ class ControlNet(nn.Module):
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
operations=operations
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
self.zero_convs.append(self.make_zero_conv(ch))
|
||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
@@ -253,11 +255,12 @@ class ControlNet(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
operations=operations
|
||||
),
|
||||
SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
use_checkpoint=use_checkpoint, operations=operations
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
@@ -266,13 +269,14 @@ class ControlNet(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
operations=operations
|
||||
),
|
||||
)
|
||||
self.middle_block_out = self.make_zero_conv(ch)
|
||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
|
||||
self._feature_size += ch
|
||||
|
||||
def make_zero_conv(self, channels):
|
||||
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||
def make_zero_conv(self, channels, operations=None):
|
||||
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||
|
||||
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
|
||||
Reference in New Issue
Block a user