Enable Runtime Selection of Attention Functions (#9639)

* Looking into a @wrap_attn decorator to look for 'optimized_attention_override' entry in transformer_options

* Created logging code for this branch so that it can be used to track down all the code paths where transformer_options would need to be added

* Fix memory usage issue with inspect

* Made WAN attention receive transformer_options, test node added to wan to test out attention override later

* Added **kwargs to all attention functions so transformer_options could potentially be passed through

* Make sure wrap_attn doesn't make itself recurse infinitely, attempt to load SageAttention and FlashAttention if not enabled so that they can be marked as available or not, create registry for available attention

* Turn off attention logging for now, make AttentionOverrideTestNode have a dropdown with available attention (this is a test node only)

* Make flux work with optimized_attention_override

* Add logs to verify optimized_attention_override is passed all the way into attention function

* Make Qwen work with optimized_attention_override

* Made hidream work with optimized_attention_override

* Made wan patches_replace work with optimized_attention_override

* Made SD3 work with optimized_attention_override

* Made HunyuanVideo work with optimized_attention_override

* Made Mochi work with optimized_attention_override

* Made LTX work with optimized_attention_override

* Made StableAudio work with optimized_attention_override

* Made optimized_attention_override work with ACE Step

* Made Hunyuan3D work with optimized_attention_override

* Make CosmosPredict2 work with optimized_attention_override

* Made CosmosVideo work with optimized_attention_override

* Made Omnigen 2 work with optimized_attention_override

* Made StableCascade work with optimized_attention_override

* Made AuraFlow work with optimized_attention_override

* Made Lumina work with optimized_attention_override

* Made Chroma work with optimized_attention_override

* Made SVD work with optimized_attention_override

* Fix WanI2VCrossAttention so that it expects to receive transformer_options

* Fixed Wan2.1 Fun Camera transformer_options passthrough

* Fixed WAN 2.1 VACE transformer_options passthrough

* Add optimized to get_attention_function

* Disable attention logs for now

* Remove attention logging code

* Remove _register_core_attention_functions, as we wouldn't want someone to call that, just in case

* Satisfy ruff

* Remove AttentionOverrideTest node, that's something to cook up for later
This commit is contained in:
Jedrzej Kosinski
2025-09-12 15:07:38 -07:00
committed by GitHub
parent b149e2e1e3
commit d7f40442f9
26 changed files with 316 additions and 179 deletions

View File

@@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
def forward(self, q, k, v):
def forward(self, q, k, v, transformer_options={}):
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)
out = optimized_attention(q, k, v, self.heads)
out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
return self.out_proj(out)
@@ -47,13 +47,13 @@ class Attention2D(nn.Module):
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
def forward(self, x, kv, self_attn=False):
def forward(self, x, kv, self_attn=False, transformer_options={}):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv)
x = self.attn(x, kv, kv, transformer_options=transformer_options)
x = x.permute(0, 2, 1).view(*orig_shape)
return x
@@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
operations.Linear(c_cond, c, dtype=dtype, device=device)
)
def forward(self, x, kv):
def forward(self, x, kv, transformer_options={}):
kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
return x