Add manual cast to controlnet.

This commit is contained in:
comfyanonymous
2023-12-12 03:32:23 -05:00
parent 3152023fbc
commit 32b7e7e769
2 changed files with 47 additions and 37 deletions

View File

@@ -141,24 +141,24 @@ class ControlNet(nn.Module):
)
]
)
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)])
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
self.input_hint_block = TimestepEmbedSequential(
operations.conv_nd(dims, hint_channels, 16, 3, padding=1),
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 16, 16, 3, padding=1),
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2),
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 32, 32, 3, padding=1),
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2),
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 96, 96, 3, padding=1),
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2),
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1))
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
)
self._feature_size = model_channels
@@ -206,7 +206,7 @@ class ControlNet(nn.Module):
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
@@ -234,7 +234,7 @@ class ControlNet(nn.Module):
)
ch = out_ch
input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
ds *= 2
self._feature_size += ch
@@ -276,11 +276,11 @@ class ControlNet(nn.Module):
operations=operations
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
self._feature_size += ch
def make_zero_conv(self, channels, operations=None):
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)