Better per model memory usage estimations.
This commit is contained in:
@@ -31,6 +31,7 @@ class SD15(supported_models_base.BASE):
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
k = list(state_dict.keys())
|
||||
@@ -77,6 +78,7 @@ class SD20(supported_models_base.BASE):
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
||||
@@ -140,6 +142,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SDXLRefiner(self, device=device)
|
||||
@@ -178,6 +181,8 @@ class SDXL(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
memory_usage_factor = 0.7
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
|
||||
self.latent_format = latent_formats.SDXL_Playground_2_5()
|
||||
@@ -505,6 +510,9 @@ class SD3(supported_models_base.BASE):
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.SD3
|
||||
|
||||
memory_usage_factor = 1.2
|
||||
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
@@ -631,6 +639,9 @@ class Flux(supported_models_base.BASE):
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
|
||||
memory_usage_factor = 2.6
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
|
||||
Reference in New Issue
Block a user