Load the SD3 T5xxl model in the same dtype stored in the checkpoint.
This commit is contained in:
@@ -511,17 +511,20 @@ class SD3(supported_models_base.BASE):
|
||||
clip_l = False
|
||||
clip_g = False
|
||||
t5 = False
|
||||
dtype_t5 = None
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||
clip_l = True
|
||||
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||
clip_g = True
|
||||
if "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) in state_dict:
|
||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||
if t5_key in state_dict:
|
||||
t5 = True
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
|
||||
class SD3ClipModel(sd3_clip.SD3ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, device=device, dtype=dtype)
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
|
||||
|
||||
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user