mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-20 16:36:41 +08:00
Compare commits
4 Commits
v0.21.0
...
feat/rnp-r
| Author | SHA1 | Date | |
|---|---|---|---|
| c899ea4ef7 | |||
| 428c323780 | |||
| 46063aa927 | |||
| b565dc7a6c |
@ -198,6 +198,62 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
_PRESETS_SEEDREAM_1K = [
|
||||
("(1K) 1024x1024 (1:1)", 1024, 1024),
|
||||
("(1K) 864x1152 (3:4)", 864, 1152),
|
||||
("(1K) 1152x864 (4:3)", 1152, 864),
|
||||
("(1K) 1312x736 (16:9)", 1312, 736),
|
||||
("(1K) 736x1312 (9:16)", 736, 1312),
|
||||
("(1K) 832x1248 (2:3)", 832, 1248),
|
||||
("(1K) 1248x832 (3:2)", 1248, 832),
|
||||
("(1K) 1568x672 (21:9)", 1568, 672),
|
||||
]
|
||||
|
||||
_PRESETS_SEEDREAM_2K = [
|
||||
("(2K) 2048x2048 (1:1)", 2048, 2048),
|
||||
("(2K) 1728x2304 (3:4)", 1728, 2304),
|
||||
("(2K) 2304x1728 (4:3)", 2304, 1728),
|
||||
("(2K) 2848x1600 (16:9)", 2848, 1600),
|
||||
("(2K) 1600x2848 (9:16)", 1600, 2848),
|
||||
("(2K) 1664x2496 (2:3)", 1664, 2496),
|
||||
("(2K) 2496x1664 (3:2)", 2496, 1664),
|
||||
("(2K) 3136x1344 (21:9)", 3136, 1344),
|
||||
]
|
||||
|
||||
_PRESETS_SEEDREAM_3K = [
|
||||
("(3K) 3072x3072 (1:1)", 3072, 3072),
|
||||
("(3K) 2592x3456 (3:4)", 2592, 3456),
|
||||
("(3K) 3456x2592 (4:3)", 3456, 2592),
|
||||
("(3K) 4096x2304 (16:9)", 4096, 2304),
|
||||
("(3K) 2304x4096 (9:16)", 2304, 4096),
|
||||
("(3K) 2496x3744 (2:3)", 2496, 3744),
|
||||
("(3K) 3744x2496 (3:2)", 3744, 2496),
|
||||
("(3K) 4704x2016 (21:9)", 4704, 2016),
|
||||
]
|
||||
|
||||
_PRESETS_SEEDREAM_4K = [
|
||||
("(4K) 4096x4096 (1:1)", 4096, 4096),
|
||||
("(4K) 3520x4704 (3:4)", 3520, 4704),
|
||||
("(4K) 4704x3520 (4:3)", 4704, 3520),
|
||||
("(4K) 5504x3040 (16:9)", 5504, 3040),
|
||||
("(4K) 3040x5504 (9:16)", 3040, 5504),
|
||||
("(4K) 3328x4992 (2:3)", 3328, 4992),
|
||||
("(4K) 4992x3328 (3:2)", 4992, 3328),
|
||||
("(4K) 6240x2656 (21:9)", 6240, 2656),
|
||||
]
|
||||
|
||||
_CUSTOM_PRESET = [("Custom", None, None)]
|
||||
|
||||
RECOMMENDED_PRESETS_SEEDREAM_5_LITE = (
|
||||
_PRESETS_SEEDREAM_2K + _PRESETS_SEEDREAM_3K + _PRESETS_SEEDREAM_4K + _CUSTOM_PRESET
|
||||
)
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4_5 = (
|
||||
_PRESETS_SEEDREAM_2K + _PRESETS_SEEDREAM_4K + _CUSTOM_PRESET
|
||||
)
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4_0 = (
|
||||
_PRESETS_SEEDREAM_1K + _PRESETS_SEEDREAM_2K + _PRESETS_SEEDREAM_4K + _CUSTOM_PRESET
|
||||
)
|
||||
|
||||
# Seedance 2.0 reference video pixel count limits per model and output resolution.
|
||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
|
||||
"dreamina-seedance-2-0-260128": {
|
||||
|
||||
@ -596,6 +596,7 @@ class Flux2ProImageNode(IO.ComfyNode):
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["width", "height"], inputs=["images"]),
|
||||
expr=cls.PRICE_BADGE_EXPR,
|
||||
),
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -674,6 +675,175 @@ class Flux2MaxImageNode(Flux2ProImageNode):
|
||||
"""
|
||||
|
||||
|
||||
_FLUX2_MODEL_ENDPOINTS = {
|
||||
"Flux.2 [pro]": "/proxy/bfl/flux-2-pro/generate",
|
||||
"Flux.2 [max]": "/proxy/bfl/flux-2-max/generate",
|
||||
}
|
||||
|
||||
|
||||
def _flux2_model_inputs():
|
||||
return [
|
||||
IO.Int.Input(
|
||||
"width",
|
||||
default=1024,
|
||||
min=256,
|
||||
max=2048,
|
||||
step=32,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"height",
|
||||
default=768,
|
||||
min=256,
|
||||
max=2048,
|
||||
step=32,
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, 9)],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional reference image(s) for image-to-image generation. Up to 8 images.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class Flux2ImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Flux2ImageNode",
|
||||
display_name="Flux.2 Image",
|
||||
category="api node/image/BFL",
|
||||
description="Generate images via Flux.2 [pro] or Flux.2 [max] from a prompt and optional reference images.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation or edit",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Flux.2 [pro]", _flux2_model_inputs()),
|
||||
IO.DynamicCombo.Option("Flux.2 [max]", _flux2_model_inputs()),
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model", "model.width", "model.height"],
|
||||
input_groups=["model.images"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isMax := widgets.model = "flux.2 [max]";
|
||||
$MP := 1024 * 1024;
|
||||
$w := $lookup(widgets, "model.width");
|
||||
$h := $lookup(widgets, "model.height");
|
||||
$outMP := $max([1, $floor((($w * $h) + $MP - 1) / $MP)]);
|
||||
$outputCost := $isMax
|
||||
? (0.07 + 0.03 * ($outMP - 1))
|
||||
: (0.03 + 0.015 * ($outMP - 1));
|
||||
$refMin := $isMax ? 0.03 : 0.015;
|
||||
$refMax := $isMax ? 0.24 : 0.12;
|
||||
$hasRefs := $lookup(inputGroups, "model.images") > 0;
|
||||
$hasRefs
|
||||
? {
|
||||
"type": "range_usd",
|
||||
"min_usd": $outputCost + $refMin,
|
||||
"max_usd": $outputCost + $refMax,
|
||||
"format": { "approximate": true }
|
||||
}
|
||||
: {"type": "usd", "usd": $outputCost}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
model_choice = model["model"]
|
||||
endpoint = _FLUX2_MODEL_ENDPOINTS[model_choice]
|
||||
width = model["width"]
|
||||
height = model["height"]
|
||||
images_dict = model.get("images") or {}
|
||||
|
||||
image_tensors: list[Input.Image] = [t for t in images_dict.values() if t is not None]
|
||||
n_images = sum(get_number_of_images(t) for t in image_tensors)
|
||||
if n_images > 8:
|
||||
raise ValueError("The current maximum number of supported images is 8.")
|
||||
|
||||
flat_tensors: list[torch.Tensor] = []
|
||||
for tensor in image_tensors:
|
||||
if len(tensor.shape) == 4:
|
||||
flat_tensors.extend(tensor[i] for i in range(tensor.shape[0]))
|
||||
else:
|
||||
flat_tensors.append(tensor)
|
||||
|
||||
reference_images: dict[str, str] = {}
|
||||
for idx, tensor in enumerate(flat_tensors):
|
||||
key_name = f"input_image_{idx + 1}" if idx else "input_image"
|
||||
reference_images[key_name] = tensor_to_base64_string(tensor, total_pixels=2048 * 2048)
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=endpoint, method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=Flux2ProGenerateRequest(
|
||||
prompt=prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
seed=seed,
|
||||
**reference_images,
|
||||
),
|
||||
)
|
||||
|
||||
def price_extractor(_r: BaseModel) -> float | None:
|
||||
return None if initial_response.cost is None else initial_response.cost / 100
|
||||
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
price_extractor=price_extractor,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class BFLExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -685,6 +855,7 @@ class BFLExtension(ComfyExtension):
|
||||
FluxProFillNode,
|
||||
Flux2ProImageNode,
|
||||
Flux2MaxImageNode,
|
||||
Flux2ImageNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -10,6 +10,9 @@ from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bytedance import (
|
||||
RECOMMENDED_PRESETS,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4_0,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4_5,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_5_LITE,
|
||||
SEEDANCE2_PRICE_PER_1K_TOKENS,
|
||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS,
|
||||
VIDEO_TASKS_EXECUTION_TIME,
|
||||
@ -68,6 +71,12 @@ SEEDREAM_MODELS = {
|
||||
"seedream-4-0-250828": "seedream-4-0-250828",
|
||||
}
|
||||
|
||||
SEEDREAM_PRESETS = {
|
||||
"seedream-5-0-260128": RECOMMENDED_PRESETS_SEEDREAM_5_LITE,
|
||||
"seedream-4-5-251128": RECOMMENDED_PRESETS_SEEDREAM_4_5,
|
||||
"seedream-4-0-250828": RECOMMENDED_PRESETS_SEEDREAM_4_0,
|
||||
}
|
||||
|
||||
# Long-running tasks endpoints(e.g., video)
|
||||
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
@ -562,6 +571,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
)
|
||||
""",
|
||||
),
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -651,6 +661,226 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(torch.cat([await download_url_to_image_tensor(i) for i in urls]))
|
||||
|
||||
|
||||
def _seedream_model_inputs(*, max_ref_images: int, presets: list):
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"size_preset",
|
||||
options=[label for label, _, _ in presets],
|
||||
tooltip="Pick a recommended size. Select Custom to use the width and height below.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"width",
|
||||
default=2048,
|
||||
min=1024,
|
||||
max=6240,
|
||||
step=2,
|
||||
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"height",
|
||||
default=2048,
|
||||
min=1024,
|
||||
max=4992,
|
||||
step=2,
|
||||
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"max_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=max_ref_images,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Maximum number of images to generate. With 1, exactly one image is produced. "
|
||||
"With >1, the model generates between 1 and max_images related images "
|
||||
"(e.g., story scenes, character variations). "
|
||||
"Total images (input + generated) cannot exceed 15.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, max_ref_images + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional reference image(s) for image-to-image or multi-reference generation. "
|
||||
f"Up to {max_ref_images} images.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"fail_on_partial",
|
||||
default=False,
|
||||
tooltip="If enabled, abort execution if any requested images are missing or return an error.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class ByteDanceSeedreamNodeV2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedreamNodeV2",
|
||||
display_name="ByteDance Seedream 4.5 & 5.0",
|
||||
category="api node/image/ByteDance",
|
||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for creating or editing an image.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"seedream 5.0 lite",
|
||||
_seedream_model_inputs(max_ref_images=14, presets=RECOMMENDED_PRESETS_SEEDREAM_5_LITE),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"seedream-4-5-251128",
|
||||
_seedream_model_inputs(max_ref_images=10, presets=RECOMMENDED_PRESETS_SEEDREAM_4_5),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"seedream-4-0-250828",
|
||||
_seedream_model_inputs(max_ref_images=10, presets=RECOMMENDED_PRESETS_SEEDREAM_4_0),
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to use for generation.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
tooltip='Whether to add an "AI generated" watermark to the image.',
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$price := $contains(widgets.model, "5.0 lite") ? 0.035 :
|
||||
$contains(widgets.model, "4-5") ? 0.04 : 0.03;
|
||||
{
|
||||
"type":"usd",
|
||||
"usd": $price,
|
||||
"format": { "suffix":" x images/Run", "approximate": true }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int = 0,
|
||||
watermark: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_id = SEEDREAM_MODELS[model["model"]]
|
||||
presets = SEEDREAM_PRESETS[model_id]
|
||||
|
||||
size_preset = model.get("size_preset", presets[0][0])
|
||||
width = model.get("width", 2048)
|
||||
height = model.get("height", 2048)
|
||||
max_images = model.get("max_images", 1)
|
||||
sequential_image_generation = "disabled" if max_images == 1 else "auto"
|
||||
images_dict = model.get("images") or {}
|
||||
fail_on_partial = model.get("fail_on_partial", False)
|
||||
|
||||
w = h = None
|
||||
for label, tw, th in presets:
|
||||
if label == size_preset:
|
||||
w, h = tw, th
|
||||
break
|
||||
if w is None or h is None:
|
||||
w, h = width, height
|
||||
|
||||
out_num_pixels = w * h
|
||||
mp_provided = out_num_pixels / 1_000_000.0
|
||||
if ("seedream-4-5" in model_id or "seedream-5-0" in model_id) and out_num_pixels < 3686400:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution for the selected model is 3.68MP, but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
if "seedream-4-0" in model_id and out_num_pixels < 921600:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution that the selected model can generate is 0.92MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
if out_num_pixels > 16_777_216:
|
||||
raise ValueError(
|
||||
f"Maximum image resolution for the selected model is 16.78MP, but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
|
||||
image_tensors: list[Input.Image] = [t for t in images_dict.values() if t is not None]
|
||||
n_input_images = sum(get_number_of_images(t) for t in image_tensors)
|
||||
max_num_of_images = 14 if model_id == "seedream-5-0-260128" else 10
|
||||
if n_input_images > max_num_of_images:
|
||||
raise ValueError(
|
||||
f"Maximum of {max_num_of_images} reference images are supported, but {n_input_images} received."
|
||||
)
|
||||
if sequential_image_generation == "auto" and n_input_images + max_images > 15:
|
||||
raise ValueError(
|
||||
"The maximum number of generated images plus the number of reference images cannot exceed 15."
|
||||
)
|
||||
|
||||
reference_images_urls: list[str] = []
|
||||
if image_tensors:
|
||||
for tensor in image_tensors:
|
||||
validate_image_aspect_ratio(tensor, (1, 3), (3, 1))
|
||||
reference_images_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image_tensors,
|
||||
max_images=n_input_images,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading reference images",
|
||||
)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
|
||||
response_model=ImageTaskCreationResponse,
|
||||
data=Seedream4TaskCreationRequest(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
image=reference_images_urls,
|
||||
size=f"{w}x{h}",
|
||||
seed=seed,
|
||||
sequential_image_generation=sequential_image_generation,
|
||||
sequential_image_generation_options=Seedream4Options(max_images=max_images),
|
||||
watermark=watermark,
|
||||
),
|
||||
)
|
||||
if len(response.data) == 1:
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
|
||||
urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d]
|
||||
if fail_on_partial and len(urls) < len(response.data):
|
||||
raise RuntimeError(f"Only {len(urls)} of {len(response.data)} images were generated before error.")
|
||||
return IO.NodeOutput(torch.cat([await download_url_to_image_tensor(i) for i in urls]))
|
||||
|
||||
|
||||
class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -2105,6 +2335,7 @@ class ByteDanceExtension(ComfyExtension):
|
||||
return [
|
||||
ByteDanceImageNode,
|
||||
ByteDanceSeedreamNode,
|
||||
ByteDanceSeedreamNodeV2,
|
||||
ByteDanceTextToVideoNode,
|
||||
ByteDanceImageToVideoNode,
|
||||
ByteDanceFirstLastFrameNode,
|
||||
|
||||
@ -162,6 +162,61 @@ class GrokImageNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
_GROK_IMAGE_EDIT_ASPECT_RATIO_OPTIONS = [
|
||||
"auto",
|
||||
"1:1",
|
||||
"2:3",
|
||||
"3:2",
|
||||
"3:4",
|
||||
"4:3",
|
||||
"9:16",
|
||||
"16:9",
|
||||
"9:19.5",
|
||||
"19.5:9",
|
||||
"9:20",
|
||||
"20:9",
|
||||
"1:2",
|
||||
"2:1",
|
||||
]
|
||||
|
||||
|
||||
def _grok_image_edit_model_inputs(*, max_ref_images: int, with_aspect_ratio: bool):
|
||||
inputs = [
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, max_ref_images + 1)],
|
||||
min=1,
|
||||
),
|
||||
tooltip=(
|
||||
"Reference image to edit."
|
||||
if max_ref_images == 1
|
||||
else f"Reference image(s) to edit. Up to {max_ref_images} images."
|
||||
),
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1K", "2K"]),
|
||||
IO.Int.Input(
|
||||
"number_of_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="Number of edited images to generate",
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
]
|
||||
if with_aspect_ratio:
|
||||
inputs.append(
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=_GROK_IMAGE_EDIT_ASPECT_RATIO_OPTIONS,
|
||||
tooltip="Only allowed when multiple images are connected.",
|
||||
)
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
class GrokImageEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -256,6 +311,7 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
)
|
||||
""",
|
||||
),
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -303,6 +359,143 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class GrokImageEditNodeV2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokImageEditNodeV2",
|
||||
display_name="Grok Image Edit",
|
||||
category="api node/image/Grok",
|
||||
description="Modify an existing image based on a text prompt",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="The text prompt used to generate the image",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-image-quality",
|
||||
_grok_image_edit_model_inputs(max_ref_images=3, with_aspect_ratio=True),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-image-pro",
|
||||
_grok_image_edit_model_inputs(max_ref_images=1, with_aspect_ratio=False),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-image",
|
||||
_grok_image_edit_model_inputs(max_ref_images=3, with_aspect_ratio=True),
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model", "model.resolution", "model.number_of_images"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isQualityModel := widgets.model = "grok-imagine-image-quality";
|
||||
$isPro := $contains(widgets.model, "pro");
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$n := $lookup(widgets, "model.number_of_images");
|
||||
$rate := $isQualityModel
|
||||
? ($res = "1k" ? 0.05 : 0.07)
|
||||
: ($isPro ? 0.07 : 0.02);
|
||||
$base := $isQualityModel ? 0.01 : 0.002;
|
||||
$output := $rate * $n;
|
||||
$isPro
|
||||
? {"type":"usd","usd": $base + $output}
|
||||
: {"type":"range_usd","min_usd": $base + $output, "max_usd": 3 * $base + $output}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_id = model["model"]
|
||||
resolution = model["resolution"]
|
||||
number_of_images = model["number_of_images"]
|
||||
images_dict = model.get("images") or {}
|
||||
aspect_ratio = model.get("aspect_ratio", "auto")
|
||||
|
||||
image_tensors: list[Input.Image] = [t for t in images_dict.values() if t is not None]
|
||||
n_images = sum(get_number_of_images(t) for t in image_tensors)
|
||||
if n_images < 1:
|
||||
raise ValueError("At least one image is required for editing.")
|
||||
if model_id == "grok-imagine-image-pro" and n_images > 1:
|
||||
raise ValueError("The pro model supports only 1 input image.")
|
||||
if model_id != "grok-imagine-image-pro" and n_images > 3:
|
||||
raise ValueError("A maximum of 3 input images is supported.")
|
||||
if aspect_ratio != "auto" and n_images == 1:
|
||||
raise ValueError(
|
||||
"Custom aspect ratio is only allowed when multiple images are connected to the image input."
|
||||
)
|
||||
|
||||
flat_tensors: list[torch.Tensor] = []
|
||||
for tensor in image_tensors:
|
||||
if len(tensor.shape) == 4:
|
||||
flat_tensors.extend(tensor[i] for i in range(tensor.shape[0]))
|
||||
else:
|
||||
flat_tensors.append(tensor)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"),
|
||||
data=ImageEditRequest(
|
||||
model=model_id,
|
||||
images=[
|
||||
InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in flat_tensors
|
||||
],
|
||||
prompt=prompt,
|
||||
resolution=resolution.lower(),
|
||||
n=number_of_images,
|
||||
seed=seed,
|
||||
aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio,
|
||||
),
|
||||
response_model=ImageGenerationResponse,
|
||||
price_extractor=_extract_grok_price,
|
||||
)
|
||||
if len(response.data) == 1:
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
|
||||
return IO.NodeOutput(
|
||||
torch.cat(
|
||||
[await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class GrokVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -737,6 +930,7 @@ class GrokExtension(ComfyExtension):
|
||||
return [
|
||||
GrokImageNode,
|
||||
GrokImageEditNode,
|
||||
GrokImageEditNodeV2,
|
||||
GrokVideoNode,
|
||||
GrokVideoReferenceNode,
|
||||
GrokVideoEditNode,
|
||||
|
||||
@ -27,6 +27,7 @@ from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
downscale_image_tensor,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
@ -372,6 +373,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
display_name="OpenAI GPT Image 2",
|
||||
category="api node/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
|
||||
is_deprecated=True,
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -640,6 +642,316 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||
|
||||
|
||||
def _gpt_image_shared_inputs():
|
||||
"""Inputs shared by all GPT Image models (quality + reference images + mask)."""
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"quality",
|
||||
default="low",
|
||||
options=["low", "medium", "high"],
|
||||
tooltip="Image quality, affects cost and generation time.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, 17)],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional reference image(s) for image editing. Up to 16 images.",
|
||||
),
|
||||
IO.Mask.Input(
|
||||
"mask",
|
||||
optional=True,
|
||||
tooltip="Optional mask for inpainting (white areas will be replaced). "
|
||||
"Requires exactly one reference image.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _gpt_image_legacy_model_inputs():
|
||||
"""Per-model widget set for legacy gpt-image-1 / gpt-image-1.5 (4 base sizes, transparent bg allowed)."""
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
default="auto",
|
||||
options=["auto", "1024x1024", "1024x1536", "1536x1024"],
|
||||
tooltip="Image size.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"background",
|
||||
default="auto",
|
||||
options=["auto", "opaque", "transparent"],
|
||||
tooltip="Return image with or without background.",
|
||||
),
|
||||
*_gpt_image_shared_inputs(),
|
||||
]
|
||||
|
||||
|
||||
class OpenAIGPTImageNodeV2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIGPTImageNodeV2",
|
||||
display_name="OpenAI GPT Image 2",
|
||||
category="api node/image/OpenAI",
|
||||
description="Generates images via OpenAI's GPT Image endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
tooltip="Text prompt for GPT Image",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"gpt-image-2",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
default="auto",
|
||||
options=[
|
||||
"auto",
|
||||
"1024x1024",
|
||||
"1024x1536",
|
||||
"1536x1024",
|
||||
"2048x2048",
|
||||
"2048x1152",
|
||||
"1152x2048",
|
||||
"3840x2160",
|
||||
"2160x3840",
|
||||
"Custom",
|
||||
],
|
||||
tooltip="Image size. Select 'Custom' to use the custom width and height.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"custom_width",
|
||||
default=1024,
|
||||
min=1024,
|
||||
max=3840,
|
||||
step=16,
|
||||
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"custom_height",
|
||||
default=1024,
|
||||
min=1024,
|
||||
max=3840,
|
||||
step=16,
|
||||
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"background",
|
||||
default="auto",
|
||||
options=["auto", "opaque"],
|
||||
tooltip="Return image with or without background.",
|
||||
),
|
||||
*_gpt_image_shared_inputs(),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("gpt-image-1.5", _gpt_image_legacy_model_inputs()),
|
||||
IO.DynamicCombo.Option("gpt-image-1", _gpt_image_legacy_model_inputs()),
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"n",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="How many images to generate",
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="not implemented yet in backend",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.quality", "n"]),
|
||||
expr="""
|
||||
(
|
||||
$ranges := {
|
||||
"gpt-image-1": {
|
||||
"low": [0.011, 0.02],
|
||||
"medium": [0.042, 0.07],
|
||||
"high": [0.167, 0.25]
|
||||
},
|
||||
"gpt-image-1.5": {
|
||||
"low": [0.009, 0.02],
|
||||
"medium": [0.034, 0.062],
|
||||
"high": [0.133, 0.22]
|
||||
},
|
||||
"gpt-image-2": {
|
||||
"low": [0.0048, 0.019],
|
||||
"medium": [0.041, 0.168],
|
||||
"high": [0.165, 0.67]
|
||||
}
|
||||
};
|
||||
$range := $lookup($lookup($ranges, widgets.model), $lookup(widgets, "model.quality"));
|
||||
$nRaw := widgets.n;
|
||||
$n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1;
|
||||
($n = 1)
|
||||
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1], "format": {"approximate": true}}
|
||||
: {
|
||||
"type":"range_usd",
|
||||
"min_usd": $range[0] * $n,
|
||||
"max_usd": $range[1] * $n,
|
||||
"format": { "suffix": "/Run", "approximate": true }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
n: int,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
model_id = model["model"]
|
||||
size = model["size"]
|
||||
background = model["background"]
|
||||
quality = model["quality"]
|
||||
custom_width = model.get("custom_width", 1024)
|
||||
custom_height = model.get("custom_height", 1024)
|
||||
|
||||
images_dict = model.get("images") or {}
|
||||
image_tensors: list[Input.Image] = [t for t in images_dict.values() if t is not None]
|
||||
n_images = sum(get_number_of_images(t) for t in image_tensors)
|
||||
mask = model.get("mask")
|
||||
|
||||
if mask is not None and n_images == 0:
|
||||
raise ValueError("Cannot use a mask without an input image")
|
||||
|
||||
if size == "Custom":
|
||||
if custom_width % 16 != 0 or custom_height % 16 != 0:
|
||||
raise ValueError(
|
||||
f"Custom width and height must be multiples of 16, got {custom_width}x{custom_height}"
|
||||
)
|
||||
if max(custom_width, custom_height) > 3840:
|
||||
raise ValueError(
|
||||
f"Custom resolution max edge must be <= 3840, got {custom_width}x{custom_height}"
|
||||
)
|
||||
ratio = max(custom_width, custom_height) / min(custom_width, custom_height)
|
||||
if ratio > 3:
|
||||
raise ValueError(
|
||||
f"Custom resolution aspect ratio must not exceed 3:1, got {custom_width}x{custom_height}"
|
||||
)
|
||||
total_pixels = custom_width * custom_height
|
||||
if not 655_360 <= total_pixels <= 8_294_400:
|
||||
raise ValueError(
|
||||
f"Custom resolution total pixels must be between 655,360 and 8,294,400, got {total_pixels}"
|
||||
)
|
||||
size = f"{custom_width}x{custom_height}"
|
||||
|
||||
if model_id == "gpt-image-1":
|
||||
price_extractor = calculate_tokens_price_image_1
|
||||
elif model_id == "gpt-image-1.5":
|
||||
price_extractor = calculate_tokens_price_image_1_5
|
||||
elif model_id == "gpt-image-2":
|
||||
price_extractor = calculate_tokens_price_image_2_0
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {model_id}")
|
||||
|
||||
if image_tensors:
|
||||
flat: list[torch.Tensor] = []
|
||||
for tensor in image_tensors:
|
||||
if len(tensor.shape) == 4:
|
||||
flat.extend(tensor[i : i + 1] for i in range(tensor.shape[0]))
|
||||
else:
|
||||
flat.append(tensor.unsqueeze(0))
|
||||
|
||||
files = []
|
||||
for i, single_image in enumerate(flat):
|
||||
scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze()
|
||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
|
||||
if len(flat) == 1:
|
||||
files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
else:
|
||||
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
|
||||
if mask is not None:
|
||||
if len(flat) != 1:
|
||||
raise Exception("Cannot use a mask with multiple image")
|
||||
ref_image = flat[0]
|
||||
if mask.shape[1:] != ref_image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
_, height, width = mask.shape
|
||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||
scaled_mask = downscale_image_tensor(
|
||||
rgba_mask.unsqueeze(0), total_pixels=2048 * 2048
|
||||
).squeeze()
|
||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_img_byte_arr = BytesIO()
|
||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||
mask_img_byte_arr.seek(0)
|
||||
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
data=OpenAIImageEditRequest(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
size=size,
|
||||
moderation="low",
|
||||
),
|
||||
content_type="multipart/form-data",
|
||||
files=files,
|
||||
price_extractor=price_extractor,
|
||||
)
|
||||
else:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
data=OpenAIImageGenerationRequest(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
size=size,
|
||||
moderation="low",
|
||||
),
|
||||
price_extractor=price_extractor,
|
||||
)
|
||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||
|
||||
|
||||
class OpenAIChatNode(IO.ComfyNode):
|
||||
"""
|
||||
Node to generate text responses from an OpenAI model.
|
||||
@ -999,6 +1311,7 @@ class OpenAIExtension(ComfyExtension):
|
||||
OpenAIDalle2,
|
||||
OpenAIDalle3,
|
||||
OpenAIGPTImage1,
|
||||
OpenAIGPTImageNodeV2,
|
||||
OpenAIChatNode,
|
||||
OpenAIInputFiles,
|
||||
OpenAIChatConfig,
|
||||
|
||||
@ -51,7 +51,7 @@ class ApiEndpoint:
|
||||
|
||||
@dataclass
|
||||
class _RequestConfig:
|
||||
node_cls: type[IO.ComfyNode]
|
||||
node_cls: type[IO.ComfyNode] | None
|
||||
endpoint: ApiEndpoint
|
||||
timeout: float
|
||||
content_type: str
|
||||
@ -70,6 +70,17 @@ class _RequestConfig:
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None
|
||||
response_header_validator: Callable[[dict[str, str]], None] | None = None
|
||||
base_url: str | None = None
|
||||
auth_headers: dict[str, str] | None = None
|
||||
allow_304: bool = False
|
||||
error_parser: Callable[[int, Any], Exception | None] | None = None
|
||||
# Optional callback to render a per-second progress label while
|
||||
# waiting out a rate-limit / SERVER_BUSY / MAINTENANCE retry. Called
|
||||
# with ``(status, body, retry_after_s)`` and should return the label
|
||||
# string used by ``_display_time_progress`` (which renders it as
|
||||
# ``Status: <label>\nTime elapsed: Ns``). Returning ``None`` keeps
|
||||
# the default ``cfg.wait_label``.
|
||||
rate_limit_label: Callable[[int, Any, float], str | None] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -84,13 +95,40 @@ class _PollUIState:
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||
|
||||
|
||||
def _parse_retry_after(raw: str | None) -> float | None:
|
||||
"""RFC 7231 Retry-After: seconds-int or HTTP-date.
|
||||
|
||||
Returns the wait time in seconds, clamped to non-negative. Returns
|
||||
``None`` for unparseable / missing values so the caller can fall
|
||||
back to the local backoff schedule.
|
||||
"""
|
||||
if not raw:
|
||||
return None
|
||||
raw = raw.strip()
|
||||
try:
|
||||
return max(0.0, float(raw))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
# HTTP-date form (rare in practice for our servers, but cheap to support).
|
||||
try:
|
||||
from email.utils import parsedate_to_datetime
|
||||
dt = parsedate_to_datetime(raw)
|
||||
if dt is None:
|
||||
return None
|
||||
import datetime as _dt
|
||||
now = _dt.datetime.now(tz=dt.tzinfo) if dt.tzinfo else _dt.datetime.utcnow()
|
||||
return max(0.0, (dt - now).total_seconds())
|
||||
except Exception:
|
||||
return None
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
|
||||
|
||||
|
||||
async def sync_op(
|
||||
cls: type[IO.ComfyNode],
|
||||
cls: type[IO.ComfyNode] | None,
|
||||
endpoint: ApiEndpoint,
|
||||
*,
|
||||
response_model: type[M],
|
||||
@ -110,6 +148,9 @@ async def sync_op(
|
||||
monitor_progress: bool = True,
|
||||
max_retries_on_rate_limit: int = 16,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
rate_limit_label: Callable[[int, Any, float], str | None] | None = None,
|
||||
base_url: str | None = None,
|
||||
auth_headers: dict[str, str] | None = None,
|
||||
) -> M:
|
||||
raw = await sync_op_raw(
|
||||
cls,
|
||||
@ -131,6 +172,9 @@ async def sync_op(
|
||||
monitor_progress=monitor_progress,
|
||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||
is_rate_limited=is_rate_limited,
|
||||
rate_limit_label=rate_limit_label,
|
||||
base_url=base_url,
|
||||
auth_headers=auth_headers,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
@ -138,7 +182,7 @@ async def sync_op(
|
||||
|
||||
|
||||
async def poll_op(
|
||||
cls: type[IO.ComfyNode],
|
||||
cls: type[IO.ComfyNode] | None,
|
||||
poll_endpoint: ApiEndpoint,
|
||||
*,
|
||||
response_model: type[M],
|
||||
@ -159,6 +203,11 @@ async def poll_op(
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
extra_text: str | None = None,
|
||||
base_url: str | None = None,
|
||||
auth_headers: dict[str, str] | None = None,
|
||||
error_parser: Callable[[int, Any], Exception | None] | None = None,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
rate_limit_label: Callable[[int, Any, float], str | None] | None = None,
|
||||
) -> M:
|
||||
raw = await poll_op_raw(
|
||||
cls,
|
||||
@ -180,6 +229,11 @@ async def poll_op(
|
||||
cancel_endpoint=cancel_endpoint,
|
||||
cancel_timeout=cancel_timeout,
|
||||
extra_text=extra_text,
|
||||
base_url=base_url,
|
||||
auth_headers=auth_headers,
|
||||
error_parser=error_parser,
|
||||
is_rate_limited=is_rate_limited,
|
||||
rate_limit_label=rate_limit_label,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
@ -187,7 +241,7 @@ async def poll_op(
|
||||
|
||||
|
||||
async def sync_op_raw(
|
||||
cls: type[IO.ComfyNode],
|
||||
cls: type[IO.ComfyNode] | None,
|
||||
endpoint: ApiEndpoint,
|
||||
*,
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
|
||||
@ -207,13 +261,26 @@ async def sync_op_raw(
|
||||
monitor_progress: bool = True,
|
||||
max_retries_on_rate_limit: int = 16,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
rate_limit_label: Callable[[int, Any, float], str | None] | None = None,
|
||||
response_header_validator: Callable[[dict[str, str]], None] | None = None,
|
||||
) -> dict[str, Any] | bytes:
|
||||
base_url: str | None = None,
|
||||
auth_headers: dict[str, str] | None = None,
|
||||
allow_304: bool = False,
|
||||
error_parser: Callable[[int, Any], Exception | None] | None = None,
|
||||
) -> dict[str, Any] | bytes | None:
|
||||
"""
|
||||
Make a single network request.
|
||||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||
- If as_binary=True: returns bytes.
|
||||
- response_header_validator: optional callback receiving response headers dict
|
||||
- base_url: override the default api.comfy.org base for this request.
|
||||
- auth_headers: pre-built Authorization/X-API-KEY dict; bypasses get_auth_header.
|
||||
- allow_304: when True, an HTTP 304 response returns ``None`` instead of raising.
|
||||
- error_parser: when set, called on every >=400 response with
|
||||
``(status, body)``; if it returns an Exception, that exception
|
||||
is raised immediately and the retry/friendly-message path is
|
||||
skipped. Used by RNP to surface structured ``RnpProtocolError``
|
||||
envelopes that would otherwise be flattened to "API Error: ...".
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump(exclude_none=True)
|
||||
@ -239,13 +306,18 @@ async def sync_op_raw(
|
||||
price_extractor=price_extractor,
|
||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||
is_rate_limited=is_rate_limited,
|
||||
rate_limit_label=rate_limit_label,
|
||||
response_header_validator=response_header_validator,
|
||||
base_url=base_url,
|
||||
auth_headers=auth_headers,
|
||||
allow_304=allow_304,
|
||||
error_parser=error_parser,
|
||||
)
|
||||
return await _request_base(cfg, expect_binary=as_binary)
|
||||
|
||||
|
||||
async def poll_op_raw(
|
||||
cls: type[IO.ComfyNode],
|
||||
cls: type[IO.ComfyNode] | None,
|
||||
poll_endpoint: ApiEndpoint,
|
||||
*,
|
||||
status_extractor: Callable[[dict[str, Any]], str | int | None],
|
||||
@ -265,6 +337,11 @@ async def poll_op_raw(
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
extra_text: str | None = None,
|
||||
base_url: str | None = None,
|
||||
auth_headers: dict[str, str] | None = None,
|
||||
error_parser: Callable[[int, Any], Exception | None] | None = None,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
rate_limit_label: Callable[[int, Any, float], str | None] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
|
||||
@ -272,6 +349,14 @@ async def poll_op_raw(
|
||||
|
||||
Uses default complete, failed and queued states assumption.
|
||||
|
||||
``error_parser`` and ``is_rate_limited`` are forwarded to each
|
||||
per-poll ``sync_op_raw`` call so callers can surface a typed
|
||||
exception for >=400 responses (e.g. an RNP structured-error
|
||||
envelope) and treat protocol-specific 5xx codes (e.g. RNP
|
||||
``SERVER_BUSY`` / ``MAINTENANCE``) like a 429 — both honour
|
||||
``Retry-After`` and consume the rate-limit retry counter instead
|
||||
of falling through the generic 5xx exponential-backoff path.
|
||||
|
||||
Returns the final JSON response from the poll endpoint.
|
||||
"""
|
||||
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
|
||||
@ -286,6 +371,22 @@ async def poll_op_raw(
|
||||
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
||||
stop_ticker = asyncio.Event()
|
||||
|
||||
# Wrap the user's rate_limit_label so a SERVER_BUSY/MAINTENANCE/429
|
||||
# wait inside the per-poll sync_op_raw also flips the outer ticker's
|
||||
# status_label — otherwise _ticker keeps writing "Status: Queued"
|
||||
# over our message every second. The next successful poll resets
|
||||
# status_label from the response, so no manual restore is needed.
|
||||
user_rate_limit_label = rate_limit_label
|
||||
|
||||
def _wrapped_rate_limit_label(status: int, body: Any, retry_after_s: float) -> str | None:
|
||||
label: str | None = None
|
||||
if user_rate_limit_label is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
label = user_rate_limit_label(status, body, retry_after_s)
|
||||
if label:
|
||||
state.status_label = label
|
||||
return label
|
||||
|
||||
async def _ticker():
|
||||
"""Emit a UI update every second while polling is in progress."""
|
||||
try:
|
||||
@ -327,6 +428,11 @@ async def poll_op_raw(
|
||||
as_binary=False,
|
||||
final_label_on_success=None,
|
||||
monitor_progress=False,
|
||||
base_url=base_url,
|
||||
auth_headers=auth_headers,
|
||||
error_parser=error_parser,
|
||||
is_rate_limited=is_rate_limited,
|
||||
rate_limit_label=_wrapped_rate_limit_label,
|
||||
)
|
||||
if not isinstance(resp_json, dict):
|
||||
raise Exception("Polling endpoint returned non-JSON response.")
|
||||
@ -343,6 +449,8 @@ async def poll_op_raw(
|
||||
as_binary=False,
|
||||
final_label_on_success=None,
|
||||
monitor_progress=False,
|
||||
base_url=base_url,
|
||||
auth_headers=auth_headers,
|
||||
)
|
||||
raise
|
||||
|
||||
@ -419,6 +527,8 @@ async def poll_op_raw(
|
||||
as_binary=False,
|
||||
final_label_on_success=None,
|
||||
monitor_progress=False,
|
||||
base_url=base_url,
|
||||
auth_headers=auth_headers,
|
||||
)
|
||||
raise
|
||||
if not is_queued:
|
||||
@ -433,6 +543,16 @@ async def poll_op_raw(
|
||||
except (LocalNetworkError, ApiServerError):
|
||||
raise
|
||||
except Exception as e:
|
||||
# Let typed protocol errors raised by ``error_parser`` (e.g.
|
||||
# RnpProtocolError) bubble unchanged so callers can pattern-
|
||||
# match on ``.code`` to drive resume / fallback logic. Any
|
||||
# exception that exposes a string ``code`` attribute counts as
|
||||
# "typed" — duck-typing avoids importing the typed-error class
|
||||
# into this generic util layer. Everything else gets the
|
||||
# friendlier wrapper for back-compat with existing api-node
|
||||
# callers that surface the wrapped message directly to users.
|
||||
if isinstance(getattr(e, "code", None), str):
|
||||
raise
|
||||
raise Exception(f"Polling aborted due to error: {e}") from e
|
||||
finally:
|
||||
stop_ticker.set()
|
||||
@ -441,12 +561,16 @@ async def poll_op_raw(
|
||||
|
||||
|
||||
def _display_text(
|
||||
node_cls: type[IO.ComfyNode],
|
||||
node_cls: type[IO.ComfyNode] | None,
|
||||
text: str | None,
|
||||
*,
|
||||
status: str | int | None = None,
|
||||
price: float | None = None,
|
||||
) -> None:
|
||||
# Skip when there's no node to address — RNP / bootstrap callers
|
||||
# pass cls=None on requests that aren't tied to a workflow node.
|
||||
if node_cls is None:
|
||||
return
|
||||
display_lines: list[str] = []
|
||||
if status:
|
||||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||||
@ -461,7 +585,7 @@ def _display_text(
|
||||
|
||||
|
||||
def _display_time_progress(
|
||||
node_cls: type[IO.ComfyNode],
|
||||
node_cls: type[IO.ComfyNode] | None,
|
||||
status: str | int | None,
|
||||
elapsed_seconds: int,
|
||||
estimated_total: int | None = None,
|
||||
@ -481,7 +605,7 @@ def _display_time_progress(
|
||||
_display_text(node_cls, text, status=status, price=price)
|
||||
|
||||
|
||||
async def _diagnose_connectivity() -> dict[str, bool]:
|
||||
async def _diagnose_connectivity(base_url: str | None = None) -> dict[str, bool]:
|
||||
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
|
||||
results = {
|
||||
"internet_accessible": False,
|
||||
@ -515,7 +639,7 @@ async def _diagnose_connectivity() -> dict[str, bool]:
|
||||
if not results["internet_accessible"]:
|
||||
return results
|
||||
|
||||
parsed = urlparse(default_base_url())
|
||||
parsed = urlparse(base_url or default_base_url())
|
||||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||||
with contextlib.suppress(ClientError, OSError):
|
||||
async with session.get(health_url) as resp:
|
||||
@ -604,10 +728,11 @@ def _snapshot_request_body_for_logging(
|
||||
|
||||
async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
"""Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors."""
|
||||
resolved_base_url = cfg.base_url or default_base_url()
|
||||
url = cfg.endpoint.path
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||
url = urljoin(resolved_base_url.rstrip("/") + "/", url.lstrip("/"))
|
||||
|
||||
method = cfg.endpoint.method
|
||||
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
|
||||
@ -645,7 +770,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
|
||||
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||
if cfg.auth_headers is not None:
|
||||
payload_headers.update(cfg.auth_headers)
|
||||
elif cfg.node_cls is not None:
|
||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||
payload_headers["Comfy-Env"] = get_deploy_environment()
|
||||
if cfg.endpoint.headers:
|
||||
payload_headers.update(cfg.endpoint.headers)
|
||||
@ -726,6 +854,21 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
# Otherwise, request finished
|
||||
resp = await req_task
|
||||
async with resp:
|
||||
if cfg.allow_304 and resp.status == 304:
|
||||
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
|
||||
if cfg.response_header_validator:
|
||||
cfg.response_header_validator(resp_headers)
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=resp_headers,
|
||||
response_content=None,
|
||||
)
|
||||
return None
|
||||
if resp.status >= 400:
|
||||
try:
|
||||
body = await resp.json()
|
||||
@ -737,12 +880,33 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
is_rl = resp.status == 429 or (
|
||||
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
|
||||
)
|
||||
sleep_label = cfg.wait_label
|
||||
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
|
||||
rate_limit_attempts += 1
|
||||
wait_time = min(rate_limit_delay, 30.0)
|
||||
# Honor server-provided Retry-After when present
|
||||
# (clamped to keep a runaway header from blocking
|
||||
# the executor for hours), otherwise fall back to
|
||||
# the local exponential backoff.
|
||||
retry_after_s = _parse_retry_after(resp.headers.get("Retry-After"))
|
||||
if retry_after_s is not None:
|
||||
wait_time = min(retry_after_s, 300.0)
|
||||
else:
|
||||
wait_time = min(rate_limit_delay, 30.0)
|
||||
rate_limit_delay *= cfg.retry_backoff
|
||||
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
|
||||
should_retry = True
|
||||
# Let callers (e.g. RNP) render a friendlier
|
||||
# per-second label like
|
||||
# "Server busy, retrying in 30s..." while we
|
||||
# sleep — surfaced via send_progress_text by
|
||||
# _display_time_progress every second.
|
||||
if cfg.rate_limit_label is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
custom = cfg.rate_limit_label(
|
||||
resp.status, body, wait_time
|
||||
)
|
||||
if custom:
|
||||
sleep_label = custom
|
||||
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||
wait_time = delay
|
||||
delay *= cfg.retry_backoff
|
||||
@ -767,15 +931,45 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
response_content=body,
|
||||
error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
|
||||
)
|
||||
# Stop the in-flight monitor so the per-second
|
||||
# progress label flips from cfg.wait_label
|
||||
# ("Waiting for server") to the rate-limit copy
|
||||
# ("Server busy, retrying in Ns...") for the
|
||||
# duration of this sleep — otherwise the two
|
||||
# writers race and the user sees alternating
|
||||
# text every tick.
|
||||
stop_event.set()
|
||||
if monitor_task:
|
||||
monitor_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await monitor_task
|
||||
monitor_task = None
|
||||
await sleep_with_interrupt(
|
||||
wait_time,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
sleep_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
continue
|
||||
# Retries either weren't applicable or have been exhausted —
|
||||
# give the caller's error_parser a chance to surface a
|
||||
# structured exception (e.g. RNP RnpProtocolError) before
|
||||
# we flatten the response with _friendly_http_message.
|
||||
if cfg.error_parser is not None:
|
||||
custom_exc = cfg.error_parser(resp.status, body)
|
||||
if custom_exc is not None:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=f"{type(custom_exc).__name__}: {custom_exc}",
|
||||
)
|
||||
raise custom_exc
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
@ -878,7 +1072,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
)
|
||||
delay *= cfg.retry_backoff
|
||||
continue
|
||||
diag = await _diagnose_connectivity()
|
||||
diag = await _diagnose_connectivity(resolved_base_url)
|
||||
if not diag["internet_accessible"]:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
@ -903,7 +1097,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
error_message=f"ApiServerError: {str(e)}",
|
||||
)
|
||||
raise ApiServerError(
|
||||
f"The API server at {default_base_url()} is currently unreachable. "
|
||||
f"The API server at {resolved_base_url} is currently unreachable. "
|
||||
f"The service may be experiencing issues."
|
||||
) from e
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user