Support loading unet files in diffusers format.

This commit is contained in:
comfyanonymous
2023-07-05 17:34:45 -04:00
parent e57cba4c61
commit af7a49916b
9 changed files with 123 additions and 15 deletions

View File

@@ -53,9 +53,9 @@ class SD20(supported_models_base.BASE):
latent_format = latent_formats.SD15
def v_prediction(self, state_dict):
def v_prediction(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
out = state_dict[k]
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
return True
@@ -109,7 +109,7 @@ class SDXLRefiner(supported_models_base.BASE):
latent_format = latent_formats.SDXL
def get_model(self, state_dict):
def get_model(self, state_dict, prefix=""):
return model_base.SDXLRefiner(self)
def process_clip_state_dict(self, state_dict):
@@ -144,7 +144,7 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL
def get_model(self, state_dict):
def get_model(self, state_dict, prefix=""):
return model_base.SDXL(self)
def process_clip_state_dict(self, state_dict):