Reduce Peak WAN inference VRAM usage (#9898)
* flux: Do the xq and xk ropes one at a time This was doing independendent interleaved tensor math on the q and k tensors, leading to the holding of more than the minimum intermediates in VRAM. On a bad day, it would VRAM OOM on xk intermediates. Do everything q and then everything k, so torch can garbage collect all of qs intermediates before k allocates its intermediates. This reduces peak VRAM usage for some WAN2.2 inferences (at least). * wan: Optimize qkv intermediates on attention As commented. The former logic computed independent pieces of QKV in parallel which help more inference intermediates in VRAM spiking VRAM usage. Fully roping Q and garbage collecting the intermediates before touching K reduces the peak inference VRAM usage.
This commit is contained in:
@@ -8,7 +8,7 @@ from einops import rearrange
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -60,20 +60,24 @@ class WanSelfAttention(nn.Module):
|
||||
"""
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
# query, key, value function
|
||||
def qkv_fn(x):
|
||||
def qkv_fn_q(x):
|
||||
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
||||
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
||||
v = self.v(x).view(b, s, n * d)
|
||||
return q, k, v
|
||||
return apply_rope1(q, freqs)
|
||||
|
||||
q, k, v = qkv_fn(x)
|
||||
q, k = apply_rope(q, k, freqs)
|
||||
def qkv_fn_k(x):
|
||||
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
||||
return apply_rope1(k, freqs)
|
||||
|
||||
#These two are VRAM hogs, so we want to do all of q computation and
|
||||
#have pytorch garbage collect the intermediates on the sub function
|
||||
#return before we touch k
|
||||
q = qkv_fn_q(x)
|
||||
k = qkv_fn_k(x)
|
||||
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
k.view(b, s, n * d),
|
||||
v,
|
||||
self.v(x).view(b, s, n * d),
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user