Enable Runtime Selection of Attention Functions (#9639)
* Looking into a @wrap_attn decorator to look for 'optimized_attention_override' entry in transformer_options * Created logging code for this branch so that it can be used to track down all the code paths where transformer_options would need to be added * Fix memory usage issue with inspect * Made WAN attention receive transformer_options, test node added to wan to test out attention override later * Added **kwargs to all attention functions so transformer_options could potentially be passed through * Make sure wrap_attn doesn't make itself recurse infinitely, attempt to load SageAttention and FlashAttention if not enabled so that they can be marked as available or not, create registry for available attention * Turn off attention logging for now, make AttentionOverrideTestNode have a dropdown with available attention (this is a test node only) * Make flux work with optimized_attention_override * Add logs to verify optimized_attention_override is passed all the way into attention function * Make Qwen work with optimized_attention_override * Made hidream work with optimized_attention_override * Made wan patches_replace work with optimized_attention_override * Made SD3 work with optimized_attention_override * Made HunyuanVideo work with optimized_attention_override * Made Mochi work with optimized_attention_override * Made LTX work with optimized_attention_override * Made StableAudio work with optimized_attention_override * Made optimized_attention_override work with ACE Step * Made Hunyuan3D work with optimized_attention_override * Make CosmosPredict2 work with optimized_attention_override * Made CosmosVideo work with optimized_attention_override * Made Omnigen 2 work with optimized_attention_override * Made StableCascade work with optimized_attention_override * Made AuraFlow work with optimized_attention_override * Made Lumina work with optimized_attention_override * Made Chroma work with optimized_attention_override * Made SVD work with optimized_attention_override * Fix WanI2VCrossAttention so that it expects to receive transformer_options * Fixed Wan2.1 Fun Camera transformer_options passthrough * Fixed WAN 2.1 VACE transformer_options passthrough * Add optimized to get_attention_function * Disable attention logs for now * Remove attention logging code * Remove _register_core_attention_functions, as we wouldn't want someone to call that, just in case * Satisfy ruff * Remove AttentionOverrideTest node, that's something to cook up for later
This commit is contained in:
@@ -314,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
output_length: int = 0,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
transformer_options={},
|
||||
):
|
||||
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
||||
temb = self.t_block(embedded_timestep)
|
||||
@@ -339,6 +340,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||
temb=temb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||
@@ -393,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
|
||||
output_length = hidden_states.shape[-1]
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
output = self.decode(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
@@ -402,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
output_length=output_length,
|
||||
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
||||
controlnet_scale=controlnet_scale,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user