feat: SUPIR model support (CORE-17) (#13250)
This commit is contained in:
@@ -34,6 +34,16 @@ class TimestepBlock(nn.Module):
|
||||
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
||||
for layer in ts:
|
||||
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||
found_patched = False
|
||||
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||
if isinstance(layer, class_type):
|
||||
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||
found_patched = True
|
||||
break
|
||||
if found_patched:
|
||||
continue
|
||||
|
||||
if isinstance(layer, VideoResBlock):
|
||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||
elif isinstance(layer, TimestepBlock):
|
||||
@@ -49,15 +59,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||
found_patched = False
|
||||
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||
if isinstance(layer, class_type):
|
||||
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||
found_patched = True
|
||||
break
|
||||
if found_patched:
|
||||
continue
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@@ -894,6 +895,12 @@ class UNetModel(nn.Module):
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'middle')
|
||||
|
||||
if "middle_block_after_patch" in transformer_patches:
|
||||
patch = transformer_patches["middle_block_after_patch"]
|
||||
for p in patch:
|
||||
out = p({"h": h, "x": x, "emb": emb, "context": context, "y": y,
|
||||
"timesteps": timesteps, "transformer_options": transformer_options})
|
||||
h = out["h"]
|
||||
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
transformer_options["block"] = ("output", id)
|
||||
@@ -905,8 +912,9 @@ class UNetModel(nn.Module):
|
||||
for p in patch:
|
||||
h, hsp = p(h, hsp, transformer_options)
|
||||
|
||||
h = th.cat([h, hsp], dim=1)
|
||||
del hsp
|
||||
if hsp is not None:
|
||||
h = th.cat([h, hsp], dim=1)
|
||||
del hsp
|
||||
if len(hs) > 0:
|
||||
output_shape = hs[-1].shape
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user