feat(api-nodes): network client v2: async ops, cancellation, downloads, refactor (#10390)

* feat(api-nodes): implement new API client for V3 nodes

* feat(api-nodes): implement new API client for V3 nodes

* feat(api-nodes): implement new API client for V3 nodes

* converted WAN nodes to use new client; polishing

* fix(auth): do not leak authentification for the absolute urls

* convert BFL API nodes to use new API client; remove deprecated BFL nodes

* converted Google Veo nodes

* fix(Veo3.1 model): take into account "generate_audio" parameter
This commit is contained in:
Alexander Piskun
2025-10-24 08:37:16 +03:00
committed by GitHub
parent 24188b3141
commit 388b306a2b
29 changed files with 2935 additions and 2298 deletions

View File

@@ -0,0 +1,87 @@
from ._helpers import get_fs_object_size
from .client import (
ApiEndpoint,
poll_op,
poll_op_raw,
sync_op,
sync_op_raw,
)
from .conversions import (
audio_bytes_to_audio_input,
audio_input_to_mp3,
audio_to_base64_string,
bytesio_to_image_tensor,
downscale_image_tensor,
image_tensor_pair_to_batch,
pil_to_bytesio,
tensor_to_base64_string,
tensor_to_bytesio,
tensor_to_pil,
trim_video,
)
from .download_helpers import (
download_url_to_bytesio,
download_url_to_image_tensor,
download_url_to_video_output,
)
from .upload_helpers import (
upload_audio_to_comfyapi,
upload_file_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
)
from .validation_utils import (
get_number_of_images,
validate_aspect_ratio_closeness,
validate_audio_duration,
validate_container_format_is_mp4,
validate_image_aspect_ratio,
validate_image_aspect_ratio_range,
validate_image_dimensions,
validate_string,
validate_video_dimensions,
validate_video_duration,
)
__all__ = [
# API client
"ApiEndpoint",
"poll_op",
"poll_op_raw",
"sync_op",
"sync_op_raw",
# Upload helpers
"upload_audio_to_comfyapi",
"upload_file_to_comfyapi",
"upload_images_to_comfyapi",
"upload_video_to_comfyapi",
# Download helpers
"download_url_to_bytesio",
"download_url_to_image_tensor",
"download_url_to_video_output",
# Conversions
"audio_bytes_to_audio_input",
"audio_input_to_mp3",
"audio_to_base64_string",
"bytesio_to_image_tensor",
"downscale_image_tensor",
"image_tensor_pair_to_batch",
"pil_to_bytesio",
"tensor_to_base64_string",
"tensor_to_bytesio",
"tensor_to_pil",
"trim_video",
# Validation utilities
"get_number_of_images",
"validate_aspect_ratio_closeness",
"validate_audio_duration",
"validate_container_format_is_mp4",
"validate_image_aspect_ratio",
"validate_image_aspect_ratio_range",
"validate_image_dimensions",
"validate_string",
"validate_video_dimensions",
"validate_video_duration",
# Misc functions
"get_fs_object_size",
]

View File

@@ -0,0 +1,71 @@
import asyncio
import contextlib
import os
import time
from io import BytesIO
from typing import Callable, Optional, Union
from comfy.cli_args import args
from comfy.model_management import processing_interrupted
from comfy_api.latest import IO
from .common_exceptions import ProcessingInterrupted
def is_processing_interrupted() -> bool:
"""Return True if user/runtime requested interruption."""
return processing_interrupted()
def get_node_id(node_cls: type[IO.ComfyNode]) -> str:
return node_cls.hidden.unique_id
def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
if node_cls.hidden.auth_token_comfy_org:
return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"}
if node_cls.hidden.api_key_comfy_org:
return {"X-API-KEY": node_cls.hidden.api_key_comfy_org}
return {}
def default_base_url() -> str:
return getattr(args, "comfy_api_base", "https://api.comfy.org")
async def sleep_with_interrupt(
seconds: float,
node_cls: Optional[type[IO.ComfyNode]],
label: Optional[str] = None,
start_ts: Optional[float] = None,
estimated_total: Optional[int] = None,
*,
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
):
"""
Sleep in 1s slices while:
- Checking for interruption (raises ProcessingInterrupted).
- Optionally emitting time progress via display_callback (if provided).
"""
end = time.monotonic() + seconds
while True:
if is_processing_interrupted():
raise ProcessingInterrupted("Task cancelled")
now = time.monotonic()
if start_ts is not None and label and display_callback:
with contextlib.suppress(Exception):
display_callback(node_cls, label, int(now - start_ts), estimated_total)
if now >= end:
break
await asyncio.sleep(min(1.0, end - now))
def mimetype_to_extension(mime_type: str) -> str:
"""Converts a MIME type to a file extension."""
return mime_type.split("/")[-1].lower()
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
if isinstance(path_or_object, str):
return os.path.getsize(path_or_object)
return len(path_or_object.getvalue())

View File

@@ -0,0 +1,941 @@
import asyncio
import contextlib
import json
import logging
import socket
import time
import uuid
from dataclasses import dataclass
from enum import Enum
from io import BytesIO
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
from urllib.parse import urljoin, urlparse
import aiohttp
from aiohttp.client_exceptions import ClientError, ContentTypeError
from pydantic import BaseModel
from comfy import utils
from comfy_api.latest import IO
from comfy_api_nodes.apis import request_logger
from server import PromptServer
from ._helpers import (
default_base_url,
get_auth_header,
get_node_id,
is_processing_interrupted,
sleep_with_interrupt,
)
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
M = TypeVar("M", bound=BaseModel)
class ApiEndpoint:
def __init__(
self,
path: str,
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
*,
query_params: Optional[dict[str, Any]] = None,
headers: Optional[dict[str, str]] = None,
):
self.path = path
self.method = method
self.query_params = query_params or {}
self.headers = headers or {}
@dataclass
class _RequestConfig:
node_cls: type[IO.ComfyNode]
endpoint: ApiEndpoint
timeout: float
content_type: str
data: Optional[dict[str, Any]]
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
multipart_parser: Optional[Callable]
max_retries: int
retry_delay: float
retry_backoff: float
wait_label: str = "Waiting"
monitor_progress: bool = True
estimated_total: Optional[int] = None
final_label_on_success: Optional[str] = "Completed"
progress_origin_ts: Optional[float] = None
@dataclass
class _PollUIState:
started: float
status_label: str = "Queued"
is_queued: bool = True
price: Optional[float] = None
estimated_duration: Optional[int] = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals
active_since: Optional[float] = None # start time of current active interval (None if queued)
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"]
FAILED_STATUSES = ["cancelled", "canceled", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
async def sync_op(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
response_model: Type[M],
data: Optional[BaseModel] = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Optional[Callable] = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
estimated_duration: Optional[int] = None,
final_label_on_success: Optional[str] = "Completed",
progress_origin_ts: Optional[float] = None,
monitor_progress: bool = True,
) -> M:
raw = await sync_op_raw(
cls,
endpoint,
data=data,
files=files,
content_type=content_type,
timeout=timeout,
multipart_parser=multipart_parser,
max_retries=max_retries,
retry_delay=retry_delay,
retry_backoff=retry_backoff,
wait_label=wait_label,
estimated_duration=estimated_duration,
as_binary=False,
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
monitor_progress=monitor_progress,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
return _validate_or_raise(response_model, raw)
async def poll_op(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
response_model: Type[M],
status_extractor: Callable[[M], Optional[Union[str, int]]],
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
completed_statuses: Optional[list[Union[str, int]]] = None,
failed_statuses: Optional[list[Union[str, int]]] = None,
queued_statuses: Optional[list[Union[str, int]]] = None,
data: Optional[BaseModel] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
estimated_duration: Optional[int] = None,
cancel_endpoint: Optional[ApiEndpoint] = None,
cancel_timeout: float = 10.0,
) -> M:
raw = await poll_op_raw(
cls,
poll_endpoint=poll_endpoint,
status_extractor=_wrap_model_extractor(response_model, status_extractor),
progress_extractor=_wrap_model_extractor(response_model, progress_extractor),
price_extractor=_wrap_model_extractor(response_model, price_extractor),
completed_statuses=completed_statuses,
failed_statuses=failed_statuses,
queued_statuses=queued_statuses,
data=data,
poll_interval=poll_interval,
max_poll_attempts=max_poll_attempts,
timeout_per_poll=timeout_per_poll,
max_retries_per_poll=max_retries_per_poll,
retry_delay_per_poll=retry_delay_per_poll,
retry_backoff_per_poll=retry_backoff_per_poll,
estimated_duration=estimated_duration,
cancel_endpoint=cancel_endpoint,
cancel_timeout=cancel_timeout,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
return _validate_or_raise(response_model, raw)
async def sync_op_raw(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
data: Optional[Union[dict[str, Any], BaseModel]] = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Optional[Callable] = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
estimated_duration: Optional[int] = None,
as_binary: bool = False,
final_label_on_success: Optional[str] = "Completed",
progress_origin_ts: Optional[float] = None,
monitor_progress: bool = True,
) -> Union[dict[str, Any], bytes]:
"""
Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
- If as_binary=True: returns bytes.
"""
if isinstance(data, BaseModel):
data = data.model_dump(exclude_none=True)
for k, v in list(data.items()):
if isinstance(v, Enum):
data[k] = v.value
cfg = _RequestConfig(
node_cls=cls,
endpoint=endpoint,
timeout=timeout,
content_type=content_type,
data=data,
files=files,
multipart_parser=multipart_parser,
max_retries=max_retries,
retry_delay=retry_delay,
retry_backoff=retry_backoff,
wait_label=wait_label,
monitor_progress=monitor_progress,
estimated_total=estimated_duration,
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
)
return await _request_base(cfg, expect_binary=as_binary)
async def poll_op_raw(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
completed_statuses: Optional[list[Union[str, int]]] = None,
failed_statuses: Optional[list[Union[str, int]]] = None,
queued_statuses: Optional[list[Union[str, int]]] = None,
data: Optional[Union[dict[str, Any], BaseModel]] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
estimated_duration: Optional[int] = None,
cancel_endpoint: Optional[ApiEndpoint] = None,
cancel_timeout: float = 10.0,
) -> dict[str, Any]:
"""
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
checks interruption every second, and calls Cancel endpoint (if provided) on interruption.
Uses default complete, failed and queued states assumption.
Returns the final JSON response from the poll endpoint.
"""
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
started = time.monotonic()
consumed_attempts = 0 # counts only non-queued polls
progress_bar = utils.ProgressBar(100) if progress_extractor else None
last_progress: Optional[int] = None
state = _PollUIState(started=started, estimated_duration=estimated_duration)
stop_ticker = asyncio.Event()
async def _ticker():
"""Emit a UI update every second while polling is in progress."""
try:
while not stop_ticker.is_set():
if is_processing_interrupted():
break
now = time.monotonic()
proc_elapsed = state.base_processing_elapsed + (
(now - state.active_since) if state.active_since is not None else 0.0
)
_display_time_progress(
cls,
status=state.status_label,
elapsed_seconds=int(now - state.started),
estimated_total=state.estimated_duration,
price=state.price,
is_queued=state.is_queued,
processing_elapsed_seconds=int(proc_elapsed),
)
await asyncio.sleep(1.0)
except Exception as exc:
logging.debug("Polling ticker exited: %s", exc)
ticker_task = asyncio.create_task(_ticker())
try:
while consumed_attempts < max_poll_attempts:
try:
resp_json = await sync_op_raw(
cls,
poll_endpoint,
data=data,
timeout=timeout_per_poll,
max_retries=max_retries_per_poll,
retry_delay=retry_delay_per_poll,
retry_backoff=retry_backoff_per_poll,
wait_label="Checking",
estimated_duration=None,
as_binary=False,
final_label_on_success=None,
monitor_progress=False,
)
if not isinstance(resp_json, dict):
raise Exception("Polling endpoint returned non-JSON response.")
except ProcessingInterrupted:
if cancel_endpoint:
with contextlib.suppress(Exception):
await sync_op_raw(
cls,
cancel_endpoint,
timeout=cancel_timeout,
max_retries=0,
wait_label="Cancelling task",
estimated_duration=None,
as_binary=False,
final_label_on_success=None,
monitor_progress=False,
)
raise
try:
status = _normalize_status_value(status_extractor(resp_json))
except Exception as e:
logging.error("Status extraction failed: %s", e)
status = None
if price_extractor:
new_price = price_extractor(resp_json)
if new_price is not None:
state.price = new_price
if progress_extractor:
new_progress = progress_extractor(resp_json)
if new_progress is not None and last_progress != new_progress:
progress_bar.update_absolute(new_progress, total=100)
last_progress = new_progress
now_ts = time.monotonic()
is_queued = status in queued_states
if is_queued:
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
state.base_processing_elapsed += now_ts - state.active_since
state.active_since = None
else:
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
state.active_since = now_ts
state.is_queued = is_queued
state.status_label = status or ("Queued" if is_queued else "Processing")
if status in completed_states:
if state.active_since is not None:
state.base_processing_elapsed += now_ts - state.active_since
state.active_since = None
stop_ticker.set()
with contextlib.suppress(Exception):
await ticker_task
if progress_bar and last_progress != 100:
progress_bar.update_absolute(100, total=100)
_display_time_progress(
cls,
status=status if status else "Completed",
elapsed_seconds=int(now_ts - started),
estimated_total=estimated_duration,
price=state.price,
is_queued=False,
processing_elapsed_seconds=int(state.base_processing_elapsed),
)
return resp_json
if status in failed_states:
msg = f"Task failed: {json.dumps(resp_json)}"
logging.error(msg)
raise Exception(msg)
try:
await sleep_with_interrupt(poll_interval, cls, None, None, None)
except ProcessingInterrupted:
if cancel_endpoint:
with contextlib.suppress(Exception):
await sync_op_raw(
cls,
cancel_endpoint,
timeout=cancel_timeout,
max_retries=0,
wait_label="Cancelling task",
estimated_duration=None,
as_binary=False,
final_label_on_success=None,
monitor_progress=False,
)
raise
if not is_queued:
consumed_attempts += 1
raise Exception(
f"Polling timed out after {max_poll_attempts} non-queued attempts "
f"(~{int(max_poll_attempts * poll_interval)}s of active polling)."
)
except ProcessingInterrupted:
raise
except (LocalNetworkError, ApiServerError):
raise
except Exception as e:
raise Exception(f"Polling aborted due to error: {e}") from e
finally:
stop_ticker.set()
with contextlib.suppress(Exception):
await ticker_task
def _display_text(
node_cls: type[IO.ComfyNode],
text: Optional[str],
*,
status: Optional[Union[str, int]] = None,
price: Optional[float] = None,
) -> None:
display_lines: list[str] = []
if status:
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
if price is not None:
display_lines.append(f"Price: ${float(price):,.4f}")
if text is not None:
display_lines.append(text)
if display_lines:
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
def _display_time_progress(
node_cls: type[IO.ComfyNode],
status: Optional[Union[str, int]],
elapsed_seconds: int,
estimated_total: Optional[int] = None,
*,
price: Optional[float] = None,
is_queued: Optional[bool] = None,
processing_elapsed_seconds: Optional[int] = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
remaining = max(0, int(estimated_total) - int(pe))
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
else:
time_line = f"Time elapsed: {int(elapsed_seconds)}s"
_display_text(node_cls, time_line, status=status, price=price)
async def _diagnose_connectivity() -> dict[str, bool]:
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
results = {
"internet_accessible": False,
"api_accessible": False,
"is_local_issue": False,
"is_api_issue": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
async with aiohttp.ClientSession(timeout=timeout) as session:
try:
async with session.get("https://www.google.com") as resp:
results["internet_accessible"] = resp.status < 500
except (ClientError, asyncio.TimeoutError, socket.gaierror):
results["is_local_issue"] = True
return results
parsed = urlparse(default_base_url())
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
with contextlib.suppress(ClientError, asyncio.TimeoutError):
async with session.get(health_url) as resp:
results["api_accessible"] = resp.status < 500
results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"]
return results
def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
"""Normalize (filename, value, content_type)."""
if len(t) == 2:
return t[0], t[1], "application/octet-stream"
if len(t) == 3:
return t[0], t[1], t[2]
raise ValueError("files tuple must be (filename, file[, content_type])")
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
params = dict(endpoint_params or {})
if method.upper() == "GET" and data:
for k, v in data.items():
if v is not None:
params[k] = v
return params
def _friendly_http_message(status: int, body: Any) -> str:
if status == 401:
return "Unauthorized: Please login first to use this node."
if status == 402:
return "Payment Required: Please add credits to your account to use this node."
if status == 409:
return "There is a problem with your account. Please contact support@comfy.org."
if status == 429:
return "Rate Limit Exceeded: Please try again later."
try:
if isinstance(body, dict):
err = body.get("error")
if isinstance(err, dict):
msg = err.get("message")
typ = err.get("type")
if msg and typ:
return f"API Error: {msg} (Type: {typ})"
if msg:
return f"API Error: {msg}"
return f"API Error: {json.dumps(body)}"
else:
txt = str(body)
if len(txt) <= 200:
return f"API Error (raw): {txt}"
return f"API Error (status {status})"
except Exception:
return f"HTTP {status}: Unknown error"
def _generate_operation_id(method: str, path: str, attempt: int) -> str:
slug = path.strip("/").replace("/", "_") or "op"
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
def _snapshot_request_body_for_logging(
content_type: str,
method: str,
data: Optional[dict[str, Any]],
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
) -> Optional[Union[dict[str, Any], str]]:
if method.upper() == "GET":
return None
if content_type == "multipart/form-data":
form_fields = sorted([k for k, v in (data or {}).items() if v is not None])
file_fields: list[dict[str, str]] = []
if files:
file_iter = files if isinstance(files, list) else list(files.items())
for field_name, file_obj in file_iter:
if file_obj is None:
continue
if isinstance(file_obj, tuple):
filename = file_obj[0]
else:
filename = getattr(file_obj, "name", field_name)
file_fields.append({"field": field_name, "filename": str(filename or "")})
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
if content_type == "application/x-www-form-urlencoded":
return data or {}
return data or {}
async def _request_base(cfg: _RequestConfig, expect_binary: bool):
"""Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors."""
url = cfg.endpoint.path
parsed_url = urlparse(url)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
method = cfg.endpoint.method
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
"""Every second: update elapsed time and signal interruption."""
try:
while not stop_evt.is_set():
if is_processing_interrupted():
return
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return # normal shutdown
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
attempt = 0
delay = cfg.retry_delay
operation_succeeded: bool = False
final_elapsed_seconds: Optional[int] = None
while True:
attempt += 1
stop_event = asyncio.Event()
monitor_task: Optional[asyncio.Task] = None
sess: Optional[aiohttp.ClientSession] = None
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
payload_headers = {"Accept": "*/*"}
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls))
if cfg.endpoint.headers:
payload_headers.update(cfg.endpoint.headers)
payload_kw: dict[str, Any] = {"headers": payload_headers}
if method == "GET":
payload_headers.pop("Content-Type", None)
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
try:
if cfg.monitor_progress:
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
timeout = aiohttp.ClientTimeout(total=cfg.timeout)
sess = aiohttp.ClientSession(timeout=timeout)
if cfg.content_type == "multipart/form-data" and method != "GET":
# aiohttp will set Content-Type boundary; remove any fixed Content-Type
payload_headers.pop("Content-Type", None)
if cfg.multipart_parser and cfg.data:
form = cfg.multipart_parser(cfg.data)
if not isinstance(form, aiohttp.FormData):
raise ValueError("multipart_parser must return aiohttp.FormData")
else:
form = aiohttp.FormData(default_to_multipart=True)
if cfg.data:
for k, v in cfg.data.items():
if v is None:
continue
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
if cfg.files:
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
for field_name, file_obj in file_iter:
if file_obj is None:
continue
if isinstance(file_obj, tuple):
filename, file_value, content_type = _unpack_tuple(file_obj)
else:
filename = getattr(file_obj, "name", field_name)
file_value = file_obj
content_type = "application/octet-stream"
# Attempt to rewind BytesIO for retries
if isinstance(file_value, BytesIO):
with contextlib.suppress(Exception):
file_value.seek(0)
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
payload_kw["data"] = form
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
payload_kw["data"] = cfg.data or {}
elif method != "GET":
payload_headers["Content-Type"] = "application/json"
payload_kw["json"] = cfg.data or {}
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
)
except Exception as _log_e:
logging.debug("[DEBUG] request logging failed: %s", _log_e)
req_coro = sess.request(method, url, params=params, **payload_kw)
req_task = asyncio.create_task(req_coro)
# Race: request vs. monitor (interruption)
tasks = {req_task}
if monitor_task:
tasks.add(monitor_task)
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
if monitor_task and monitor_task in done:
# Interrupted cancel the request and abort
if req_task in pending:
req_task.cancel()
raise ProcessingInterrupted("Task cancelled")
# Otherwise, request finished
resp = await req_task
async with resp:
if resp.status >= 400:
try:
body = await resp.json()
except (ContentTypeError, json.JSONDecodeError):
body = await resp.text()
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
logging.warning(
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
method,
url,
resp.status,
delay,
attempt,
cfg.max_retries,
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=_friendly_http_message(resp.status, body),
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
await sleep_with_interrupt(
delay,
cfg.node_cls,
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
)
delay *= cfg.retry_backoff
continue
msg = _friendly_http_message(resp.status, body)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=msg,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
raise Exception(msg)
if expect_binary:
buff = bytearray()
last_tick = time.monotonic()
async for chunk in resp.content.iter_chunked(64 * 1024):
buff.extend(chunk)
now = time.monotonic()
if now - last_tick >= 1.0:
last_tick = now
if is_processing_interrupted():
raise ProcessingInterrupted("Task cancelled")
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
)
bytes_payload = bytes(buff)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=bytes_payload,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
return bytes_payload
else:
try:
payload = await resp.json()
response_content_to_log: Any = payload
except (ContentTypeError, json.JSONDecodeError):
text = await resp.text()
try:
payload = json.loads(text) if text else {}
except json.JSONDecodeError:
payload = {"_raw": text}
response_content_to_log = payload if isinstance(payload, dict) else text
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=response_content_to_log,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
return payload
except ProcessingInterrupted:
logging.debug("Polling was interrupted by user")
raise
except (ClientError, asyncio.TimeoutError, socket.gaierror) as e:
if attempt <= cfg.max_retries:
logging.warning(
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
method,
url,
delay,
attempt,
cfg.max_retries,
str(e),
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
except Exception as _log_e:
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
await sleep_with_interrupt(
delay,
cfg.node_cls,
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
)
delay *= cfg.retry_backoff
continue
diag = await _diagnose_connectivity()
if diag.get("is_local_issue"):
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"LocalNetworkError: {str(e)}",
)
except Exception as _log_e:
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"ApiServerError: {str(e)}",
)
except Exception as _log_e:
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
raise ApiServerError(
f"The API server at {default_base_url()} is currently unreachable. "
f"The service may be experiencing issues."
) from e
finally:
stop_event.set()
if monitor_task:
monitor_task.cancel()
with contextlib.suppress(Exception):
await monitor_task
if sess:
with contextlib.suppress(Exception):
await sess.close()
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
_display_time_progress(
cfg.node_cls,
status=cfg.final_label_on_success,
elapsed_seconds=(
final_elapsed_seconds
if final_elapsed_seconds is not None
else int(time.monotonic() - start_time)
),
estimated_total=cfg.estimated_total,
price=None,
is_queued=False,
processing_elapsed_seconds=final_elapsed_seconds,
)
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
try:
return response_model.model_validate(payload)
except Exception as e:
logging.error(
"Response validation failed for %s: %s",
getattr(response_model, "__name__", response_model),
e,
)
raise Exception(
f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}"
) from e
def _wrap_model_extractor(
response_model: Type[M],
extractor: Optional[Callable[[M], Any]],
) -> Optional[Callable[[dict[str, Any]], Any]]:
"""Wrap a typed extractor so it can be used by the dict-based poller.
Validates the dict into `response_model` before invoking `extractor`.
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
the same response for multiple extractors in a single poll attempt.
"""
if extractor is None:
return None
_cache: dict[int, M] = {}
def _wrapped(d: dict[str, Any]) -> Any:
try:
key = id(d)
model = _cache.get(key)
if model is None:
model = response_model.model_validate(d)
_cache[key] = model
return extractor(model)
except Exception as e:
logging.error("Extractor failed (typed -> dict wrapper): %s", e)
raise
return _wrapped
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
if not values:
return set()
out: set[Union[str, int]] = set()
for v in values:
nv = _normalize_status_value(v)
if nv is not None:
out.add(nv)
return out
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
if isinstance(val, str):
return val.strip().lower()
return val

View File

@@ -0,0 +1,14 @@
class NetworkError(Exception):
"""Base exception for network-related errors with diagnostic information."""
class LocalNetworkError(NetworkError):
"""Exception raised when local network connectivity issues are detected."""
class ApiServerError(NetworkError):
"""Exception raised when the API server is unreachable but internet is working."""
class ProcessingInterrupted(Exception):
"""Operation was interrupted by user/runtime via processing_interrupted()."""

View File

@@ -0,0 +1,407 @@
import base64
import logging
import math
import uuid
from io import BytesIO
from typing import Optional
import av
import numpy as np
import torch
from PIL import Image
from comfy.utils import common_upscale
from comfy_api.latest import Input, InputImpl
from ._helpers import mimetype_to_extension
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
"""Converts image data from BytesIO to a torch.Tensor.
Args:
image_bytesio: BytesIO object containing the image data.
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
PIL.UnidentifiedImageError: If the image data cannot be identified.
ValueError: If the specified mode is invalid.
"""
image = Image.open(image_bytesio)
image = image.convert(mode)
image_array = np.array(image).astype(np.float32) / 255.0
return torch.from_numpy(image_array).unsqueeze(0)
def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor:
"""
Converts a pair of image tensors to a batch tensor.
If the images are not the same size, the smaller image is resized to
match the larger image.
"""
if image1.shape[1:] != image2.shape[1:]:
image2 = common_upscale(
image2.movedim(-1, 1),
image1.shape[2],
image1.shape[1],
"bilinear",
"center",
).movedim(1, -1)
return torch.cat((image1, image2), dim=0)
def tensor_to_bytesio(
image: torch.Tensor,
name: Optional[str] = None,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> BytesIO:
"""Converts a torch.Tensor image to a named BytesIO object.
Args:
image: Input torch.Tensor image.
name: Optional filename for the BytesIO object.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Named BytesIO object containing the image data, with pointer set to the start of buffer.
"""
if not mime_type:
mime_type = "image/png"
pil_image = tensor_to_pil(image, total_pixels=total_pixels)
img_binary = pil_to_bytesio(pil_image, mime_type=mime_type)
img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
return img_binary
def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
if len(image.shape) > 3:
image = image[0]
# TODO: remove alpha if not allowed and present
input_tensor = image.cpu()
input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze()
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
return img
def tensor_to_base64_string(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Base64 encoded string of the image.
"""
pil_image = tensor_to_pil(image_tensor, total_pixels=total_pixels)
img_byte_arr = pil_to_bytesio(pil_image, mime_type=mime_type)
img_bytes = img_byte_arr.getvalue()
# Encode bytes to base64 string
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
return base64_encoded_string
def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
"""Converts a PIL Image to a BytesIO object."""
if not mime_type:
mime_type = "image/png"
img_byte_arr = BytesIO()
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
pil_format = mime_type.split("/")[-1].upper()
if pil_format == "JPG":
pil_format = "JPEG"
img.save(img_byte_arr, format=pil_format)
img_byte_arr.seek(0)
return img_byte_arr
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
"""Downscale input image tensor to roughly the specified total pixels."""
samples = image.movedim(-1, 1)
total = int(total_pixels)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
if scale_by >= 1:
return image
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = common_upscale(samples, width, height, "lanczos", "disabled")
s = s.movedim(1, -1)
return s
def tensor_to_data_uri(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Converts a tensor image to a Data URI string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
Returns:
Data URI string (e.g., 'data:image/png;base64,...').
"""
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
return f"data:{mime_type};base64,{base64_string}"
def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", codec_name: str = "aac") -> str:
"""Converts an audio input to a base64 string."""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
audio_bytes = audio_bytes_io.getvalue()
return base64.b64encode(audio_bytes).decode("utf-8")
def audio_ndarray_to_bytesio(
audio_data_np: np.ndarray,
sample_rate: int,
container_format: str = "mp4",
codec_name: str = "aac",
) -> BytesIO:
"""
Encodes a numpy array of audio data into a BytesIO object.
"""
audio_bytes_io = BytesIO()
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
frame = av.AudioFrame.from_ndarray(
audio_data_np,
format="fltp",
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
)
frame.sample_rate = sample_rate
frame.pts = 0
for packet in audio_stream.encode(frame):
output_container.mux(packet)
# Flush stream
for packet in audio_stream.encode(None):
output_container.mux(packet)
audio_bytes_io.seek(0)
return audio_bytes_io
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
"""
Prepares audio waveform for av library by converting to a contiguous numpy array.
Args:
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
Returns:
Contiguous numpy array of the audio waveform. If the audio was batched,
the first item is taken.
"""
if waveform.ndim != 3 or waveform.shape[0] != 1:
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
# If batch is > 1, take first item
if waveform.shape[0] > 1:
waveform = waveform[0]
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
if audio_data_np.dtype != np.float32:
audio_data_np = audio_data_np.astype(np.float32)
return audio_data_np
def audio_input_to_mp3(audio: Input.Audio) -> BytesIO:
waveform = audio["waveform"].cpu()
output_buffer = BytesIO()
output_container = av.open(output_buffer, mode="w", format="mp3")
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
out_stream.bit_rate = 320000
frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format="flt",
layout="mono" if waveform.shape[0] == 1 else "stereo",
)
frame.sample_rate = audio["sample_rate"]
frame.pts = 0
output_container.mux(out_stream.encode(frame))
output_container.mux(out_stream.encode(None))
output_container.close()
output_buffer.seek(0)
return output_buffer
def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
"""
Returns a new VideoInput object trimmed from the beginning to the specified duration,
using av to avoid loading entire video into memory.
Args:
video: Input video to trim
duration_sec: Duration in seconds to keep from the beginning
Returns:
VideoFromFile object that owns the output buffer
"""
output_buffer = BytesIO()
input_container = None
output_container = None
try:
# Get the stream source - this avoids loading entire video into memory
# when the source is already a file path
input_source = video.get_stream_source()
# Open containers
input_container = av.open(input_source, mode="r")
output_container = av.open(output_buffer, mode="w", format="mp4")
# Set up output streams for re-encoding
video_stream = None
audio_stream = None
for stream in input_container.streams:
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters
video_stream = output_container.add_stream("h264", rate=stream.average_rate)
video_stream.width = stream.width
video_stream.height = stream.height
video_stream.pix_fmt = "yuv420p"
logging.info("Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate)
elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters
audio_stream = output_container.add_stream("aac", rate=stream.sample_rate)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
# Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps)
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
if target_frames == 0:
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
frame_count = 0
audio_frame_count = 0
# Decode and re-encode video frames
if video_stream:
for frame in input_container.decode(video=0):
if frame_count >= target_frames:
break
# Re-encode frame
for packet in video_stream.encode(frame):
output_container.mux(packet)
frame_count += 1
# Flush encoder
for packet in video_stream.encode():
output_container.mux(packet)
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
# Decode and re-encode audio frames
if audio_stream:
input_container.seek(0) # Reset to beginning for audio
for frame in input_container.decode(audio=0):
if frame.time >= duration_sec:
break
# Re-encode frame
for packet in audio_stream.encode(frame):
output_container.mux(packet)
audio_frame_count += 1
# Flush encoder
for packet in audio_stream.encode():
output_container.mux(packet)
logging.info("Encoded %s audio frames", audio_frame_count)
# Close containers
output_container.close()
input_container.close()
# Return as VideoFromFile using the buffer
output_buffer.seek(0)
return InputImpl.VideoFromFile(output_buffer)
except Exception as e:
# Clean up on error
if input_container is not None:
input_container.close()
if output_container is not None:
output_container.close()
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
def _f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / (2**15)
elif wav.dtype == torch.int32:
return wav.float() / (2**31)
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict:
"""
Decode any common audio container from bytes using PyAV and return
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
"""
with av.open(BytesIO(audio_bytes)) as af:
if not af.streams.audio:
raise ValueError("No audio stream found in response.")
stream = af.streams.audio[0]
in_sr = int(stream.codec_context.sample_rate)
out_sr = in_sr
frames: list[torch.Tensor] = []
n_channels = stream.channels or 1
for frame in af.decode(streams=stream.index):
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
buf = torch.from_numpy(arr)
if buf.ndim == 1:
buf = buf.unsqueeze(0) # [T] -> [1, T]
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
elif buf.shape[0] != n_channels:
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
frames.append(buf)
if not frames:
raise ValueError("Decoded zero audio frames.")
wav = torch.cat(frames, dim=1) # [C, T]
wav = _f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}

View File

@@ -0,0 +1,249 @@
import asyncio
import contextlib
import uuid
from io import BytesIO
from pathlib import Path
from typing import IO, Optional, Union
from urllib.parse import urljoin, urlparse
import aiohttp
import torch
from aiohttp.client_exceptions import ClientError, ContentTypeError
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO as COMFY_IO
from comfy_api_nodes.apis import request_logger
from ._helpers import (
default_base_url,
get_auth_header,
is_processing_interrupted,
sleep_with_interrupt,
)
from .client import _diagnose_connectivity
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
from .conversions import bytesio_to_image_tensor
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
async def download_url_to_bytesio(
url: str,
dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
*,
timeout: Optional[float] = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
cls: type[COMFY_IO.ComfyNode] = None,
) -> None:
"""Stream-download a URL to `dest`.
`dest` must be one of:
- a BytesIO (rewound to 0 after write),
- a file-like object opened in binary write mode (must implement .write()),
- a filesystem path (str | pathlib.Path), which will be opened with 'wb'.
If `url` starts with `/proxy/`, `cls` must be provided so the URL can be expanded
to an absolute URL and authentication headers can be applied.
Raises:
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors)
"""
if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"):
raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().")
attempt = 0
delay = retry_delay
headers: dict[str, str] = {}
parsed_url = urlparse(url)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
if cls is None:
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
headers = get_auth_header(cls)
while True:
attempt += 1
op_id = _generate_operation_id("GET", url, attempt)
timeout_cfg = aiohttp.ClientTimeout(total=timeout)
is_path_sink = isinstance(dest, (str, Path))
fhandle = None
session: Optional[aiohttp.ClientSession] = None
stop_evt: Optional[asyncio.Event] = None
monitor_task: Optional[asyncio.Task] = None
req_task: Optional[asyncio.Task] = None
try:
with contextlib.suppress(Exception):
request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url)
session = aiohttp.ClientSession(timeout=timeout_cfg)
stop_evt = asyncio.Event()
async def _monitor():
try:
while not stop_evt.is_set():
if is_processing_interrupted():
return
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return
monitor_task = asyncio.create_task(_monitor())
req_task = asyncio.create_task(session.get(url, headers=headers))
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
if monitor_task in done and req_task in pending:
req_task.cancel()
with contextlib.suppress(Exception):
await req_task
raise ProcessingInterrupted("Task cancelled")
try:
resp = await req_task
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
async with resp:
if resp.status >= 400:
with contextlib.suppress(Exception):
try:
body = await resp.json()
except (ContentTypeError, ValueError):
text = await resp.text()
body = text if len(text) <= 4096 else f"[text {len(text)} bytes]"
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=f"HTTP {resp.status}",
)
if resp.status in _RETRY_STATUS and attempt <= max_retries:
await sleep_with_interrupt(delay, cls, None, None, None)
delay *= retry_backoff
continue
raise Exception(f"Failed to download (HTTP {resp.status}).")
if is_path_sink:
p = Path(str(dest))
with contextlib.suppress(Exception):
p.parent.mkdir(parents=True, exist_ok=True)
fhandle = open(p, "wb")
sink = fhandle
else:
sink = dest # BytesIO or file-like
written = 0
while True:
try:
chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0)
except asyncio.TimeoutError:
chunk = b""
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
if is_processing_interrupted():
raise ProcessingInterrupted("Task cancelled")
if not chunk:
if resp.content.at_eof():
break
continue
sink.write(chunk)
written += len(chunk)
if isinstance(dest, BytesIO):
with contextlib.suppress(Exception):
dest.seek(0)
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=f"[streamed {written} bytes to dest]",
)
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (ClientError, asyncio.TimeoutError) as e:
if attempt <= max_retries:
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
await sleep_with_interrupt(delay, cls, None, None, None)
delay *= retry_backoff
continue
diag = await _diagnose_connectivity()
if diag.get("is_local_issue"):
raise LocalNetworkError(
"Unable to connect to the network. Please check your internet connection and try again."
) from e
raise ApiServerError("The remote service appears unreachable at this time.") from e
finally:
if stop_evt is not None:
stop_evt.set()
if monitor_task:
monitor_task.cancel()
with contextlib.suppress(Exception):
await monitor_task
if req_task and not req_task.done():
req_task.cancel()
with contextlib.suppress(Exception):
await req_task
if session:
with contextlib.suppress(Exception):
await session.close()
if fhandle:
with contextlib.suppress(Exception):
fhandle.flush()
fhandle.close()
async def download_url_to_image_tensor(
url: str,
*,
timeout: float = None,
cls: type[COMFY_IO.ComfyNode] = None,
) -> torch.Tensor:
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
result = BytesIO()
await download_url_to_bytesio(url, result, timeout=timeout, cls=cls)
return bytesio_to_image_tensor(result)
async def download_url_to_video_output(
video_url: str,
*,
timeout: float = None,
cls: type[COMFY_IO.ComfyNode] = None,
) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO()
await download_url_to_bytesio(video_url, result, timeout=timeout, cls=cls)
return VideoFromFile(result)
def _generate_operation_id(method: str, url: str, attempt: int) -> str:
try:
parsed = urlparse(url)
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_")
except Exception:
slug = "download"
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"

