Enable Runtime Selection of Attention Functions (#9639)
* Looking into a @wrap_attn decorator to look for 'optimized_attention_override' entry in transformer_options * Created logging code for this branch so that it can be used to track down all the code paths where transformer_options would need to be added * Fix memory usage issue with inspect * Made WAN attention receive transformer_options, test node added to wan to test out attention override later * Added **kwargs to all attention functions so transformer_options could potentially be passed through * Make sure wrap_attn doesn't make itself recurse infinitely, attempt to load SageAttention and FlashAttention if not enabled so that they can be marked as available or not, create registry for available attention * Turn off attention logging for now, make AttentionOverrideTestNode have a dropdown with available attention (this is a test node only) * Make flux work with optimized_attention_override * Add logs to verify optimized_attention_override is passed all the way into attention function * Make Qwen work with optimized_attention_override * Made hidream work with optimized_attention_override * Made wan patches_replace work with optimized_attention_override * Made SD3 work with optimized_attention_override * Made HunyuanVideo work with optimized_attention_override * Made Mochi work with optimized_attention_override * Made LTX work with optimized_attention_override * Made StableAudio work with optimized_attention_override * Made optimized_attention_override work with ACE Step * Made Hunyuan3D work with optimized_attention_override * Make CosmosPredict2 work with optimized_attention_override * Made CosmosVideo work with optimized_attention_override * Made Omnigen 2 work with optimized_attention_override * Made StableCascade work with optimized_attention_override * Made AuraFlow work with optimized_attention_override * Made Lumina work with optimized_attention_override * Made Chroma work with optimized_attention_override * Made SVD work with optimized_attention_override * Fix WanI2VCrossAttention so that it expects to receive transformer_options * Fixed Wan2.1 Fun Camera transformer_options passthrough * Fixed WAN 2.1 VACE transformer_options passthrough * Add optimized to get_attention_function * Disable attention logs for now * Remove attention logging code * Remove _register_core_attention_functions, as we wouldn't want someone to call that, just in case * Satisfy ruff * Remove AttentionOverrideTest node, that's something to cook up for later
This commit is contained in:
@@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module):
|
||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||
crop_y,
|
||||
transformer_options={},
|
||||
**rope_rotation,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
rope_cos = rope_rotation.get("rope_cos")
|
||||
@@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module):
|
||||
|
||||
xy = optimized_attention(q,
|
||||
k,
|
||||
v, self.num_heads, skip_reshape=True)
|
||||
v, self.num_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||
x = self.proj_x(x)
|
||||
@@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
transformer_options={},
|
||||
**attn_kwargs,
|
||||
):
|
||||
"""Forward pass of a block.
|
||||
@@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
y,
|
||||
scale_x=scale_msa_x,
|
||||
scale_y=scale_msa_y,
|
||||
transformer_options=transformer_options,
|
||||
**attn_kwargs,
|
||||
)
|
||||
|
||||
@@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module):
|
||||
args["txt"],
|
||||
rope_cos=args["rope_cos"],
|
||||
rope_sin=args["rope_sin"],
|
||||
crop_y=args["num_tokens"]
|
||||
crop_y=args["num_tokens"],
|
||||
transformer_options=args["transformer_options"]
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
y_feat = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
@@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module):
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
crop_y=num_tokens,
|
||||
transformer_options=transformer_options,
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user