Add a way to pass options to the transformers blocks.

This commit is contained in:
comfyanonymous
2023-03-31 13:04:39 -04:00
parent 04b42bad87
commit 61ec3c9d5d
5 changed files with 33 additions and 29 deletions

View File

@@ -504,10 +504,10 @@ class BasicTransformerBlock(nn.Module):
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
def _forward(self, x, context=None, transformer_options={}):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
@@ -557,7 +557,7 @@ class SpatialTransformer(nn.Module):
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear
def forward(self, x, context=None):
def forward(self, x, context=None, transformer_options={}):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
@@ -570,7 +570,7 @@ class SpatialTransformer(nn.Module):
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i])
x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()

View File

@@ -76,12 +76,12 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input.
"""
def forward(self, x, emb, context=None):
def forward(self, x, emb, context=None, transformer_options={}):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
x = layer(x, context, transformer_options)
else:
x = layer(x)
return x
@@ -753,7 +753,7 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs):
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
@@ -762,6 +762,7 @@ class UNetModel(nn.Module):
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
transformer_options["original_shape"] = list(x.shape)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
@@ -775,13 +776,13 @@ class UNetModel(nn.Module):
h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context)
h = module(h, emb, context, transformer_options)
if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
hs.append(h)
h = self.middle_block(h, emb, context)
h = self.middle_block(h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0:
h += control['middle'].pop()
@@ -793,7 +794,7 @@ class UNetModel(nn.Module):
hsp += ctrl
h = th.cat([h, hsp], dim=1)
del hsp
h = module(h, emb, context)
h = module(h, emb, context, transformer_options)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)