Implement hunyuan image refiner model. (#9817)

This commit is contained in:
comfyanonymous
2025-09-11 21:43:20 -07:00
committed by GitHub
parent 18de0b2830
commit 33bd9ed9cb
9 changed files with 367 additions and 12 deletions

View File

@@ -145,7 +145,7 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d):
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d, norm_op=Normalize):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@@ -153,7 +153,7 @@ class ResnetBlock(nn.Module):
self.use_conv_shortcut = conv_shortcut
self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.norm1 = norm_op(in_channels)
self.conv1 = conv_op(in_channels,
out_channels,
kernel_size=3,
@@ -162,7 +162,7 @@ class ResnetBlock(nn.Module):
if temb_channels > 0:
self.temb_proj = ops.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.norm2 = norm_op(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = conv_op(out_channels,
out_channels,
@@ -305,11 +305,11 @@ def vae_attention():
return normal_attention
class AttnBlock(nn.Module):
def __init__(self, in_channels, conv_op=ops.Conv2d):
def __init__(self, in_channels, conv_op=ops.Conv2d, norm_op=Normalize):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.norm = norm_op(in_channels)
self.q = conv_op(in_channels,
in_channels,
kernel_size=1,