Support base SDXL and SDXL refiner models.

Large refactor of the model detection and loading code.
This commit is contained in:
comfyanonymous
2023-06-22 13:03:50 -04:00
parent 9fccf4aa03
commit f87ec10a97
16 changed files with 754 additions and 289 deletions

View File

@@ -34,8 +34,10 @@ class ControlNet(nn.Module):
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
use_bf16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
@@ -51,6 +53,8 @@ class ControlNet(nn.Module):
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
):
super().__init__()
if use_spatial_transformer:
@@ -75,6 +79,10 @@ class ControlNet(nn.Module):
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
@@ -97,8 +105,10 @@ class ControlNet(nn.Module):
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
@@ -111,6 +121,24 @@ class ControlNet(nn.Module):
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
else:
raise ValueError()
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
@@ -179,7 +207,7 @@ class ControlNet(nn.Module):
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
@@ -238,7 +266,7 @@ class ControlNet(nn.Module):
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
@@ -257,7 +285,7 @@ class ControlNet(nn.Module):
def make_zero_conv(self, channels):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, timesteps, context, **kwargs):
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
@@ -265,6 +293,14 @@ class ControlNet(nn.Module):
outs = []
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None: