Lower vram usage for flux 2 text encoder. (#10887)
This commit is contained in:
@@ -138,7 +138,7 @@ class Flux2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
return tokens
|
||||
|
||||
class Mistral3_24BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
textmodel_json_config = {}
|
||||
num_layers = model_options.get("num_layers", None)
|
||||
if num_layers is not None:
|
||||
@@ -154,7 +154,7 @@ class Flux2TEModel(sd1_clip.SD1ClipModel):
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||
|
||||
out = torch.stack((out[:, 10], out[:, 20], out[:, 30]), dim=1)
|
||||
out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1)
|
||||
out = out.movedim(1, 2)
|
||||
out = out.reshape(out.shape[0], out.shape[1], -1)
|
||||
return out, pooled, extra
|
||||
|
||||
Reference in New Issue
Block a user