dynamic_vram: Training fixes (#12442)

This commit is contained in:
rattus
2026-02-13 12:29:37 -08:00
committed by GitHub
parent e03fe8b591
commit 8902907d7a
2 changed files with 14 additions and 1 deletions
+4
View File
@@ -1561,6 +1561,8 @@ class ModelPatcherDynamic(ModelPatcher):
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)
move_weight_functions(m, device_to)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
self.model.device = device_to
@@ -1601,6 +1603,8 @@ class ModelPatcherDynamic(ModelPatcher):
if unpatch_weights:
self.partially_unload_ram(1e32)
self.partially_unload(None, 1e32)
for m in self.model.modules():
move_weight_functions(m, device_to)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above