View File

@@ -0,0 +1,338 @@
import asyncio
import contextlib
import logging
import time
import uuid
from io import BytesIO
from typing import Optional, Union
from urllib.parse import urlparse
import aiohttp
import torch
from pydantic import BaseModel, Field
from comfy_api.latest import IO, Input
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api_nodes.apis import request_logger
from ._helpers import is_processing_interrupted, sleep_with_interrupt
from .client import (
ApiEndpoint,
_diagnose_connectivity,
_display_time_progress,
sync_op,
)
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
from .conversions import (
audio_ndarray_to_bytesio,
audio_tensor_to_contiguous_ndarray,
tensor_to_bytesio,
)
class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload")
content_type: Optional[str] = Field(
None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
)
class UploadResponse(BaseModel):
download_url: str = Field(..., description="URL to GET uploaded file")
upload_url: str = Field(..., description="URL to PUT file to upload")
async def upload_images_to_comfyapi(
cls: type[IO.ComfyNode],
image: torch.Tensor,
*,
max_images: int = 8,
mime_type: Optional[str] = None,
wait_label: Optional[str] = "Uploading",
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first.
"""
# if batch, try to upload each file if max_images is greater than 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1
for idx in range(min(batch_len, max_images)):
tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
download_urls.append(url)
return download_urls
async def upload_audio_to_comfyapi(
cls: type[IO.ComfyNode],
audio: Input.Audio,
*,
container_format: str = "mp4",
codec_name: str = "aac",
mime_type: str = "audio/mp4",
filename: str = "uploaded_audio.mp4",
) -> str:
"""
Uploads a single audio input to ComfyUI API and returns its download URL.
Encodes the raw waveform into the specified format before uploading.
"""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type)
async def upload_video_to_comfyapi(
cls: type[IO.ComfyNode],
video: Input.Video,
*,
container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None,
) -> str:
"""
Uploads a single video to ComfyUI API and returns its download URL.
Uses the specified container and codec for saving the video before upload.
"""
if max_duration is not None:
try:
actual_duration = video.get_duration()
if actual_duration > max_duration:
raise ValueError(
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
)
except Exception as e:
logging.error("Error getting video duration: %s", str(e))
raise ValueError(f"Could not verify video duration from source: {e}") from e
upload_mime_type = f"video/{container.value.lower()}"
filename = f"uploaded_video.{container.value.lower()}"
# Convert VideoInput to BytesIO using specified container/codec
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type)
async def upload_file_to_comfyapi(
cls: type[IO.ComfyNode],
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: Optional[str],
wait_label: Optional[str] = "Uploading",
) -> str:
"""Uploads a single file to ComfyUI API and returns its download URL."""
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
create_resp = await sync_op(
cls,
endpoint=ApiEndpoint(path="/customers/storage", method="POST"),
data=request_object,
response_model=UploadResponse,
final_label_on_success=None,
monitor_progress=False,
)
await upload_file(
cls,
create_resp.upload_url,
file_bytes_io,
content_type=upload_mime_type,
wait_label=wait_label,
)
return create_resp.download_url
async def upload_file(
cls: type[IO.ComfyNode],
upload_url: str,
file: Union[BytesIO, str],
*,
content_type: Optional[str] = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: Optional[str] = None,
) -> None:
"""
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
Args:
cls: Node class (provides auth context + UI progress hooks).
upload_url: Pre-signed PUT URL.
file: BytesIO or path string.
content_type: Explicit MIME type. If None, we *suppress* Content-Type.
max_retries: Maximum retry attempts.
retry_delay: Initial delay in seconds.
retry_backoff: Exponential backoff factor.
wait_label: Progress label shown in Comfy UI.
Raises:
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
"""
if isinstance(file, BytesIO):
with contextlib.suppress(Exception):
file.seek(0)
data = file.read()
elif isinstance(file, str):
with open(file, "rb") as f:
data = f.read()
else:
raise ValueError("file must be a BytesIO or a filesystem path string")
headers: dict[str, str] = {}
skip_auto_headers: set[str] = set()
if content_type:
headers["Content-Type"] = content_type
else:
skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request
attempt = 0
delay = retry_delay
start_ts = time.monotonic()
op_uuid = uuid.uuid4().hex[:8]
while True:
attempt += 1
operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid)
timeout = aiohttp.ClientTimeout(total=None)
stop_evt = asyncio.Event()
async def _monitor():
try:
while not stop_evt.is_set():
if is_processing_interrupted():
return
if wait_label:
_display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return
monitor_task = asyncio.create_task(_monitor())
sess: Optional[aiohttp.ClientSession] = None
try:
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_params=None,
request_data=f"[File data {len(data)} bytes]",
)
except Exception as e:
logging.debug("[DEBUG] upload request logging failed: %s", e)
sess = aiohttp.ClientSession(timeout=timeout)
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
req_task = asyncio.create_task(req)
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
if monitor_task in done and req_task in pending:
req_task.cancel()
raise ProcessingInterrupted("Upload cancelled")
try:
resp = await req_task
except asyncio.CancelledError:
raise ProcessingInterrupted("Upload cancelled") from None
async with resp:
if resp.status >= 400:
with contextlib.suppress(Exception):
try:
body = await resp.json()
except Exception:
body = await resp.text()
msg = f"Upload failed with status {resp.status}"
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=msg,
)
if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries:
await sleep_with_interrupt(
delay,
cls,
wait_label,
start_ts,
None,
display_callback=_display_time_progress if wait_label else None,
)
delay *= retry_backoff
continue
raise Exception(f"Failed to upload (HTTP {resp.status}).")
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content="File uploaded successfully.",
)
except Exception as e:
logging.debug("[DEBUG] upload response logging failed: %s", e)
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
if attempt <= max_retries:
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_data=f"[File data {len(data)} bytes]",
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
await sleep_with_interrupt(
delay,
cls,
wait_label,
start_ts,
None,
display_callback=_display_time_progress if wait_label else None,
)
delay *= retry_backoff
continue
diag = await _diagnose_connectivity()
if diag.get("is_local_issue"):
raise LocalNetworkError(
"Unable to connect to the network. Please check your internet connection and try again."
) from e
raise ApiServerError("The API service appears unreachable at this time.") from e
finally:
stop_evt.set()
if monitor_task:
monitor_task.cancel()
with contextlib.suppress(Exception):
await monitor_task
if sess:
with contextlib.suppress(Exception):
await sess.close()
def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str:
try:
parsed = urlparse(url)
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_")
except Exception:
slug = "upload"
return f"{method}_{slug}_{op_uuid}_try{attempt}"

View File

@@ -2,6 +2,8 @@ import logging
from typing import Optional
import torch
from comfy_api.input.video_types import VideoInput
from comfy_api.latest import Input
@@ -28,9 +30,7 @@ def validate_image_dimensions(
if max_width is not None and width > max_width:
raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
if min_height is not None and height < min_height:
raise ValueError(
f"Image height must be at least {min_height}px, got {height}px"
)
raise ValueError(f"Image height must be at least {min_height}px, got {height}px")
if max_height is not None and height > max_height:
raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
@@ -44,13 +44,9 @@ def validate_image_aspect_ratio(
aspect_ratio = width / height
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
raise ValueError(
f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}"
)
raise ValueError(f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}")
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
raise ValueError(
f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}"
)
raise ValueError(f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}")
def validate_image_aspect_ratio_range(
@@ -58,7 +54,7 @@ def validate_image_aspect_ratio_range(
min_ratio: tuple[float, float], # e.g. (1, 4)
max_ratio: tuple[float, float], # e.g. (4, 1)
*,
strict: bool = True, # True -> (min, max); False -> [min, max]
strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float:
a1, b1 = min_ratio
a2, b2 = max_ratio
@@ -85,7 +81,7 @@ def validate_aspect_ratio_closeness(
min_rel: float,
max_rel: float,
*,
strict: bool = False, # True => exclusive, False => inclusive
strict: bool = False, # True => exclusive, False => inclusive
) -> None:
w1, h1 = get_image_dimensions(start_img)
w2, h2 = get_image_dimensions(end_img)
@@ -118,9 +114,7 @@ def validate_video_dimensions(
if max_width is not None and width > max_width:
raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
if min_height is not None and height < min_height:
raise ValueError(
f"Video height must be at least {min_height}px, got {height}px"
)
raise ValueError(f"Video height must be at least {min_height}px, got {height}px")
if max_height is not None and height > max_height:
raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
@@ -138,13 +132,9 @@ def validate_video_duration(
epsilon = 0.0001
if min_duration is not None and min_duration - epsilon > duration:
raise ValueError(
f"Video duration must be at least {min_duration}s, got {duration}s"
)
raise ValueError(f"Video duration must be at least {min_duration}s, got {duration}s")
if max_duration is not None and duration > max_duration + epsilon:
raise ValueError(
f"Video duration must be at most {max_duration}s, got {duration}s"
)
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
def get_number_of_images(images):
@@ -165,3 +155,31 @@ def validate_audio_duration(
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
if max_duration is not None and dur - eps > max_duration:
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")
def validate_string(
string: str,
strip_whitespace=True,
field_name="prompt",
min_length=None,
max_length=None,
):
if string is None:
raise Exception(f"Field '{field_name}' cannot be empty.")
if strip_whitespace:
string = string.strip()
if min_length and len(string) < min_length:
raise Exception(
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
)
if max_length and len(string) > max_length:
raise Exception(
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
)
def validate_container_format_is_mp4(video: VideoInput) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")