Fix LoRA Trainer bugs with FP8 models. (#9854)

* Fix adapter weight init

* Fix fp8 model training

* Avoid inference tensor
This commit is contained in:
Kohaku-Blueleaf
2025-09-21 09:24:48 +08:00
committed by GitHub
parent 9ed3c5cc09
commit 7be2b49b6b
6 changed files with 34 additions and 15 deletions

View File

@@ -38,6 +38,23 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
return new_dict
def process_cond_list(d, prefix=""):
if hasattr(d, "__iter__") and not hasattr(d, "items"):
for index, item in enumerate(d):
process_cond_list(item, f"{prefix}.{index}")
return d
elif hasattr(d, "items"):
for k, v in list(d.items()):
if isinstance(v, dict):
process_cond_list(v, f"{prefix}.{k}")
elif isinstance(v, torch.Tensor):
d[k] = v.clone()
elif isinstance(v, (list, tuple)):
for index, item in enumerate(v):
process_cond_list(item, f"{prefix}.{k}.{index}")
return d
class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
self.loss_fn = loss_fn
@@ -50,6 +67,7 @@ class TrainSampler(comfy.samplers.Sampler):
self.training_dtype = training_dtype
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
model_wrap.conds = process_cond_list(model_wrap.conds)
cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0)
torch.cuda.empty_cache()