IP2P model loading support.

This is the code to load the model and inference it with only a text
prompt. This commit does not contain the nodes to properly use it with an
image input.

This supports both the original SD1 instructpix2pix model and the
diffusers SDXL one.
This commit is contained in:
comfyanonymous
2024-03-31 01:25:16 -04:00
parent 96b4c757cf
commit 575acb69e4
4 changed files with 84 additions and 6 deletions

View File

@@ -473,6 +473,40 @@ class SD_X4Upscaler(BaseModel):
out['y'] = comfy.conds.CONDRegular(noise_level)
return out
class IP2P:
def extra_conds(self, **kwargs):
out = {}
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
return out
class SD15_instructpix2pix(IP2P, BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.process_ip2p_image_in = lambda image: image
class SDXL_instructpix2pix(IP2P, SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
# self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image)
self.process_ip2p_image_in = lambda image: image
class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageC)