Compare commits

..

6 Commits

Author SHA1 Message Date
f273f6771b Merge branch 'master' into deepme987/add-execution-environment-api 2026-04-21 05:00:19 +05:30
543e9fba64 fix: pin SQLAlchemy>=2.0 in requirements.txt (fixes #13036) (#13316) 2026-04-20 15:30:23 -07:00
fc5f4a996b Add link to Intel portable to Readme. (#13477) 2026-04-19 20:26:12 -04:00
bdf444df06 Merge branch 'master' into deepme987/add-execution-environment-api 2026-03-25 23:10:31 -07:00
e773b69b03 Merge branch 'master' into deepme987/add-execution-environment-api 2026-03-20 10:14:47 -07:00
81651606a6 feat: add execution environment API for managed deployments
Adds api.environment.get() to the public ComfyAPI — returns
"local" (default), "cloud", or "remote" based on the
COMFY_EXECUTION_ENVIRONMENT env var.

Custom nodes use this to adapt behavior for managed deployments
(e.g. skip model downloads when models are pre-provisioned).
2026-03-20 10:05:06 -07:00
7 changed files with 100 additions and 348 deletions

View File

@ -195,7 +195,9 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat
#### Alternative Downloads:
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).

View File

@ -15,6 +15,7 @@ from comfy_execution.progress import get_progress_state, PreviewImageTuple
from PIL import Image
from comfy.cli_args import args
import numpy as np
import os
class ComfyAPI_latest(ComfyAPIBase):
@ -25,6 +26,7 @@ class ComfyAPI_latest(ComfyAPIBase):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
self.environment = self.Environment()
self.caching = self.Caching()
class NodeReplacement(ProxiedSingleton):
@ -85,6 +87,27 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
class Environment(ProxiedSingleton):
"""
Query the current execution environment.
Managed deployments set the ``COMFY_EXECUTION_ENVIRONMENT`` env var
so custom nodes can adapt their behaviour at runtime.
Example::
from comfy_api.latest import api
env = api.environment.get() # "local" | "cloud" | "remote"
"""
_VALID = {"local", "cloud", "remote"}
async def get(self) -> str:
"""Return the execution environment: ``"local"``, ``"cloud"``, or ``"remote"``."""
value = os.environ.get("COMFY_EXECUTION_ENVIRONMENT", "local").lower().strip()
return value if value in self._VALID else "local"
class Caching(ProxiedSingleton):
"""
External cache provider API for sharing cached node outputs

View File

@ -2,7 +2,6 @@ import asyncio
import contextlib
import json
import logging
import os
import time
import uuid
from collections.abc import Callable, Iterable
@ -33,30 +32,6 @@ from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInte
M = TypeVar("M", bound=BaseModel)
def _env_int(key: str, default: int) -> int:
try:
return int(os.environ[key])
except (KeyError, ValueError):
return default
def _env_float(key: str, default: float) -> float:
try:
return float(os.environ[key])
except (KeyError, ValueError):
return default
@dataclass(frozen=True)
class _RetryDefaults:
max_retries: int = _env_int("COMFY_API_MAX_RETRIES", 3)
retry_delay: float = _env_float("COMFY_API_RETRY_DELAY", 1.0)
retry_backoff: float = _env_float("COMFY_API_RETRY_BACKOFF", 2.0)
RETRY_DEFAULTS = _RetryDefaults()
class ApiEndpoint:
def __init__(
self,
@ -103,21 +78,11 @@ class _PollUIState:
price: float | None = None
estimated_duration: int | None = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals
active_since: float | None = (
None # start time of current active interval (None if queued)
)
active_since: float | None = None # start time of current active interval (None if queued)
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = [
"succeeded",
"succeed",
"success",
"completed",
"finished",
"done",
"complete",
]
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
@ -133,9 +98,9 @@ async def sync_op(
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Callable | None = None,
max_retries: int = RETRY_DEFAULTS.max_retries,
retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
estimated_duration: int | None = None,
final_label_on_success: str | None = "Completed",
@ -166,9 +131,7 @@ async def sync_op(
is_rate_limited=is_rate_limited,
)
if not isinstance(raw, dict):
raise Exception(
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
)
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
return _validate_or_raise(response_model, raw)
@ -215,9 +178,7 @@ async def poll_op(
cancel_timeout=cancel_timeout,
)
if not isinstance(raw, dict):
raise Exception(
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
)
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
return _validate_or_raise(response_model, raw)
@ -231,9 +192,9 @@ async def sync_op_raw(
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Callable | None = None,
max_retries: int = RETRY_DEFAULTS.max_retries,
retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
estimated_duration: int | None = None,
as_binary: bool = False,
@ -308,15 +269,9 @@ async def poll_op_raw(
Returns the final JSON response from the poll endpoint.
"""
completed_states = _normalize_statuses(
COMPLETED_STATUSES if completed_statuses is None else completed_statuses
)
failed_states = _normalize_statuses(
FAILED_STATUSES if failed_statuses is None else failed_statuses
)
queued_states = _normalize_statuses(
QUEUED_STATUSES if queued_statuses is None else queued_statuses
)
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
started = time.monotonic()
consumed_attempts = 0 # counts only non-queued polls
@ -334,9 +289,7 @@ async def poll_op_raw(
break
now = time.monotonic()
proc_elapsed = state.base_processing_elapsed + (
(now - state.active_since)
if state.active_since is not None
else 0.0
(now - state.active_since) if state.active_since is not None else 0.0
)
_display_time_progress(
cls,
@ -408,15 +361,11 @@ async def poll_op_raw(
is_queued = status in queued_states
if is_queued:
if (
state.active_since is not None
): # If we just moved from active -> queued, close the active interval
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
state.base_processing_elapsed += now_ts - state.active_since
state.active_since = None
else:
if (
state.active_since is None
): # If we just moved from queued -> active, open a new active interval
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
state.active_since = now_ts
state.is_queued = is_queued
@ -493,9 +442,7 @@ def _display_text(
) -> None:
display_lines: list[str] = []
if status:
display_lines.append(
f"Status: {status.capitalize() if isinstance(status, str) else status}"
)
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
if price is not None:
p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
if p != "0":
@ -503,9 +450,7 @@ def _display_text(
if text is not None:
display_lines.append(text)
if display_lines:
PromptServer.instance.send_progress_text(
"\n".join(display_lines), get_node_id(node_cls)
)
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
def _display_time_progress(
@ -519,11 +464,7 @@ def _display_time_progress(
processing_elapsed_seconds: int | None = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = (
processing_elapsed_seconds
if processing_elapsed_seconds is not None
else elapsed_seconds
)
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
remaining = max(0, int(estimated_total) - int(pe))
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
else:
@ -562,9 +503,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
raise ValueError("files tuple must be (filename, file[, content_type])")
def _merge_params(
endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None
) -> dict[str, Any]:
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
params = dict(endpoint_params or {})
if method.upper() == "GET" and data:
for k, v in data.items():
@ -627,14 +566,8 @@ def _snapshot_request_body_for_logging(
filename = file_obj[0]
else:
filename = getattr(file_obj, "name", field_name)
file_fields.append(
{"field": field_name, "filename": str(filename or "")}
)
return {
"_multipart": True,
"form_fields": form_fields,
"file_fields": file_fields,
}
file_fields.append({"field": field_name, "filename": str(filename or "")})
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
if content_type == "application/x-www-form-urlencoded":
return data or {}
return data or {}
@ -648,9 +581,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
method = cfg.endpoint.method
params = _merge_params(
cfg.endpoint.query_params, method, cfg.data if method == "GET" else None
)
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
"""Every second: update elapsed time and signal interruption."""
@ -660,20 +591,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
return
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls,
cfg.wait_label,
int(time.monotonic() - start_ts),
cfg.estimated_total,
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return # normal shutdown
start_time = (
cfg.progress_origin_ts
if cfg.progress_origin_ts is not None
else time.monotonic()
)
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
attempt = 0
delay = cfg.retry_delay
rate_limit_attempts = 0
@ -690,9 +614,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
payload_headers = (
{"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
)
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls))
if cfg.endpoint.headers:
@ -701,9 +623,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_kw: dict[str, Any] = {"headers": payload_headers}
if method == "GET":
payload_headers.pop("Content-Type", None)
request_body_log = _snapshot_request_body_for_logging(
cfg.content_type, method, cfg.data, cfg.files
)
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
try:
if cfg.monitor_progress:
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
@ -717,23 +637,16 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if cfg.multipart_parser and cfg.data:
form = cfg.multipart_parser(cfg.data)
if not isinstance(form, aiohttp.FormData):
raise ValueError(
"multipart_parser must return aiohttp.FormData"
)
raise ValueError("multipart_parser must return aiohttp.FormData")
else:
form = aiohttp.FormData(default_to_multipart=True)
if cfg.data:
for k, v in cfg.data.items():
if v is None:
continue
form.add_field(
k,
str(v) if not isinstance(v, (bytes, bytearray)) else v,
)
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
if cfg.files:
file_iter = (
cfg.files if isinstance(cfg.files, list) else cfg.files.items()
)
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
for field_name, file_obj in file_iter:
if file_obj is None:
continue
@ -747,17 +660,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if isinstance(file_value, BytesIO):
with contextlib.suppress(Exception):
file_value.seek(0)
form.add_field(
field_name,
file_value,
filename=filename,
content_type=content_type,
)
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
payload_kw["data"] = form
elif (
cfg.content_type == "application/x-www-form-urlencoded"
and method != "GET"
):
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
payload_kw["data"] = cfg.data or {}
elif method != "GET":
@ -780,9 +685,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
tasks = {req_task}
if monitor_task:
tasks.add(monitor_task)
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
if monitor_task and monitor_task in done:
# Interrupted cancel the request and abort
@ -802,8 +705,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
wait_time = 0.0
retry_label = ""
is_rl = resp.status == 429 or (
cfg.is_rate_limited is not None
and cfg.is_rate_limited(resp.status, body)
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
)
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
rate_limit_attempts += 1
@ -811,10 +713,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
rate_limit_delay *= cfg.retry_backoff
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
should_retry = True
elif (
resp.status in _RETRY_STATUS
and (attempt - rate_limit_attempts) <= cfg.max_retries
):
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
wait_time = delay
delay *= cfg.retry_backoff
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
@ -844,9 +743,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress
if cfg.monitor_progress
else None,
display_callback=_display_time_progress if cfg.monitor_progress else None,
)
continue
msg = _friendly_http_message(resp.status, body)
@ -873,10 +770,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
raise ProcessingInterrupted("Task cancelled")
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls,
cfg.wait_label,
int(now - start_time),
cfg.estimated_total,
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
)
bytes_payload = bytes(buff)
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
@ -906,15 +800,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload = json.loads(text) if text else {}
except json.JSONDecodeError:
payload = {"_raw": text}
response_content_to_log = (
payload if isinstance(payload, dict) else text
)
response_content_to_log = payload if isinstance(payload, dict) else text
with contextlib.suppress(Exception):
extracted_price = (
cfg.price_extractor(payload)
if cfg.price_extractor
else None
)
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response(
@ -956,9 +844,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress
if cfg.monitor_progress
else None,
display_callback=_display_time_progress if cfg.monitor_progress else None,
)
delay *= cfg.retry_backoff
continue
@ -999,11 +885,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if sess:
with contextlib.suppress(Exception):
await sess.close()
if (
operation_succeeded
and cfg.monitor_progress
and cfg.final_label_on_success
):
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
_display_time_progress(
cfg.node_cls,
status=cfg.final_label_on_success,

View File

@ -22,7 +22,7 @@ from ._helpers import (
sleep_with_interrupt,
to_aiohttp_url,
)
from .client import RETRY_DEFAULTS, _diagnose_connectivity
from .client import _diagnose_connectivity
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
from .conversions import bytesio_to_image_tensor
@ -34,9 +34,9 @@ async def download_url_to_bytesio(
dest: BytesIO | IO[bytes] | str | Path | None,
*,
timeout: float | None = None,
max_retries: int = max(5, RETRY_DEFAULTS.max_retries),
retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
max_retries: int = 5,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
cls: type[COMFY_IO.ComfyNode] = None,
) -> None:
"""Stream-download a URL to `dest`.
@ -53,9 +53,7 @@ async def download_url_to_bytesio(
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors)
"""
if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"):
raise ValueError(
"dest must be a path (str|Path) or a binary-writable object providing .write()."
)
raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().")
attempt = 0
delay = retry_delay
@ -64,9 +62,7 @@ async def download_url_to_bytesio(
parsed_url = urlparse(url)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
if cls is None:
raise ValueError(
"For relative 'cloud' paths, the `cls` parameter is required."
)
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
headers = get_auth_header(cls)
@ -84,9 +80,7 @@ async def download_url_to_bytesio(
try:
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=op_id, request_method="GET", request_url=url
)
request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url)
session = aiohttp.ClientSession(timeout=timeout_cfg)
stop_evt = asyncio.Event()
@ -102,12 +96,8 @@ async def download_url_to_bytesio(
monitor_task = asyncio.create_task(_monitor())
req_task = asyncio.create_task(
session.get(to_aiohttp_url(url), headers=headers)
)
done, pending = await asyncio.wait(
{req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED
)
req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers))
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
if monitor_task in done and req_task in pending:
req_task.cancel()
@ -127,11 +117,7 @@ async def download_url_to_bytesio(
body = await resp.json()
except (ContentTypeError, ValueError):
text = await resp.text()
body = (
text
if len(text) <= 4096
else f"[text {len(text)} bytes]"
)
body = text if len(text) <= 4096 else f"[text {len(text)} bytes]"
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
@ -160,9 +146,7 @@ async def download_url_to_bytesio(
written = 0
while True:
try:
chunk = await asyncio.wait_for(
resp.content.read(1024 * 1024), timeout=1.0
)
chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0)
except asyncio.TimeoutError:
chunk = b""
except asyncio.CancelledError:
@ -211,9 +195,7 @@ async def download_url_to_bytesio(
raise LocalNetworkError(
"Unable to connect to the network. Please check your internet connection and try again."
) from e
raise ApiServerError(
"The remote service appears unreachable at this time."
) from e
raise ApiServerError("The remote service appears unreachable at this time.") from e
finally:
if stop_evt is not None:
stop_evt.set()
@ -255,9 +237,7 @@ async def download_url_to_video_output(
) -> InputImpl.VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO()
await download_url_to_bytesio(
video_url, result, timeout=timeout, max_retries=max_retries, cls=cls
)
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
return InputImpl.VideoFromFile(result)
@ -276,11 +256,7 @@ async def download_url_as_bytesio(
def _generate_operation_id(method: str, url: str, attempt: int) -> str:
try:
parsed = urlparse(url)
slug = (
(parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download")
.strip("/")
.replace("/", "_")
)
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_")
except Exception:
slug = "download"
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"

View File

@ -15,7 +15,6 @@ from comfy_api.latest import IO, Input, Types
from . import request_logger
from ._helpers import is_processing_interrupted, sleep_with_interrupt
from .client import (
RETRY_DEFAULTS,
ApiEndpoint,
_diagnose_connectivity,
_display_time_progress,
@ -78,17 +77,13 @@ async def upload_images_to_comfyapi(
for idx in range(num_to_upload):
tensor = tensors[idx]
img_io = tensor_to_bytesio(
tensor, total_pixels=total_pixels, mime_type=mime_type
)
img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type)
effective_label = wait_label
if wait_label and show_batch_index and num_to_upload > 1:
effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})"
url = await upload_file_to_comfyapi(
cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts
)
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts)
download_urls.append(url)
return download_urls
@ -130,12 +125,8 @@ async def upload_audio_to_comfyapi(
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
return await upload_file_to_comfyapi(
cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type
)
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type)
async def upload_video_to_comfyapi(
@ -170,9 +161,7 @@ async def upload_video_to_comfyapi(
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)
return await upload_file_to_comfyapi(
cls, video_bytes_io, filename, upload_mime_type, wait_label
)
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
_3D_MIME_TYPES = {
@ -208,9 +197,7 @@ async def upload_file_to_comfyapi(
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(
file_name=filename, content_type=upload_mime_type
)
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
create_resp = await sync_op(
cls,
endpoint=ApiEndpoint(path="/customers/storage", method="POST"),
@ -236,9 +223,9 @@ async def upload_file(
file: BytesIO | str,
*,
content_type: str | None = None,
max_retries: int = RETRY_DEFAULTS.max_retries,
retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str | None = None,
progress_origin_ts: float | None = None,
) -> None:
@ -263,15 +250,11 @@ async def upload_file(
if content_type:
headers["Content-Type"] = content_type
else:
skip_auto_headers.add(
"Content-Type"
) # Don't let aiohttp add Content-Type, it can break the signed request
skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request
attempt = 0
delay = retry_delay
start_ts = (
progress_origin_ts if progress_origin_ts is not None else time.monotonic()
)
start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic()
op_uuid = uuid.uuid4().hex[:8]
while True:
attempt += 1
@ -285,9 +268,7 @@ async def upload_file(
if is_processing_interrupted():
return
if wait_label:
_display_time_progress(
cls, wait_label, int(time.monotonic() - start_ts), None
)
_display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return
@ -305,17 +286,10 @@ async def upload_file(
)
sess = aiohttp.ClientSession(timeout=timeout)
req = sess.put(
upload_url,
data=data,
headers=headers,
skip_auto_headers=skip_auto_headers,
)
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
req_task = asyncio.create_task(req)
done, pending = await asyncio.wait(
{req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED
)
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
if monitor_task in done and req_task in pending:
req_task.cancel()
@ -343,19 +317,14 @@ async def upload_file(
response_content=body,
error_message=msg,
)
if (
resp.status in {408, 429, 500, 502, 503, 504}
and attempt <= max_retries
):
if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries:
await sleep_with_interrupt(
delay,
cls,
wait_label,
start_ts,
None,
display_callback=_display_time_progress
if wait_label
else None,
display_callback=_display_time_progress if wait_label else None,
)
delay *= retry_backoff
continue
@ -397,9 +366,7 @@ async def upload_file(
raise LocalNetworkError(
"Unable to connect to the network. Please check your internet connection and try again."
) from e
raise ApiServerError(
"The API service appears unreachable at this time."
) from e
raise ApiServerError("The API service appears unreachable at this time.") from e
finally:
stop_evt.set()
if monitor_task:
@ -414,11 +381,7 @@ async def upload_file(
def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str:
try:
parsed = urlparse(url)
slug = (
(parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload")
.strip("/")
.replace("/", "_")
)
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_")
except Exception:
slug = "upload"
return f"{method}_{slug}_{op_uuid}_try{attempt}"

View File

@ -19,7 +19,7 @@ scipy
tqdm
psutil
alembic
SQLAlchemy
SQLAlchemy>=2.0
filelock
av>=14.2.0
comfy-kitchen>=0.2.8

View File

@ -1,94 +0,0 @@
"""Tests for configurable retry defaults via environment variables.
Verifies that COMFY_API_MAX_RETRIES, COMFY_API_RETRY_DELAY, and
COMFY_API_RETRY_BACKOFF environment variables are respected.
NOTE: Cannot import from comfy_api_nodes directly because the import
chain triggers CUDA initialization. The helpers under test are
reimplemented here identically to the production code in client.py.
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from unittest.mock import patch
import pytest
def _env_int(key: str, default: int) -> int:
try:
return int(os.environ[key])
except (KeyError, ValueError):
return default
def _env_float(key: str, default: float) -> float:
try:
return float(os.environ[key])
except (KeyError, ValueError):
return default
@dataclass(frozen=True)
class _RetryDefaults:
max_retries: int = _env_int("COMFY_API_MAX_RETRIES", 3)
retry_delay: float = _env_float("COMFY_API_RETRY_DELAY", 1.0)
retry_backoff: float = _env_float("COMFY_API_RETRY_BACKOFF", 2.0)
class TestEnvHelpers:
def test_env_int_returns_default_when_unset(self):
with patch.dict(os.environ, {}, clear=True):
assert _env_int("NONEXISTENT_KEY", 42) == 42
def test_env_int_returns_env_value(self):
with patch.dict(os.environ, {"TEST_KEY": "10"}):
assert _env_int("TEST_KEY", 42) == 10
def test_env_int_returns_default_on_invalid_value(self):
with patch.dict(os.environ, {"TEST_KEY": "not_a_number"}):
assert _env_int("TEST_KEY", 42) == 42
def test_env_float_returns_default_when_unset(self):
with patch.dict(os.environ, {}, clear=True):
assert _env_float("NONEXISTENT_KEY", 1.5) == 1.5
def test_env_float_returns_env_value(self):
with patch.dict(os.environ, {"TEST_KEY": "2.5"}):
assert _env_float("TEST_KEY", 1.5) == 2.5
def test_env_float_returns_default_on_invalid_value(self):
with patch.dict(os.environ, {"TEST_KEY": "bad"}):
assert _env_float("TEST_KEY", 1.5) == 1.5
class TestRetryDefaults:
def test_hardcoded_defaults_match_expected(self):
defaults = _RetryDefaults()
assert defaults.max_retries == 3
assert defaults.retry_delay == 1.0
assert defaults.retry_backoff == 2.0
def test_env_vars_would_override_at_import_time(self):
"""Dataclass field defaults are evaluated at class-definition time.
This test verifies that _env_int/_env_float return the env values,
which is what populates the dataclass fields at import time."""
with patch.dict(os.environ, {"COMFY_API_MAX_RETRIES": "10"}):
assert _env_int("COMFY_API_MAX_RETRIES", 3) == 10
with patch.dict(os.environ, {"COMFY_API_RETRY_DELAY": "3.0"}):
assert _env_float("COMFY_API_RETRY_DELAY", 1.0) == 3.0
with patch.dict(os.environ, {"COMFY_API_RETRY_BACKOFF": "1.5"}):
assert _env_float("COMFY_API_RETRY_BACKOFF", 2.0) == 1.5
def test_explicit_construction_overrides_defaults(self):
defaults = _RetryDefaults(max_retries=10, retry_delay=3.0, retry_backoff=1.5)
assert defaults.max_retries == 10
assert defaults.retry_delay == 3.0
assert defaults.retry_backoff == 1.5
def test_frozen_dataclass(self):
defaults = _RetryDefaults()
with pytest.raises(AttributeError):
defaults.max_retries = 999