feat: per-guide attention strength control in self-attention (#12518)

Implements per-guide attention attenuation via log-space additive bias
in self-attention. Each guide reference tracks its own strength and
optional spatial mask in conditioning metadata (guide_attention_entries).
This commit is contained in:
Tavi Halperin
2026-02-26 08:25:23 +02:00
committed by GitHub
parent 907e5dcbbf
commit a4522017c5
4 changed files with 352 additions and 12 deletions
+45 -2
View File
@@ -134,6 +134,36 @@ class LTXVImgToVideoInplace(io.ComfyNode):
generate = execute # TODO: remove
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
"""Append a guide_attention_entry to both positive and negative conditioning.
Each entry tracks one guide reference for per-reference attention control.
Entries are derived independently from each conditioning to avoid cross-contamination.
"""
new_entry = {
"pre_filter_count": pre_filter_count,
"strength": strength,
"pixel_mask": None,
"latent_shape": latent_shape,
}
results = []
for cond in (positive, negative):
# Read existing entries from this specific conditioning
existing = []
for t in cond:
found = t[1].get("guide_attention_entries", None)
if found is not None:
existing = found
break
# Shallow copy and append (no deepcopy needed — entries contain
# only scalars and None for pixel_mask at this call site).
entries = [*existing, new_entry]
results.append(node_helpers.conditioning_set_values(
cond, {"guide_attention_entries": entries}
))
return results[0], results[1]
def conditioning_get_any_value(conditioning, key, default=None):
for t in conditioning:
if key in t[1]:
@@ -324,6 +354,13 @@ class LTXVAddGuide(io.ComfyNode):
scale_factors,
)
# Track this guide for per-reference attention control.
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
positive, negative = _append_guide_attention_entry(
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
generate = execute # TODO: remove
@@ -359,8 +396,14 @@ class LTXVCropGuides(io.ComfyNode):
latent_image = latent_image[:, :, :-num_keyframes]
noise_mask = noise_mask[:, :, :-num_keyframes]
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
positive = node_helpers.conditioning_set_values(positive, {
"keyframe_idxs": None,
"guide_attention_entries": None,
})
negative = node_helpers.conditioning_set_values(negative, {
"keyframe_idxs": None,
"guide_attention_entries": None,
})
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})