FP8 bwd training (#13121)

This commit is contained in:
Kohaku-Blueleaf
2026-03-25 08:39:04 +08:00
committed by GitHub
parent a0a64c679f
commit 5ebb0c2e0b
3 changed files with 59 additions and 16 deletions
+9
View File
@@ -1030,6 +1030,11 @@ class TrainLoraNode(io.ComfyNode):
default="bf16",
tooltip="The dtype to use for lora.",
),
io.Boolean.Input(
"quantized_backward",
default=False,
tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.",
),
io.Combo.Input(
"algorithm",
options=list(adapter_maps.keys()),
@@ -1097,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode):
seed,
training_dtype,
lora_dtype,
quantized_backward,
algorithm,
gradient_checkpointing,
checkpoint_depth,
@@ -1117,6 +1123,7 @@ class TrainLoraNode(io.ComfyNode):
seed = seed[0]
training_dtype = training_dtype[0]
lora_dtype = lora_dtype[0]
quantized_backward = quantized_backward[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
offloading = offloading[0]
@@ -1125,6 +1132,8 @@ class TrainLoraNode(io.ComfyNode):
bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0]
comfy.model_management.training_fp8_bwd = quantized_backward
# Process latents based on mode
if bucket_mode:
latents = _process_latents_bucket_mode(latents)