Stable Cascade Stage C.

This commit is contained in:
comfyanonymous
2024-02-16 10:55:08 -05:00
parent 5e06baf112
commit f83109f09b
11 changed files with 619 additions and 31 deletions

View File

@@ -450,15 +450,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clip_target = None
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
class WeightsLoader(torch.nn.Module):
pass
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
model_config.set_manual_cast(manual_cast_dtype)
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
@@ -507,16 +507,15 @@ def load_unet_state_dict(sd): #load unet in diffusers format
parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
model_config = model_detection.model_config_from_unet(sd, "")
if model_config is None:
return None
new_sd = sd
else: #diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
model_config = model_detection.model_config_from_diffusers_unet(sd)
if model_config is None:
return None
@@ -528,8 +527,11 @@ def load_unet_state_dict(sd): #load unet in diffusers format
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
print(diffusers_keys[k], k)
offload_device = model_management.unet_offload_device()
model_config.set_manual_cast(manual_cast_dtype)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")