This commit is contained in:
comfyanonymous
2025-11-25 07:50:19 -08:00
committed by GitHub
parent 015a0599d0
commit 6b573ae0cb
12 changed files with 506 additions and 68 deletions

View File

@@ -2,7 +2,10 @@ import node_helpers
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import comfy.model_management
import torch
import math
import nodes
class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod
@@ -30,6 +33,27 @@ class CLIPTextEncodeFlux(io.ComfyNode):
encode = execute # TODO: remove
class EmptyFlux2LatentImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyFlux2LatentImage",
display_name="Empty Flux 2 Latent",
category="latent",
inputs=[
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent})
class FluxGuidance(io.ComfyNode):
@classmethod
@@ -154,6 +178,58 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
append = execute # TODO: remove
def generalized_time_snr_shift(t, mu: float, sigma: float):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
mu = compute_empirical_mu(image_seq_len, num_steps)
timesteps = torch.linspace(1, 0, num_steps + 1)
timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
return timesteps
class Flux2Scheduler(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Flux2Scheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=4096),
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[
io.Sigmas.Output(),
],
)
@classmethod
def execute(cls, steps, width, height) -> io.NodeOutput:
seq_len = (width * height / (16 * 16))
sigmas = get_schedule(steps, round(seq_len))
return io.NodeOutput(sigmas)
class FluxExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@@ -163,6 +239,8 @@ class FluxExtension(ComfyExtension):
FluxDisableGuidance,
FluxKontextImageScale,
FluxKontextMultiReferenceLatentMethod,
EmptyFlux2LatentImage,
Flux2Scheduler,
]