Reduce Peak WAN inference VRAM usage (#9898)

* flux: Do the xq and xk ropes one at a time

This was doing independendent interleaved tensor math on the q and k
tensors, leading to the holding of more than the minimum intermediates
in VRAM. On a bad day, it would VRAM OOM on xk intermediates.

Do everything q and then everything k, so torch can garbage collect
all of qs intermediates before k allocates its intermediates.

This reduces peak VRAM usage for some WAN2.2 inferences (at least).

* wan: Optimize qkv intermediates on attention

As commented. The former logic computed independent pieces of QKV in
parallel which help more inference intermediates in VRAM spiking
VRAM usage. Fully roping Q and garbage collecting the intermediates
before touching K reduces the peak inference VRAM usage.
This commit is contained in:
rattus128
2025-09-17 09:21:14 +10:00
committed by GitHub
parent a39ac59c3e
commit e42682b24e
2 changed files with 18 additions and 15 deletions

View File

@@ -35,11 +35,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device)
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)