Lower vram usage for flux 2 text encoder. (#10887)

This commit is contained in:
comfyanonymous
2025-11-25 11:58:39 -08:00
committed by GitHub
parent 18b79acba9
commit d196a905bb
3 changed files with 15 additions and 8 deletions

View File

@@ -138,7 +138,7 @@ class Flux2Tokenizer(sd1_clip.SD1Tokenizer):
return tokens
class Mistral3_24BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = {}
num_layers = model_options.get("num_layers", None)
if num_layers is not None:
@@ -154,7 +154,7 @@ class Flux2TEModel(sd1_clip.SD1ClipModel):
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
out = torch.stack((out[:, 10], out[:, 20], out[:, 30]), dim=1)
out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1)
out = out.movedim(1, 2)
out = out.reshape(out.shape[0], out.shape[1], -1)
return out, pooled, extra