Add manual cast to controlnet.
This commit is contained in:
@@ -141,24 +141,24 @@ class ControlNet(nn.Module):
|
||||
)
|
||||
]
|
||||
)
|
||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)])
|
||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
|
||||
|
||||
self.input_hint_block = TimestepEmbedSequential(
|
||||
operations.conv_nd(dims, hint_channels, 16, 3, padding=1),
|
||||
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 16, 16, 3, padding=1),
|
||||
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
||||
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 32, 32, 3, padding=1),
|
||||
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
||||
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 96, 96, 3, padding=1),
|
||||
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
||||
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1))
|
||||
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
||||
)
|
||||
|
||||
self._feature_size = model_channels
|
||||
@@ -206,7 +206,7 @@ class ControlNet(nn.Module):
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
|
||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
@@ -234,7 +234,7 @@ class ControlNet(nn.Module):
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
|
||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
@@ -276,11 +276,11 @@ class ControlNet(nn.Module):
|
||||
operations=operations
|
||||
)]
|
||||
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
|
||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
|
||||
self._feature_size += ch
|
||||
|
||||
def make_zero_conv(self, channels, operations=None):
|
||||
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
|
||||
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
|
||||
Reference in New Issue
Block a user