Reduce Peak WAN inference VRAM usage - part II (#10062)

* flux: math: Use _addcmul to avoid expensive VRAM intermediate

The rope process can be the VRAM peak and this intermediate
for the addition result before releasing the original can OOM.
addcmul_ it.

* wan: Delete the self attention before cross attention

This saves VRAM when the cross attention and FFN are in play as the
VRAM peak.
This commit is contained in:
rattus128
2025-09-28 08:14:16 +10:00
committed by GitHub
parent 160698eb41
commit 653ceab414
2 changed files with 5 additions and 1 deletions

View File

@@ -237,6 +237,7 @@ class WanAttentionBlock(nn.Module):
freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
del y
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)