Integrate RAM cache with model RAM management (#13173)

This commit is contained in:
rattus
2026-03-27 18:34:16 -07:00
committed by GitHub
parent 3696c5bad6
commit b353a7c863
9 changed files with 61 additions and 43 deletions
+20 -26
View File
@@ -1,6 +1,5 @@
import asyncio
import bisect
import gc
import itertools
import psutil
import time
@@ -475,6 +474,10 @@ class LRUCache(BasicCache):
self._mark_used(node_id)
return await self._set_immediate(node_id, value)
def set_local(self, node_id, value):
self._mark_used(node_id)
BasicCache.set_local(self, node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
await super()._ensure_subcache(node_id, children_ids)
@@ -489,15 +492,10 @@ class LRUCache(BasicCache):
return self
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
#Small baseline weight used when a cache entry has no measurable CPU tensors.
#Keeps unknown-sized entries in eviction scoring without dominating tensor-backed entries.
RAM_CACHE_HYSTERESIS = 1.1
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
RAM_CACHE_DEFAULT_RAM_USAGE = 0.05
#Exponential bias towards evicting older workflows so garbage will be taken out
#in constantly changing setups.
@@ -521,19 +519,17 @@ class RAMPressureCache(LRUCache):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return await super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
def set_local(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set_local(node_id, value)
if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
def ram_release(self, target):
if psutil.virtual_memory().available >= target:
return
clean_list = []
for key, (outputs, _), in self.cache.items():
for key, cache_entry in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
@@ -542,22 +538,20 @@ class RAMPressureCache(LRUCache):
if outputs is None:
return
for output in outputs:
if isinstance(output, list):
if isinstance(output, (list, tuple)):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
#score Tensors at a 50% discount for RAM usage as they are likely to
#be high value intermediates
ram_usage += (output.numel() * output.element_size()) * 0.5
elif hasattr(output, "get_ram_usage"):
ram_usage += output.get_ram_usage()
scan_list_for_ram_usage(outputs)
ram_usage += output.numel() * output.element_size()
scan_list_for_ram_usage(cache_entry.outputs)
oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
while psutil.virtual_memory().available < target and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()
self.used_generation.pop(key, None)
self.timestamps.pop(key, None)
self.children.pop(key, None)