Fix LoRA Trainer bugs with FP8 models. (#9854)
* Fix adapter weight init * Fix fp8 model training * Avoid inference tensor
This commit is contained in:
@@ -38,6 +38,23 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||
return new_dict
|
||||
|
||||
|
||||
def process_cond_list(d, prefix=""):
|
||||
if hasattr(d, "__iter__") and not hasattr(d, "items"):
|
||||
for index, item in enumerate(d):
|
||||
process_cond_list(item, f"{prefix}.{index}")
|
||||
return d
|
||||
elif hasattr(d, "items"):
|
||||
for k, v in list(d.items()):
|
||||
if isinstance(v, dict):
|
||||
process_cond_list(v, f"{prefix}.{k}")
|
||||
elif isinstance(v, torch.Tensor):
|
||||
d[k] = v.clone()
|
||||
elif isinstance(v, (list, tuple)):
|
||||
for index, item in enumerate(v):
|
||||
process_cond_list(item, f"{prefix}.{k}.{index}")
|
||||
return d
|
||||
|
||||
|
||||
class TrainSampler(comfy.samplers.Sampler):
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||
self.loss_fn = loss_fn
|
||||
@@ -50,6 +67,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self.training_dtype = training_dtype
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
model_wrap.conds = process_cond_list(model_wrap.conds)
|
||||
cond = model_wrap.conds["positive"]
|
||||
dataset_size = sigmas.size(0)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user