Add ControlNet support.
This commit is contained in:
@@ -1320,12 +1320,12 @@ class DiffusionWrapper(torch.nn.Module):
|
||||
self.conditioning_key = conditioning_key
|
||||
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
|
||||
|
||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
|
||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None):
|
||||
if self.conditioning_key is None:
|
||||
out = self.diffusion_model(x, t)
|
||||
out = self.diffusion_model(x, t, control=control)
|
||||
elif self.conditioning_key == 'concat':
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
out = self.diffusion_model(xc, t)
|
||||
out = self.diffusion_model(xc, t, control=control)
|
||||
elif self.conditioning_key == 'crossattn':
|
||||
if not self.sequential_cross_attn:
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
@@ -1335,25 +1335,25 @@ class DiffusionWrapper(torch.nn.Module):
|
||||
# TorchScript changes names of the arguments
|
||||
# with argument cc defined as context=cc scripted model will produce
|
||||
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
|
||||
out = self.scripted_diffusion_model(x, t, cc)
|
||||
out = self.scripted_diffusion_model(x, t, cc, control=control)
|
||||
else:
|
||||
out = self.diffusion_model(x, t, context=cc)
|
||||
out = self.diffusion_model(x, t, context=cc, control=control)
|
||||
elif self.conditioning_key == 'hybrid':
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(xc, t, context=cc)
|
||||
out = self.diffusion_model(xc, t, context=cc, control=control)
|
||||
elif self.conditioning_key == 'hybrid-adm':
|
||||
assert c_adm is not None
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(xc, t, context=cc, y=c_adm)
|
||||
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control)
|
||||
elif self.conditioning_key == 'crossattn-adm':
|
||||
assert c_adm is not None
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(x, t, context=cc, y=c_adm)
|
||||
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control)
|
||||
elif self.conditioning_key == 'adm':
|
||||
cc = c_crossattn[0]
|
||||
out = self.diffusion_model(x, t, y=cc)
|
||||
out = self.diffusion_model(x, t, y=cc, control=control)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -753,7 +753,7 @@ class UNetModel(nn.Module):
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
self.output_blocks.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
@@ -778,8 +778,14 @@ class UNetModel(nn.Module):
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
if control is not None:
|
||||
h += control.pop()
|
||||
|
||||
for module in self.output_blocks:
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
hsp = hs.pop()
|
||||
if control is not None:
|
||||
hsp += control.pop()
|
||||
h = th.cat([h, hsp], dim=1)
|
||||
h = module(h, emb, context)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
|
||||
Reference in New Issue
Block a user