Add a LoraLoader node to apply loras to models and clip.
The models are modified in place before being used and unpatched after. I think this is better than monkeypatching since it might make it easier to use faster non pytorch unet inference in the future.
This commit is contained in:
80
nodes.py
80
nodes.py
@@ -130,6 +130,27 @@ class CheckpointLoader:
|
||||
embedding_directory = os.path.join(self.models_dir, "embeddings")
|
||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=embedding_directory)
|
||||
|
||||
class LoraLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
lora_dir = os.path.join(models_dir, "loras")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"clip": ("CLIP", ),
|
||||
"lora_name": (filter_files_extensions(os.listdir(s.lora_dir), supported_pt_extensions), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL", "CLIP")
|
||||
FUNCTION = "load_lora"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
||||
lora_path = os.path.join(self.lora_dir, lora_name)
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
||||
return (model_lora, clip_lora)
|
||||
|
||||
class VAELoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
vae_dir = os.path.join(models_dir, "vae")
|
||||
@@ -268,35 +289,43 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
|
||||
else:
|
||||
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")
|
||||
|
||||
model = model.to(device)
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
try:
|
||||
real_model = model.patch_model()
|
||||
real_model.to(device)
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
|
||||
positive_copy = []
|
||||
negative_copy = []
|
||||
positive_copy = []
|
||||
negative_copy = []
|
||||
|
||||
for p in positive:
|
||||
t = p[0]
|
||||
if t.shape[0] < noise.shape[0]:
|
||||
t = torch.cat([t] * noise.shape[0])
|
||||
t = t.to(device)
|
||||
positive_copy += [[t] + p[1:]]
|
||||
for n in negative:
|
||||
t = n[0]
|
||||
if t.shape[0] < noise.shape[0]:
|
||||
t = torch.cat([t] * noise.shape[0])
|
||||
t = t.to(device)
|
||||
negative_copy += [[t] + n[1:]]
|
||||
for p in positive:
|
||||
t = p[0]
|
||||
if t.shape[0] < noise.shape[0]:
|
||||
t = torch.cat([t] * noise.shape[0])
|
||||
t = t.to(device)
|
||||
positive_copy += [[t] + p[1:]]
|
||||
for n in negative:
|
||||
t = n[0]
|
||||
if t.shape[0] < noise.shape[0]:
|
||||
t = torch.cat([t] * noise.shape[0])
|
||||
t = t.to(device)
|
||||
negative_copy += [[t] + n[1:]]
|
||||
|
||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
||||
sampler = comfy.samplers.KSampler(model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
|
||||
else:
|
||||
#other samplers
|
||||
pass
|
||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
|
||||
else:
|
||||
#other samplers
|
||||
pass
|
||||
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise)
|
||||
samples = samples.cpu()
|
||||
real_model.cpu()
|
||||
model.unpatch_model()
|
||||
except Exception as e:
|
||||
real_model.cpu()
|
||||
model.unpatch_model()
|
||||
raise e
|
||||
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise)
|
||||
samples = samples.cpu()
|
||||
model = model.cpu()
|
||||
return (samples, )
|
||||
|
||||
class KSampler:
|
||||
@@ -452,6 +481,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"LatentComposite": LatentComposite,
|
||||
"LatentRotate": LatentRotate,
|
||||
"LatentFlip": LatentFlip,
|
||||
"LoraLoader": LoraLoader,
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user