Better per model memory usage estimations.

This commit is contained in:
comfyanonymous
2024-08-02 18:08:21 -04:00
parent 3a9ee995cf
commit ea03c9dcd2
3 changed files with 17 additions and 25 deletions

View File

@@ -31,6 +31,7 @@ class SD15(supported_models_base.BASE):
}
latent_format = latent_formats.SD15
memory_usage_factor = 1.0
def process_clip_state_dict(self, state_dict):
k = list(state_dict.keys())
@@ -77,6 +78,7 @@ class SD20(supported_models_base.BASE):
}
latent_format = latent_formats.SD15
memory_usage_factor = 1.0
def model_type(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
@@ -140,6 +142,7 @@ class SDXLRefiner(supported_models_base.BASE):
}
latent_format = latent_formats.SDXL
memory_usage_factor = 1.0
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXLRefiner(self, device=device)
@@ -178,6 +181,8 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL
memory_usage_factor = 0.7
def model_type(self, state_dict, prefix=""):
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
self.latent_format = latent_formats.SDXL_Playground_2_5()
@@ -505,6 +510,9 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.SD3
memory_usage_factor = 1.2
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
@@ -631,6 +639,9 @@ class Flux(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Flux
memory_usage_factor = 2.6
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]