Support flux 2 klein kv cache model: Use the FluxKVCache node. (#12905)
This commit is contained in:
@@ -6,6 +6,7 @@ import comfy.model_management
|
||||
import torch
|
||||
import math
|
||||
import nodes
|
||||
import comfy.ldm.flux.math
|
||||
|
||||
class CLIPTextEncodeFlux(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -231,6 +232,68 @@ class Flux2Scheduler(io.ComfyNode):
|
||||
sigmas = get_schedule(steps, round(seq_len))
|
||||
return io.NodeOutput(sigmas)
|
||||
|
||||
class KV_Attn_Input:
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def __call__(self, q, k, v, extra_options, **kwargs):
|
||||
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
|
||||
if len(reference_image_num_tokens) == 0:
|
||||
return {}
|
||||
|
||||
ref_toks = sum(reference_image_num_tokens)
|
||||
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
|
||||
if cache_key in self.cache:
|
||||
kk, vv = self.cache[cache_key]
|
||||
self.set_cache = False
|
||||
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
|
||||
|
||||
self.cache[cache_key] = (k[:, :, -ref_toks:], v[:, :, -ref_toks:])
|
||||
self.set_cache = True
|
||||
return {"q": q, "k": k, "v": v}
|
||||
|
||||
def cleanup(self):
|
||||
self.cache = {}
|
||||
|
||||
|
||||
class FluxKVCache(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="FluxKVCache",
|
||||
display_name="Flux KV Cache",
|
||||
description="Enables KV Cache optimization for reference images on Flux family models.",
|
||||
category="",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to use KV Cache on."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
input_patch_obj = KV_Attn_Input()
|
||||
|
||||
def model_input_patch(inputs):
|
||||
if len(input_patch_obj.cache) > 0:
|
||||
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
|
||||
if ref_image_tokens > 0:
|
||||
img = inputs["img"]
|
||||
inputs["img"] = img[:, :-ref_image_tokens]
|
||||
return inputs
|
||||
|
||||
m.set_model_attn1_patch(input_patch_obj)
|
||||
m.set_model_post_input_patch(model_input_patch)
|
||||
if hasattr(model.model.diffusion_model, "params"):
|
||||
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
|
||||
else:
|
||||
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class FluxExtension(ComfyExtension):
|
||||
@override
|
||||
@@ -243,6 +306,7 @@ class FluxExtension(ComfyExtension):
|
||||
FluxKontextMultiReferenceLatentMethod,
|
||||
EmptyFlux2LatentImage,
|
||||
Flux2Scheduler,
|
||||
FluxKVCache,
|
||||
]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user