Cleanup HunyuanDit controlnets.
Use the: ControlNetApply SD3 and HunyuanDiT node.
This commit is contained in:
@@ -16,28 +16,11 @@ from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from .poolers import AttentionPool
|
||||
|
||||
import comfy.latent_formats
|
||||
from .models import HunYuanDiTBlock
|
||||
from .models import HunYuanDiTBlock, calc_rope
|
||||
|
||||
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
|
||||
|
||||
def calc_rope(x, patch_size, head_size):
|
||||
th = (x.shape[2] + (patch_size // 2)) // patch_size
|
||||
tw = (x.shape[3] + (patch_size // 2)) // patch_size
|
||||
base_size = 512 // 8 // patch_size
|
||||
start, stop = get_fill_resize_and_crop((th, tw), base_size)
|
||||
sub_args = [start, stop, (th, tw)]
|
||||
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
||||
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
||||
return rope
|
||||
|
||||
|
||||
class HunYuanControlNet(nn.Module):
|
||||
"""
|
||||
HunYuanDiT: Diffusion model with a Transformer backbone.
|
||||
@@ -213,35 +196,32 @@ class HunYuanControlNet(nn.Module):
|
||||
)
|
||||
|
||||
# Input zero linear for the first block
|
||||
self.before_proj = zero_module(
|
||||
nn.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
)
|
||||
self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
|
||||
# Output zero linear for the every block
|
||||
self.after_proj_list = nn.ModuleList(
|
||||
[
|
||||
zero_module(
|
||||
nn.Linear(
|
||||
|
||||
operations.Linear(
|
||||
self.hidden_size, self.hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
)
|
||||
for _ in range(len(self.blocks))
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor = None,
|
||||
condition=None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
x,
|
||||
hint,
|
||||
timesteps,
|
||||
context,#encoder_hidden_states=None,
|
||||
text_embedding_mask=None,
|
||||
encoder_hidden_states_t5=None,
|
||||
text_embedding_mask_t5=None,
|
||||
image_meta_size=None,
|
||||
style=None,
|
||||
control_weight=1.0,
|
||||
transformer_options=None,
|
||||
return_dict=False,
|
||||
**kwarg,
|
||||
):
|
||||
"""
|
||||
@@ -270,10 +250,11 @@ class HunYuanControlNet(nn.Module):
|
||||
return_dict: bool
|
||||
Whether to return a dictionary.
|
||||
"""
|
||||
condition = hint
|
||||
if condition.shape[0] == 1:
|
||||
condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
|
||||
|
||||
text_states = encoder_hidden_states # 2,77,1024
|
||||
text_states = context # 2,77,1024
|
||||
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||
text_states_mask = text_embedding_mask.bool() # 2,77
|
||||
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
|
||||
@@ -304,7 +285,7 @@ class HunYuanControlNet(nn.Module):
|
||||
) # (cos_cis_img, sin_cis_img)
|
||||
|
||||
# ========================= Build time and image embedding =========================
|
||||
t = self.t_embedder(t, dtype=self.dtype)
|
||||
t = self.t_embedder(timesteps, dtype=self.dtype)
|
||||
x = self.x_embedder(x)
|
||||
|
||||
# ========================= Concatenate all extra vectors =========================
|
||||
@@ -337,12 +318,4 @@ class HunYuanControlNet(nn.Module):
|
||||
x = block(x, c, text_states, freqs_cis_img)
|
||||
controls.append(self.after_proj_list[layer](x)) # zero linear for output
|
||||
|
||||
control_weights = [1.0 * (control_weight ** float(19 - i)) for i in range(19)]
|
||||
assert len(control_weights) == len(
|
||||
controls
|
||||
), "control_weights and controls should have the same length"
|
||||
controls = [
|
||||
control * weight for control, weight in zip(controls, control_weights)
|
||||
]
|
||||
|
||||
return {"output": controls}
|
||||
|
||||
Reference in New Issue
Block a user