Cleanups to the last PR. (#12646)

This commit is contained in:
comfyanonymous
2026-02-25 22:30:31 -08:00
committed by GitHub
parent a4522017c5
commit 8a4d85c708
2 changed files with 22 additions and 39 deletions
+2 -38
View File
@@ -65,42 +65,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
class _CONDGuideEntries(comfy.conds.CONDConstant):
"""CONDConstant subclass that safely compares guide_attention_entries.
guide_attention_entries may contain ``pixel_mask`` tensors. The default
``CONDConstant.can_concat`` uses ``!=`` which triggers a ``ValueError``
on tensors. This subclass performs a structural comparison instead.
"""
def can_concat(self, other):
if not isinstance(other, _CONDGuideEntries):
return False
a, b = self.cond, other.cond
if len(a) != len(b):
return False
for ea, eb in zip(a, b):
if ea["pre_filter_count"] != eb["pre_filter_count"]:
return False
if ea["strength"] != eb["strength"]:
return False
if ea.get("latent_shape") != eb.get("latent_shape"):
return False
a_has = ea.get("pixel_mask") is not None
b_has = eb.get("pixel_mask") is not None
if a_has != b_has:
return False
if a_has:
pm_a, pm_b = ea["pixel_mask"], eb["pixel_mask"]
if pm_a is not pm_b:
if (pm_a.shape != pm_b.shape
or pm_a.device != pm_b.device
or pm_a.dtype != pm_b.dtype
or not torch.equal(pm_a, pm_b)):
return False
return True
class ModelType(Enum):
EPS = 1
V_PREDICTION = 2
@@ -1012,7 +976,7 @@ class LTXV(BaseModel):
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries)
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out
@@ -1068,7 +1032,7 @@ class LTXAV(BaseModel):
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries)
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out