Fix VRAM leak in tiler fallback in video VAEs (#13073)

* sd: soft_empty_cache on tiler fallback

This doesnt cost a lot and creates the expected VRAM reduction in
resource monitors when you fallback to tiler.

* wan: vae: Don't recursion in local fns (move run_up)

Moved Decoder3d’s recursive run_up out of forward into a class
method to avoid nested closure self-reference cycles. This avoids
cyclic garbage that delays garbage of tensors which in turn delays
VRAM release before tiled fallback.

* ltx: vae: Don't recursion in local fns (move run_up)

Mov the recursive run_up out of forward into a class
method to avoid nested closure self-reference cycles. This avoids
cyclic garbage that delays garbage of tensors which in turn delays
VRAM release before tiled fallback.
This commit is contained in:
rattus
2026-03-19 19:30:27 -07:00
committed by GitHub
parent 8458ae2686
commit 82b868a45a
3 changed files with 88 additions and 84 deletions
+38 -36
View File
@@ -360,6 +360,43 @@ class Decoder3d(nn.Module):
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, output_channels, 3, padding=1))
def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks):
x = x_ref[0]
x_ref[0] = None
if layer_idx >= len(self.upsamples):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[feat_idx[0]])
feat_cache[feat_idx[0]] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
out_chunks.append(x)
return
layer = self.upsamples[layer_idx]
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
for frame_idx in range(x.shape[2]):
self.run_up(
layer_idx,
[x[:, :, frame_idx:frame_idx + 1, :, :]],
feat_cache,
feat_idx.copy(),
out_chunks,
)
del x
return
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
next_x_ref = [x]
del x
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
@@ -380,42 +417,7 @@ class Decoder3d(nn.Module):
out_chunks = []
def run_up(layer_idx, x_ref, feat_idx):
x = x_ref[0]
x_ref[0] = None
if layer_idx >= len(self.upsamples):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[feat_idx[0]])
feat_cache[feat_idx[0]] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
out_chunks.append(x)
return
layer = self.upsamples[layer_idx]
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
for frame_idx in range(x.shape[2]):
run_up(
layer_idx,
[x[:, :, frame_idx:frame_idx + 1, :, :]],
feat_idx.copy(),
)
del x
return
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
next_x_ref = [x]
del x
run_up(layer_idx + 1, next_x_ref, feat_idx)
run_up(0, [x], feat_idx)
self.run_up(0, [x], feat_cache, feat_idx, out_chunks)
return out_chunks