Support loading unet files in diffusers format.
This commit is contained in:
@@ -53,9 +53,9 @@ class SD20(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
|
||||
def v_prediction(self, state_dict):
|
||||
def v_prediction(self, state_dict, prefix=""):
|
||||
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
||||
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
|
||||
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
||||
out = state_dict[k]
|
||||
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||
return True
|
||||
@@ -109,7 +109,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
def get_model(self, state_dict):
|
||||
def get_model(self, state_dict, prefix=""):
|
||||
return model_base.SDXLRefiner(self)
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
@@ -144,7 +144,7 @@ class SDXL(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
def get_model(self, state_dict):
|
||||
def get_model(self, state_dict, prefix=""):
|
||||
return model_base.SDXL(self)
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
|
||||
Reference in New Issue
Block a user