Cleanup HunyuanDit controlnets.

Use the: ControlNetApply SD3 and HunyuanDiT node.
This commit is contained in:
comfyanonymous
2024-08-09 02:35:19 -04:00
parent 06eb9fb426
commit a475ec2300
4 changed files with 60 additions and 194 deletions

View File

@@ -16,28 +16,11 @@ from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from .poolers import AttentionPool
import comfy.latent_formats
from .models import HunYuanDiTBlock
from .models import HunYuanDiTBlock, calc_rope
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
def calc_rope(x, patch_size, head_size):
th = (x.shape[2] + (patch_size // 2)) // patch_size
tw = (x.shape[3] + (patch_size // 2)) // patch_size
base_size = 512 // 8 // patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
return rope
class HunYuanControlNet(nn.Module):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
@@ -213,35 +196,32 @@ class HunYuanControlNet(nn.Module):
)
# Input zero linear for the first block
self.before_proj = zero_module(
nn.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
)
self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
# Output zero linear for the every block
self.after_proj_list = nn.ModuleList(
[
zero_module(
nn.Linear(
operations.Linear(
self.hidden_size, self.hidden_size, dtype=dtype, device=device
)
)
for _ in range(len(self.blocks))
]
)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor = None,
condition=None,
encoder_hidden_states: Optional[torch.Tensor] = None,
x,
hint,
timesteps,
context,#encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
control_weight=1.0,
transformer_options=None,
return_dict=False,
**kwarg,
):
"""
@@ -270,10 +250,11 @@ class HunYuanControlNet(nn.Module):
return_dict: bool
Whether to return a dictionary.
"""
condition = hint
if condition.shape[0] == 1:
condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
text_states = encoder_hidden_states # 2,77,1024
text_states = context # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
text_states_mask = text_embedding_mask.bool() # 2,77
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
@@ -304,7 +285,7 @@ class HunYuanControlNet(nn.Module):
) # (cos_cis_img, sin_cis_img)
# ========================= Build time and image embedding =========================
t = self.t_embedder(t, dtype=self.dtype)
t = self.t_embedder(timesteps, dtype=self.dtype)
x = self.x_embedder(x)
# ========================= Concatenate all extra vectors =========================
@@ -337,12 +318,4 @@ class HunYuanControlNet(nn.Module):
x = block(x, c, text_states, freqs_cis_img)
controls.append(self.after_proj_list[layer](x)) # zero linear for output
control_weights = [1.0 * (control_weight ** float(19 - i)) for i in range(19)]
assert len(control_weights) == len(
controls
), "control_weights and controls should have the same length"
controls = [
control * weight for control, weight in zip(controls, control_weights)
]
return {"output": controls}