Compare commits

..

1 Commits

Author SHA1 Message Date
b9abd21cb7 feat: make API node retry parameters configurable via environment variables
Adds COMFY_API_MAX_RETRIES, COMFY_API_RETRY_DELAY, and
COMFY_API_RETRY_BACKOFF environment variables that override the default
retry parameters for all API node HTTP requests (sync_op, sync_op_raw,
upload_file, download_url_to_bytesio).

Users in regions with unstable networks (e.g. behind the GFW in China)
can increase the retry budget to tolerate longer network interruptions:

  COMFY_API_MAX_RETRIES=10 COMFY_API_RETRY_DELAY=2.0 python main.py

Defaults remain unchanged (3 retries, 1.0s delay, 2.0x backoff) when
the env vars are not set.
2026-04-19 10:48:48 +00:00
18 changed files with 539 additions and 956 deletions

View File

@ -195,9 +195,7 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat
#### Alternative Downloads:
[Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).

View File

@ -4,6 +4,9 @@ import math
import torch
import torchaudio
import comfy.model_management
import comfy.model_patcher
import comfy.utils as utils
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
@ -40,6 +43,30 @@ class AudioVAEComponentConfig:
return cls(autoencoder=audio_config, vocoder=vocoder_config)
class ModelDeviceManager:
"""Manages device placement and GPU residency for the composed model."""
def __init__(self, module: torch.nn.Module):
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.vae_offload_device()
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
def ensure_model_loaded(self) -> None:
comfy.model_management.free_memory(
self.patcher.model_size(),
self.patcher.load_device,
)
comfy.model_management.load_model_gpu(self.patcher)
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self.patcher.load_device)
@property
def load_device(self):
return self.patcher.load_device
class AudioLatentNormalizer:
"""Applies per-channel statistics in patch space and restores original layout."""
@ -105,17 +132,23 @@ class AudioPreprocessor:
class AudioVAE(torch.nn.Module):
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
def __init__(self, metadata: dict):
def __init__(self, state_dict: dict, metadata: dict):
super().__init__()
component_config = AudioVAEComponentConfig.from_metadata(metadata)
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
if "bwe" in component_config.vocoder:
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
else:
self.vocoder = Vocoder(config=component_config.vocoder)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)
autoencoder_config = self.autoencoder.get_config()
self.normalizer = AudioLatentNormalizer(
AudioPatchifier(
@ -135,12 +168,18 @@ class AudioVAE(torch.nn.Module):
n_fft=autoencoder_config["n_fft"],
)
def encode(self, audio, sample_rate=44100) -> torch.Tensor:
self.device_manager = ModelDeviceManager(self)
def encode(self, audio: dict) -> torch.Tensor:
"""Encode a waveform dictionary into normalized latent tensors."""
waveform = audio
waveform_sample_rate = sample_rate
waveform = audio["waveform"]
waveform_sample_rate = audio["sample_rate"]
input_device = waveform.device
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
waveform = self.device_manager.move_to_load_device(waveform)
expected_channels = self.autoencoder.encoder.in_channels
if waveform.shape[1] != expected_channels:
if waveform.shape[1] == 1:
@ -151,7 +190,7 @@ class AudioVAE(torch.nn.Module):
)
mel_spec = self.preprocessor.waveform_to_mel(
waveform, waveform_sample_rate, device=waveform.device
waveform, waveform_sample_rate, device=self.device_manager.load_device
)
latents = self.autoencoder.encode(mel_spec)
@ -165,13 +204,17 @@ class AudioVAE(torch.nn.Module):
"""Decode normalized latent tensors into an audio waveform."""
original_shape = latents.shape
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
latents = self.device_manager.move_to_load_device(latents)
latents = self.normalizer.denormalize(latents)
target_shape = self.target_shape_from_latents(original_shape)
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
waveform = self.run_vocoder(mel_spec)
return waveform
return self.device_manager.move_to_load_device(waveform)
def target_shape_from_latents(self, latents_shape):
batch, _, time, _ = latents_shape

View File

@ -12,7 +12,6 @@ from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.lightricks.vae.audio_vae
import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
@ -806,24 +805,6 @@ class VAE:
self.downscale_index_formula = (4, 8, 8)
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio
sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder."})
self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = self.first_stage_model.latent_channels
self.audio_sample_rate_output = self.first_stage_model.output_sample_rate
self.autoencoder = self.first_stage_model.autoencoder # TODO: remove hack for ltxv custom nodes
self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 4096
self.downscale_ratio = 4096
self.latent_dim = 2
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None

View File

@ -158,17 +158,10 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
("Custom", None, None),
]
# Seedance 2.0 reference video pixel count limits per model and output resolution.
# Seedance 2.0 reference video pixel count limits per model.
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
"dreamina-seedance-2-0-260128": {
"480p": {"min": 409_600, "max": 927_408},
"720p": {"min": 409_600, "max": 927_408},
"1080p": {"min": 409_600, "max": 2_073_600},
},
"dreamina-seedance-2-0-fast-260128": {
"480p": {"min": 409_600, "max": 927_408},
"720p": {"min": 409_600, "max": 927_408},
},
"dreamina-seedance-2-0-260128": {"min": 409_600, "max": 927_408},
"dreamina-seedance-2-0-fast-260128": {"min": 409_600, "max": 927_408},
}
# The time in this dictionary are given for 10 seconds duration.

View File

@ -35,7 +35,6 @@ from comfy_api_nodes.util import (
get_number_of_images,
image_tensor_pair_to_batch,
poll_op,
resize_video_to_pixel_budget,
sync_op,
upload_audio_to_comfyapi,
upload_image_to_comfyapi,
@ -70,12 +69,9 @@ DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-2504
logger = logging.getLogger(__name__)
def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: str, index: int) -> None:
"""Validate reference video pixel count against Seedance 2.0 model limits for the selected resolution."""
model_limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id)
if not model_limits:
return
limits = model_limits.get(resolution)
def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> None:
"""Validate reference video pixel count against Seedance 2.0 model limits."""
limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id)
if not limits:
return
try:
@ -1377,14 +1373,6 @@ def _seedance2_reference_inputs(resolutions: list[str]):
min=0,
),
),
IO.Boolean.Input(
"auto_downscale",
default=False,
advanced=True,
optional=True,
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
),
]
@ -1492,23 +1480,10 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
model_id = SEEDANCE_MODELS[model["model"]]
has_video_input = len(reference_videos) > 0
if model.get("auto_downscale") and reference_videos:
max_px = (
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {})
.get(model["resolution"], {})
.get("max")
)
if max_px:
for key in reference_videos:
reference_videos[key] = resize_video_to_pixel_budget(
reference_videos[key], max_px
)
total_video_duration = 0.0
for i, key in enumerate(reference_videos, 1):
video = reference_videos[key]
_validate_ref_video_pixels(video, model_id, model["resolution"], i)
_validate_ref_video_pixels(video, model_id, i)
try:
dur = video.get_duration()
if dur < 1.8:

View File

@ -357,18 +357,13 @@ def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) ->
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
def calculate_tokens_price_image_2(response: OpenAIImageGenerationResponse) -> float | None:
# https://platform.openai.com/docs/pricing - gpt-image-2: input $8/1M, output $30/1M
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 30.0)) / 1_000_000.0
class OpenAIGPTImage1(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="OpenAIGPTImage1",
display_name="OpenAI GPT Image 1 & 1.5",
display_name="OpenAI GPT Image 1.5",
category="api node/image/OpenAI",
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
inputs=[
@ -447,22 +442,14 @@ class OpenAIGPTImage1(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n", "model"]),
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]),
expr="""
(
$m := widgets.model;
$ranges :=
$contains($m, "gpt-image-1.5")
? {
"low": [0.009, 0.016],
"medium": [0.037, 0.056],
"high": [0.134, 0.240]
}
: {
"low": [0.011, 0.020],
"medium": [0.046, 0.070],
"high": [0.167, 0.300]
};
$ranges := {
"low": [0.011, 0.02],
"medium": [0.046, 0.07],
"high": [0.167, 0.3]
};
$range := $lookup($ranges, widgets.quality);
$n := widgets.n;
($n = 1)
@ -577,261 +564,6 @@ class OpenAIGPTImage1(IO.ComfyNode):
return IO.NodeOutput(await validate_and_cast_response(response))
_GPT_IMAGE_2_SIZES = [
"auto",
"1024x1024",
"1536x1024",
"1024x1536",
"2048x2048",
"2048x1152",
"3840x2160",
"2160x3840",
]
def _resolve_gpt_image_2_size(size: str, custom_width: int, custom_height: int) -> str:
if custom_width <= 0 or custom_height <= 0:
return size
w, h = custom_width, custom_height
if max(w, h) > 3840:
raise ValueError(f"Maximum edge length must be ≤ 3840px, got {max(w, h)}")
if w % 16 != 0 or h % 16 != 0:
raise ValueError(f"Both edges must be multiples of 16px, got {w}x{h}")
if max(w, h) / min(w, h) > 3:
raise ValueError(f"Long-to-short edge ratio must not exceed 3:1, got {max(w, h) / min(w, h):.2f}:1")
total = w * h
if total < 655_360 or total > 8_294_400:
raise ValueError(f"Total pixels must be between 655,360 and 8,294,400, got {total:,}")
return f"{w}x{h}"
class OpenAIGPTImage2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="OpenAIGPTImage2",
display_name="OpenAI GPT Image 2",
category="api node/image/OpenAI",
description="Generates images synchronously via OpenAI's GPT-Image-2 endpoint.",
inputs=[
IO.String.Input(
"prompt",
default="",
multiline=True,
tooltip="Text prompt for GPT Image 2",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2**31 - 1,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="not implemented yet in backend",
optional=True,
),
IO.Combo.Input(
"quality",
default="auto",
options=["auto", "low", "medium", "high"],
tooltip="Image quality. 'auto' lets the model decide based on the prompt. Square images are typically fastest.",
optional=True,
),
IO.Combo.Input(
"background",
default="auto",
options=["auto", "opaque"],
tooltip="Background style. GPT-Image-2 does not support transparent backgrounds.",
optional=True,
),
IO.Combo.Input(
"size",
default="auto",
options=_GPT_IMAGE_2_SIZES,
tooltip="Output image dimensions. Ignored when custom_width and custom_height are both non-zero.",
optional=True,
),
IO.Int.Input(
"custom_width",
default=0,
min=0,
max=3840,
step=16,
display_mode=IO.NumberDisplay.number,
tooltip="Custom output width in pixels. Set to 0 (default) to use the size preset. When both width and height are non-zero, they override the size preset. Slider enforces multiples of 16 and max edge 3840px. Additional constraints checked at generation: ratio ≤ 3:1, total pixels 655,3608,294,400.",
optional=True,
),
IO.Int.Input(
"custom_height",
default=0,
min=0,
max=3840,
step=16,
display_mode=IO.NumberDisplay.number,
tooltip="Custom output height in pixels. Set to 0 (default) to use the size preset. When both width and height are non-zero, they override the size preset. Slider enforces multiples of 16 and max edge 3840px. Additional constraints checked at generation: ratio ≤ 3:1, total pixels 655,3608,294,400.",
optional=True,
),
IO.Int.Input(
"num_images",
default=1,
min=1,
max=8,
step=1,
tooltip="Number of images to generate per run.",
display_mode=IO.NumberDisplay.number,
optional=True,
),
IO.Image.Input(
"image",
tooltip="Optional reference image for image editing.",
optional=True,
),
IO.Mask.Input(
"mask",
tooltip="Optional mask for inpainting (white areas will be replaced).",
optional=True,
),
IO.Combo.Input(
"model",
options=["gpt-image-2"],
default="gpt-image-2",
tooltip="Model used for image generation.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["quality", "num_images"]),
expr="""
(
$ranges := {
"low": [0.005, 0.010],
"medium": [0.041, 0.060],
"high": [0.165, 0.250]
};
$q := widgets.quality;
$n := widgets.num_images;
$n := ($n != null and $n != 0) ? $n : 1;
$range := $lookup($ranges, $q);
$lo := $range ? $range[0] : 0.005;
$hi := $range ? $range[1] : 0.250;
($n = 1)
? {"type":"range_usd","min_usd": $lo, "max_usd": $hi, "format": {"approximate": ($range ? false : true)}}
: {
"type":"range_usd",
"min_usd": $lo,
"max_usd": $hi,
"format": {"approximate": ($range ? false : true), "suffix": " x " & $string($n) & "/Run"}
}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
seed: int = 0,
quality: str = "auto",
background: str = "auto",
image: Input.Image | None = None,
mask: Input.Image | None = None,
num_images: int = 1,
size: str = "auto",
custom_width: int = 0,
custom_height: int = 0,
model: str = "gpt-image-2",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
if mask is not None and image is None:
raise ValueError("Cannot use a mask without an input image")
resolved_size = _resolve_gpt_image_2_size(size, custom_width, custom_height)
if image is not None:
files = []
batch_size = image.shape[0]
for i in range(batch_size):
single_image = image[i : i + 1]
scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze()
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
img_byte_arr = BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
if batch_size == 1:
files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
else:
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
if mask is not None:
if image.shape[0] != 1:
raise Exception("Cannot use a mask with multiple image")
if mask.shape[1:] != image.shape[1:-1]:
raise Exception("Mask and Image must be the same size")
_, height, width = mask.shape
rgba_mask = torch.zeros(height, width, 4, device="cpu")
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048 * 2048).squeeze()
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_img_byte_arr = BytesIO()
mask_img.save(mask_img_byte_arr, format="PNG")
mask_img_byte_arr.seek(0)
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
response_model=OpenAIImageGenerationResponse,
data=OpenAIImageEditRequest(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=num_images,
size=resolved_size,
moderation="low",
),
content_type="multipart/form-data",
files=files,
price_extractor=calculate_tokens_price_image_2,
)
else:
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
response_model=OpenAIImageGenerationResponse,
data=OpenAIImageGenerationRequest(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=num_images,
size=resolved_size,
moderation="low",
),
price_extractor=calculate_tokens_price_image_2,
)
return IO.NodeOutput(await validate_and_cast_response(response))
class OpenAIChatNode(IO.ComfyNode):
"""
Node to generate text responses from an OpenAI model.
@ -1181,7 +913,6 @@ class OpenAIExtension(ComfyExtension):
OpenAIDalle2,
OpenAIDalle3,
OpenAIGPTImage1,
OpenAIGPTImage2,
OpenAIChatNode,
OpenAIInputFiles,
OpenAIChatConfig,

View File

@ -24,9 +24,8 @@ from comfy_api_nodes.util import (
AVERAGE_DURATION_VIDEO_GEN = 32
MODELS_MAP = {
"veo-2.0-generate-001": "veo-2.0-generate-001",
"veo-3.1-generate": "veo-3.1-generate-001",
"veo-3.1-fast-generate": "veo-3.1-fast-generate-001",
"veo-3.1-lite": "veo-3.1-lite-generate-001",
"veo-3.1-generate": "veo-3.1-generate-preview",
"veo-3.1-fast-generate": "veo-3.1-fast-generate-preview",
"veo-3.0-generate-001": "veo-3.0-generate-001",
"veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001",
}
@ -248,8 +247,17 @@ class VeoVideoGenerationNode(IO.ComfyNode):
raise Exception("Video generation completed but no video was returned")
class Veo3VideoGenerationNode(IO.ComfyNode):
"""Generates videos from text prompts using Google's Veo 3 API."""
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
"""
Generates videos from text prompts using Google's Veo 3 API.
Supported models:
- veo-3.0-generate-001
- veo-3.0-fast-generate-001
This node extends the base Veo node with Veo 3 specific features including
audio generation and fixed 8-second duration.
"""
@classmethod
def define_schema(cls):
@ -271,13 +279,6 @@ class Veo3VideoGenerationNode(IO.ComfyNode):
default="16:9",
tooltip="Aspect ratio of the output video",
),
IO.Combo.Input(
"resolution",
options=["720p", "1080p", "4k"],
default="720p",
tooltip="Output video resolution. 4K is not available for veo-3.1-lite and veo-3.0 models.",
optional=True,
),
IO.String.Input(
"negative_prompt",
multiline=True,
@ -288,11 +289,11 @@ class Veo3VideoGenerationNode(IO.ComfyNode):
IO.Int.Input(
"duration_seconds",
default=8,
min=4,
min=8,
max=8,
step=2,
step=1,
display_mode=IO.NumberDisplay.number,
tooltip="Duration of the output video in seconds",
tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
optional=True,
),
IO.Boolean.Input(
@ -331,10 +332,10 @@ class Veo3VideoGenerationNode(IO.ComfyNode):
options=[
"veo-3.1-generate",
"veo-3.1-fast-generate",
"veo-3.1-lite",
"veo-3.0-generate-001",
"veo-3.0-fast-generate-001",
],
default="veo-3.0-generate-001",
tooltip="Veo 3 model to use for video generation",
optional=True,
),
@ -355,111 +356,21 @@ class Veo3VideoGenerationNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "resolution", "duration_seconds"]),
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]),
expr="""
(
$m := widgets.model;
$r := widgets.resolution;
$a := widgets.generate_audio;
$seconds := widgets.duration_seconds;
$pps :=
$contains($m, "lite")
? ($r = "1080p" ? ($a ? 0.08 : 0.05) : ($a ? 0.05 : 0.03))
: $contains($m, "3.1-fast")
? ($r = "4k" ? ($a ? 0.30 : 0.25) : $r = "1080p" ? ($a ? 0.12 : 0.10) : ($a ? 0.10 : 0.08))
: $contains($m, "3.1-generate")
? ($r = "4k" ? ($a ? 0.60 : 0.40) : ($a ? 0.40 : 0.20))
: $contains($m, "3.0-fast")
? ($a ? 0.15 : 0.10)
: ($a ? 0.40 : 0.20);
{"type":"usd","usd": $pps * $seconds}
($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate"))
? {"type":"usd","usd": ($a ? 1.2 : 0.8)}
: ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate"))
? {"type":"usd","usd": ($a ? 3.2 : 1.6)}
: {"type":"range_usd","min_usd":0.8,"max_usd":3.2}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt,
aspect_ratio="16:9",
resolution="720p",
negative_prompt="",
duration_seconds=8,
enhance_prompt=True,
person_generation="ALLOW",
seed=0,
image=None,
model="veo-3.0-generate-001",
generate_audio=False,
):
if "lite" in model and resolution == "4k":
raise Exception("4K resolution is not supported by the veo-3.1-lite model.")
model = MODELS_MAP[model]
instances = [{"prompt": prompt}]
if image is not None:
image_base64 = tensor_to_base64_string(image)
if image_base64:
instances[0]["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"}
parameters = {
"aspectRatio": aspect_ratio,
"personGeneration": person_generation,
"durationSeconds": duration_seconds,
"enhancePrompt": True,
"generateAudio": generate_audio,
}
if negative_prompt:
parameters["negativePrompt"] = negative_prompt
if seed > 0:
parameters["seed"] = seed
if "veo-3.1" in model:
parameters["resolution"] = resolution
initial_response = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
response_model=VeoGenVidResponse,
data=VeoGenVidRequest(
instances=instances,
parameters=parameters,
),
)
poll_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
response_model=VeoGenVidPollResponse,
status_extractor=lambda r: "completed" if r.done else "pending",
data=VeoGenVidPollRequest(operationName=initial_response.name),
poll_interval=9.0,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
)
if poll_response.error:
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
response = poll_response.response
filtered_count = response.raiMediaFilteredCount
if filtered_count:
reasons = response.raiMediaFilteredReasons or []
reason_part = f": {reasons[0]}" if reasons else ""
raise Exception(
f"Content blocked by Google's Responsible AI filters{reason_part} "
f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)."
)
if response.videos:
video = response.videos[0]
if video.bytesBase64Encoded:
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
raise Exception("Video returned but no data or URL was provided")
raise Exception("Video generation completed but no video was returned")
class Veo3FirstLastFrameNode(IO.ComfyNode):
@ -483,7 +394,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
default="",
tooltip="Negative text prompt to guide what to avoid in the video",
),
IO.Combo.Input("resolution", options=["720p", "1080p", "4k"]),
IO.Combo.Input("resolution", options=["720p", "1080p"]),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16"],
@ -513,7 +424,8 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
IO.Image.Input("last_frame", tooltip="End frame"),
IO.Combo.Input(
"model",
options=["veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.1-lite"],
options=["veo-3.1-generate", "veo-3.1-fast-generate"],
default="veo-3.1-fast-generate",
),
IO.Boolean.Input(
"generate_audio",
@ -531,20 +443,26 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration", "resolution"]),
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]),
expr="""
(
$prices := {
"veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 },
"veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 }
};
$m := widgets.model;
$r := widgets.resolution;
$ga := widgets.generate_audio;
$ga := (widgets.generate_audio = "true");
$seconds := widgets.duration;
$pps :=
$contains($m, "lite")
? ($r = "1080p" ? ($ga ? 0.08 : 0.05) : ($ga ? 0.05 : 0.03))
: $contains($m, "fast")
? ($r = "4k" ? ($ga ? 0.30 : 0.25) : $r = "1080p" ? ($ga ? 0.12 : 0.10) : ($ga ? 0.10 : 0.08))
: ($r = "4k" ? ($ga ? 0.60 : 0.40) : ($ga ? 0.40 : 0.20));
{"type":"usd","usd": $pps * $seconds}
$modelKey :=
$contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" :
$contains($m, "veo-3.1-generate") ? "veo-3.1-generate" :
"";
$audioKey := $ga ? "audio" : "no_audio";
$modelPrices := $lookup($prices, $modelKey);
$pps := $lookup($modelPrices, $audioKey);
($pps != null)
? {"type":"usd","usd": $pps * $seconds}
: {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2}
)
""",
),
@ -564,9 +482,6 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
model: str,
generate_audio: bool,
):
if "lite" in model and resolution == "4k":
raise Exception("4K resolution is not supported by the veo-3.1-lite model.")
model = MODELS_MAP[model]
initial_response = await sync_op(
cls,
@ -604,7 +519,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
data=VeoGenVidPollRequest(
operationName=initial_response.name,
),
poll_interval=9.0,
poll_interval=5.0,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
)

View File

@ -19,7 +19,6 @@ from .conversions import (
image_tensor_pair_to_batch,
pil_to_bytesio,
resize_mask_to_image,
resize_video_to_pixel_budget,
tensor_to_base64_string,
tensor_to_bytesio,
tensor_to_pil,
@ -91,7 +90,6 @@ __all__ = [
"image_tensor_pair_to_batch",
"pil_to_bytesio",
"resize_mask_to_image",
"resize_video_to_pixel_budget",
"tensor_to_base64_string",
"tensor_to_bytesio",
"tensor_to_pil",

View File

@ -2,6 +2,7 @@ import asyncio
import contextlib
import json
import logging
import os
import time
import uuid
from collections.abc import Callable, Iterable
@ -32,6 +33,30 @@ from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInte
M = TypeVar("M", bound=BaseModel)
def _env_int(key: str, default: int) -> int:
try:
return int(os.environ[key])
except (KeyError, ValueError):
return default
def _env_float(key: str, default: float) -> float:
try:
return float(os.environ[key])
except (KeyError, ValueError):
return default
@dataclass(frozen=True)
class _RetryDefaults:
max_retries: int = _env_int("COMFY_API_MAX_RETRIES", 3)
retry_delay: float = _env_float("COMFY_API_RETRY_DELAY", 1.0)
retry_backoff: float = _env_float("COMFY_API_RETRY_BACKOFF", 2.0)
RETRY_DEFAULTS = _RetryDefaults()
class ApiEndpoint:
def __init__(
self,
@ -78,11 +103,21 @@ class _PollUIState:
price: float | None = None
estimated_duration: int | None = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals
active_since: float | None = None # start time of current active interval (None if queued)
active_since: float | None = (
None # start time of current active interval (None if queued)
)
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
COMPLETED_STATUSES = [
"succeeded",
"succeed",
"success",
"completed",
"finished",
"done",
"complete",
]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
@ -98,9 +133,9 @@ async def sync_op(
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
max_retries: int = RETRY_DEFAULTS.max_retries,
retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
wait_label: str = "Waiting for server",
estimated_duration: int | None = None,
final_label_on_success: str | None = "Completed",
@ -131,7 +166,9 @@ async def sync_op(
is_rate_limited=is_rate_limited,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
raise Exception(
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
)
return _validate_or_raise(response_model, raw)
@ -178,7 +215,9 @@ async def poll_op(
cancel_timeout=cancel_timeout,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
raise Exception(
"Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text)."
)
return _validate_or_raise(response_model, raw)
@ -192,9 +231,9 @@ async def sync_op_raw(
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
max_retries: int = RETRY_DEFAULTS.max_retries,
retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
wait_label: str = "Waiting for server",
estimated_duration: int | None = None,
as_binary: bool = False,
@ -269,9 +308,15 @@ async def poll_op_raw(
Returns the final JSON response from the poll endpoint.
"""
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
completed_states = _normalize_statuses(
COMPLETED_STATUSES if completed_statuses is None else completed_statuses
)
failed_states = _normalize_statuses(
FAILED_STATUSES if failed_statuses is None else failed_statuses
)
queued_states = _normalize_statuses(
QUEUED_STATUSES if queued_statuses is None else queued_statuses
)
started = time.monotonic()
consumed_attempts = 0 # counts only non-queued polls
@ -289,7 +334,9 @@ async def poll_op_raw(
break
now = time.monotonic()
proc_elapsed = state.base_processing_elapsed + (
(now - state.active_since) if state.active_since is not None else 0.0
(now - state.active_since)
if state.active_since is not None
else 0.0
)
_display_time_progress(
cls,
@ -361,11 +408,15 @@ async def poll_op_raw(
is_queued = status in queued_states
if is_queued:
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
if (
state.active_since is not None
): # If we just moved from active -> queued, close the active interval
state.base_processing_elapsed += now_ts - state.active_since
state.active_since = None
else:
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
if (
state.active_since is None
): # If we just moved from queued -> active, open a new active interval
state.active_since = now_ts
state.is_queued = is_queued
@ -442,7 +493,9 @@ def _display_text(
) -> None:
display_lines: list[str] = []
if status:
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
display_lines.append(
f"Status: {status.capitalize() if isinstance(status, str) else status}"
)
if price is not None:
p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
if p != "0":
@ -450,7 +503,9 @@ def _display_text(
if text is not None:
display_lines.append(text)
if display_lines:
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
PromptServer.instance.send_progress_text(
"\n".join(display_lines), get_node_id(node_cls)
)
def _display_time_progress(
@ -464,7 +519,11 @@ def _display_time_progress(
processing_elapsed_seconds: int | None = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
pe = (
processing_elapsed_seconds
if processing_elapsed_seconds is not None
else elapsed_seconds
)
remaining = max(0, int(estimated_total) - int(pe))
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
else:
@ -503,7 +562,9 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
raise ValueError("files tuple must be (filename, file[, content_type])")
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
def _merge_params(
endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None
) -> dict[str, Any]:
params = dict(endpoint_params or {})
if method.upper() == "GET" and data:
for k, v in data.items():
@ -566,8 +627,14 @@ def _snapshot_request_body_for_logging(
filename = file_obj[0]
else:
filename = getattr(file_obj, "name", field_name)
file_fields.append({"field": field_name, "filename": str(filename or "")})
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
file_fields.append(
{"field": field_name, "filename": str(filename or "")}
)
return {
"_multipart": True,
"form_fields": form_fields,
"file_fields": file_fields,
}
if content_type == "application/x-www-form-urlencoded":
return data or {}
return data or {}
@ -581,7 +648,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
method = cfg.endpoint.method
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
params = _merge_params(
cfg.endpoint.query_params, method, cfg.data if method == "GET" else None
)
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
"""Every second: update elapsed time and signal interruption."""
@ -591,13 +660,20 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
return
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
cfg.node_cls,
cfg.wait_label,
int(time.monotonic() - start_ts),
cfg.estimated_total,
)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return # normal shutdown
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
start_time = (
cfg.progress_origin_ts
if cfg.progress_origin_ts is not None
else time.monotonic()
)
attempt = 0
delay = cfg.retry_delay
rate_limit_attempts = 0
@ -614,7 +690,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
payload_headers = (
{"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls))
if cfg.endpoint.headers:
@ -623,7 +701,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_kw: dict[str, Any] = {"headers": payload_headers}
if method == "GET":
payload_headers.pop("Content-Type", None)
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
request_body_log = _snapshot_request_body_for_logging(
cfg.content_type, method, cfg.data, cfg.files
)
try:
if cfg.monitor_progress:
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
@ -637,16 +717,23 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if cfg.multipart_parser and cfg.data:
form = cfg.multipart_parser(cfg.data)
if not isinstance(form, aiohttp.FormData):
raise ValueError("multipart_parser must return aiohttp.FormData")
raise ValueError(
"multipart_parser must return aiohttp.FormData"
)
else:
form = aiohttp.FormData(default_to_multipart=True)
if cfg.data:
for k, v in cfg.data.items():
if v is None:
continue
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
form.add_field(
k,
str(v) if not isinstance(v, (bytes, bytearray)) else v,
)
if cfg.files:
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
file_iter = (
cfg.files if isinstance(cfg.files, list) else cfg.files.items()
)
for field_name, file_obj in file_iter:
if file_obj is None:
continue
@ -660,9 +747,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if isinstance(file_value, BytesIO):
with contextlib.suppress(Exception):
file_value.seek(0)
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
form.add_field(
field_name,
file_value,
filename=filename,
content_type=content_type,
)
payload_kw["data"] = form
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
elif (
cfg.content_type == "application/x-www-form-urlencoded"
and method != "GET"
):
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
payload_kw["data"] = cfg.data or {}
elif method != "GET":
@ -685,7 +780,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
tasks = {req_task}
if monitor_task:
tasks.add(monitor_task)
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
if monitor_task and monitor_task in done:
# Interrupted cancel the request and abort
@ -705,7 +802,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
wait_time = 0.0
retry_label = ""
is_rl = resp.status == 429 or (
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
cfg.is_rate_limited is not None
and cfg.is_rate_limited(resp.status, body)
)
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
rate_limit_attempts += 1
@ -713,7 +811,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
rate_limit_delay *= cfg.retry_backoff
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
should_retry = True
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
elif (
resp.status in _RETRY_STATUS
and (attempt - rate_limit_attempts) <= cfg.max_retries
):
wait_time = delay
delay *= cfg.retry_backoff
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
@ -743,7 +844,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
display_callback=_display_time_progress
if cfg.monitor_progress
else None,
)
continue
msg = _friendly_http_message(resp.status, body)
@ -770,7 +873,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
raise ProcessingInterrupted("Task cancelled")
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
cfg.node_cls,
cfg.wait_label,
int(now - start_time),
cfg.estimated_total,
)
bytes_payload = bytes(buff)
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
@ -800,9 +906,15 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload = json.loads(text) if text else {}
except json.JSONDecodeError:
payload = {"_raw": text}
response_content_to_log = payload if isinstance(payload, dict) else text
response_content_to_log = (
payload if isinstance(payload, dict) else text
)
with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
extracted_price = (
cfg.price_extractor(payload)
if cfg.price_extractor
else None
)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response(
@ -844,7 +956,9 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
display_callback=_display_time_progress
if cfg.monitor_progress
else None,
)
delay *= cfg.retry_backoff
continue
@ -885,7 +999,11 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
if sess:
with contextlib.suppress(Exception):
await sess.close()
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
if (
operation_succeeded
and cfg.monitor_progress
and cfg.final_label_on_success
):
_display_time_progress(
cfg.node_cls,
status=cfg.final_label_on_success,

View File

@ -129,38 +129,22 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
return img_byte_arr
def _compute_downscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[int, int] | None:
"""Return downscaled (w, h) with even dims fitting ``total_pixels``, or None if already fits.
Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions
are rounded down to even values (many codecs require divisible-by-2).
"""
pixels = src_w * src_h
if pixels <= total_pixels:
return None
scale = math.sqrt(total_pixels / pixels)
new_w = max(2, int(src_w * scale))
new_h = max(2, int(src_h * scale))
new_w -= new_w % 2
new_h -= new_h % 2
return new_w, new_h
def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor:
"""Downscale input image tensor to roughly the specified total pixels.
Output dimensions are rounded down to even values so that the result is guaranteed to fit within ``total_pixels``
and is compatible with codecs that require even dimensions (e.g. yuv420p).
"""
"""Downscale input image tensor to roughly the specified total pixels."""
samples = image.movedim(-1, 1)
dims = _compute_downscale_dims(samples.shape[3], samples.shape[2], int(total_pixels))
if dims is None:
total = int(total_pixels)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
if scale_by >= 1:
return image
new_w, new_h = dims
return common_upscale(samples, new_w, new_h, "lanczos", "disabled").movedim(1, -1)
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = common_upscale(samples, width, height, "lanczos", "disabled")
s = s.movedim(1, -1)
return s
def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor:
def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor:
"""Downscale input image tensor so the largest dimension is at most max_side pixels."""
samples = image.movedim(-1, 1)
height, width = samples.shape[2], samples.shape[3]
@ -415,72 +399,6 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
def resize_video_to_pixel_budget(video: Input.Video, total_pixels: int) -> Input.Video:
"""Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio.
Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio.
Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
"""
src_w, src_h = video.get_dimensions()
scale_dims = _compute_downscale_dims(src_w, src_h, total_pixels)
if scale_dims is None:
return video
return _apply_video_scale(video, scale_dims)
def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input.Video:
"""Re-encode ``video`` scaled to ``scale_dims`` with a single decode/encode pass."""
out_w, out_h = scale_dims
output_buffer = BytesIO()
input_container = None
output_container = None
try:
input_source = video.get_stream_source()
input_container = av.open(input_source, mode="r")
output_container = av.open(output_buffer, mode="w", format="mp4")
video_stream = output_container.add_stream("h264", rate=video.get_frame_rate())
video_stream.width = out_w
video_stream.height = out_h
video_stream.pix_fmt = "yuv420p"
audio_stream = None
for stream in input_container.streams:
if isinstance(stream, av.AudioStream):
audio_stream = output_container.add_stream("aac", rate=stream.sample_rate)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
break
for frame in input_container.decode(video=0):
frame = frame.reformat(width=out_w, height=out_h, format="yuv420p")
for packet in video_stream.encode(frame):
output_container.mux(packet)
for packet in video_stream.encode():
output_container.mux(packet)
if audio_stream is not None:
input_container.seek(0)
for audio_frame in input_container.decode(audio=0):
for packet in audio_stream.encode(audio_frame):
output_container.mux(packet)
for packet in audio_stream.encode():
output_container.mux(packet)
output_container.close()
input_container.close()
output_buffer.seek(0)
return InputImpl.VideoFromFile(output_buffer)
except Exception as e:
if input_container is not None:
input_container.close()
if output_container is not None:
output_container.close()
raise RuntimeError(f"Failed to resize video: {str(e)}") from e
def _f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
if wav.dtype.is_floating_point:

View File

@ -22,7 +22,7 @@ from ._helpers import (
sleep_with_interrupt,
to_aiohttp_url,
)
from .client import _diagnose_connectivity
from .client import RETRY_DEFAULTS, _diagnose_connectivity
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
from .conversions import bytesio_to_image_tensor
@ -34,9 +34,9 @@ async def download_url_to_bytesio(
dest: BytesIO | IO[bytes] | str | Path | None,
*,
timeout: float | None = None,
max_retries: int = 5,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
max_retries: int = max(5, RETRY_DEFAULTS.max_retries),
retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
cls: type[COMFY_IO.ComfyNode] = None,
) -> None:
"""Stream-download a URL to `dest`.
@ -53,7 +53,9 @@ async def download_url_to_bytesio(
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors)
"""
if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"):
raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().")
raise ValueError(
"dest must be a path (str|Path) or a binary-writable object providing .write()."
)
attempt = 0
delay = retry_delay
@ -62,7 +64,9 @@ async def download_url_to_bytesio(
parsed_url = urlparse(url)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
if cls is None:
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
raise ValueError(
"For relative 'cloud' paths, the `cls` parameter is required."
)
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
headers = get_auth_header(cls)
@ -80,7 +84,9 @@ async def download_url_to_bytesio(
try:
with contextlib.suppress(Exception):
request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url)
request_logger.log_request_response(
operation_id=op_id, request_method="GET", request_url=url
)
session = aiohttp.ClientSession(timeout=timeout_cfg)
stop_evt = asyncio.Event()
@ -96,8 +102,12 @@ async def download_url_to_bytesio(
monitor_task = asyncio.create_task(_monitor())
req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers))
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
req_task = asyncio.create_task(
session.get(to_aiohttp_url(url), headers=headers)
)
done, pending = await asyncio.wait(
{req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED
)
if monitor_task in done and req_task in pending:
req_task.cancel()
@ -117,7 +127,11 @@ async def download_url_to_bytesio(
body = await resp.json()
except (ContentTypeError, ValueError):
text = await resp.text()
body = text if len(text) <= 4096 else f"[text {len(text)} bytes]"
body = (
text
if len(text) <= 4096
else f"[text {len(text)} bytes]"
)
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
@ -146,7 +160,9 @@ async def download_url_to_bytesio(
written = 0
while True:
try:
chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0)
chunk = await asyncio.wait_for(
resp.content.read(1024 * 1024), timeout=1.0
)
except asyncio.TimeoutError:
chunk = b""
except asyncio.CancelledError:
@ -195,7 +211,9 @@ async def download_url_to_bytesio(
raise LocalNetworkError(
"Unable to connect to the network. Please check your internet connection and try again."
) from e
raise ApiServerError("The remote service appears unreachable at this time.") from e
raise ApiServerError(
"The remote service appears unreachable at this time."
) from e
finally:
if stop_evt is not None:
stop_evt.set()
@ -237,7 +255,9 @@ async def download_url_to_video_output(
) -> InputImpl.VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO()
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
await download_url_to_bytesio(
video_url, result, timeout=timeout, max_retries=max_retries, cls=cls
)
return InputImpl.VideoFromFile(result)
@ -256,7 +276,11 @@ async def download_url_as_bytesio(
def _generate_operation_id(method: str, url: str, attempt: int) -> str:
try:
parsed = urlparse(url)
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_")
slug = (
(parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download")
.strip("/")
.replace("/", "_")
)
except Exception:
slug = "download"
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"

View File

@ -15,6 +15,7 @@ from comfy_api.latest import IO, Input, Types
from . import request_logger
from ._helpers import is_processing_interrupted, sleep_with_interrupt
from .client import (
RETRY_DEFAULTS,
ApiEndpoint,
_diagnose_connectivity,
_display_time_progress,
@ -77,13 +78,17 @@ async def upload_images_to_comfyapi(
for idx in range(num_to_upload):
tensor = tensors[idx]
img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type)
img_io = tensor_to_bytesio(
tensor, total_pixels=total_pixels, mime_type=mime_type
)
effective_label = wait_label
if wait_label and show_batch_index and num_to_upload > 1:
effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})"
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts)
url = await upload_file_to_comfyapi(
cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts
)
download_urls.append(url)
return download_urls
@ -125,8 +130,12 @@ async def upload_audio_to_comfyapi(
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
return await upload_file_to_comfyapi(
cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type
)
async def upload_video_to_comfyapi(
@ -161,7 +170,9 @@ async def upload_video_to_comfyapi(
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
return await upload_file_to_comfyapi(
cls, video_bytes_io, filename, upload_mime_type, wait_label
)
_3D_MIME_TYPES = {
@ -197,7 +208,9 @@ async def upload_file_to_comfyapi(
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
request_object = UploadRequest(
file_name=filename, content_type=upload_mime_type
)
create_resp = await sync_op(
cls,
endpoint=ApiEndpoint(path="/customers/storage", method="POST"),
@ -223,9 +236,9 @@ async def upload_file(
file: BytesIO | str,
*,
content_type: str | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
max_retries: int = RETRY_DEFAULTS.max_retries,
retry_delay: float = RETRY_DEFAULTS.retry_delay,
retry_backoff: float = RETRY_DEFAULTS.retry_backoff,
wait_label: str | None = None,
progress_origin_ts: float | None = None,
) -> None:
@ -250,11 +263,15 @@ async def upload_file(
if content_type:
headers["Content-Type"] = content_type
else:
skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request
skip_auto_headers.add(
"Content-Type"
) # Don't let aiohttp add Content-Type, it can break the signed request
attempt = 0
delay = retry_delay
start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic()
start_ts = (
progress_origin_ts if progress_origin_ts is not None else time.monotonic()
)
op_uuid = uuid.uuid4().hex[:8]
while True:
attempt += 1
@ -268,7 +285,9 @@ async def upload_file(
if is_processing_interrupted():
return
if wait_label:
_display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None)
_display_time_progress(
cls, wait_label, int(time.monotonic() - start_ts), None
)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return
@ -286,10 +305,17 @@ async def upload_file(
)
sess = aiohttp.ClientSession(timeout=timeout)
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
req = sess.put(
upload_url,
data=data,
headers=headers,
skip_auto_headers=skip_auto_headers,
)
req_task = asyncio.create_task(req)
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait(
{req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED
)
if monitor_task in done and req_task in pending:
req_task.cancel()
@ -317,14 +343,19 @@ async def upload_file(
response_content=body,
error_message=msg,
)
if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries:
if (
resp.status in {408, 429, 500, 502, 503, 504}
and attempt <= max_retries
):
await sleep_with_interrupt(
delay,
cls,
wait_label,
start_ts,
None,
display_callback=_display_time_progress if wait_label else None,
display_callback=_display_time_progress
if wait_label
else None,
)
delay *= retry_backoff
continue
@ -366,7 +397,9 @@ async def upload_file(
raise LocalNetworkError(
"Unable to connect to the network. Please check your internet connection and try again."
) from e
raise ApiServerError("The API service appears unreachable at this time.") from e
raise ApiServerError(
"The API service appears unreachable at this time."
) from e
finally:
stop_evt.set()
if monitor_task:
@ -381,7 +414,11 @@ async def upload_file(
def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str:
try:
parsed = urlparse(url)
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_")
slug = (
(parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload")
.strip("/")
.replace("/", "_")
)
except Exception:
slug = "upload"
return f"{method}_{slug}_{op_uuid}_try{attempt}"

View File

@ -3,136 +3,136 @@ from typing_extensions import override
import comfy.model_management
import node_helpers
from comfy_api.latest import ComfyExtension, IO
from comfy_api.latest import ComfyExtension, io
class TextEncodeAceStepAudio(IO.ComfyNode):
class TextEncodeAceStepAudio(io.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
return io.Schema(
node_id="TextEncodeAceStepAudio",
category="conditioning",
inputs=[
IO.Clip.Input("clip"),
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
IO.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
io.Clip.Input("clip"),
io.String.Input("tags", multiline=True, dynamic_prompts=True),
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[IO.Conditioning.Output()],
outputs=[io.Conditioning.Output()],
)
@classmethod
def execute(cls, clip, tags, lyrics, lyrics_strength) -> IO.NodeOutput:
def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics)
conditioning = clip.encode_from_tokens_scheduled(tokens)
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
return IO.NodeOutput(conditioning)
return io.NodeOutput(conditioning)
class TextEncodeAceStepAudio15(IO.ComfyNode):
class TextEncodeAceStepAudio15(io.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
return io.Schema(
node_id="TextEncodeAceStepAudio1.5",
category="conditioning",
inputs=[
IO.Clip.Input("clip"),
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
IO.Int.Input("bpm", default=120, min=10, max=300),
IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
IO.Combo.Input("timesignature", options=['2', '3', '4', '6']),
IO.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
IO.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
IO.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
IO.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
IO.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
io.Clip.Input("clip"),
io.String.Input("tags", multiline=True, dynamic_prompts=True),
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
io.Int.Input("bpm", default=120, min=10, max=300),
io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
],
outputs=[IO.Conditioning.Output()],
outputs=[io.Conditioning.Output()],
)
@classmethod
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> IO.NodeOutput:
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
conditioning = clip.encode_from_tokens_scheduled(tokens)
return IO.NodeOutput(conditioning)
return io.NodeOutput(conditioning)
class EmptyAceStepLatentAudio(IO.ComfyNode):
class EmptyAceStepLatentAudio(io.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
return io.Schema(
node_id="EmptyAceStepLatentAudio",
display_name="Empty Ace Step 1.0 Latent Audio",
category="latent/audio",
inputs=[
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
IO.Int.Input(
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
io.Int.Input(
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
),
],
outputs=[IO.Latent.Output()],
outputs=[io.Latent.Output()],
)
@classmethod
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
def execute(cls, seconds, batch_size) -> io.NodeOutput:
length = int(seconds * 44100 / 512 / 8)
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return IO.NodeOutput({"samples": latent, "type": "audio"})
return io.NodeOutput({"samples": latent, "type": "audio"})
class EmptyAceStep15LatentAudio(IO.ComfyNode):
class EmptyAceStep15LatentAudio(io.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
return io.Schema(
node_id="EmptyAceStep1.5LatentAudio",
display_name="Empty Ace Step 1.5 Latent Audio",
category="latent/audio",
inputs=[
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
IO.Int.Input(
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
io.Int.Input(
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
),
],
outputs=[IO.Latent.Output()],
outputs=[io.Latent.Output()],
)
@classmethod
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
def execute(cls, seconds, batch_size) -> io.NodeOutput:
length = round((seconds * 48000 / 1920))
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return IO.NodeOutput({"samples": latent, "type": "audio"})
return io.NodeOutput({"samples": latent, "type": "audio"})
class ReferenceAudio(IO.ComfyNode):
class ReferenceAudio(io.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
return io.Schema(
node_id="ReferenceTimbreAudio",
display_name="Reference Audio",
category="advanced/conditioning/audio",
is_experimental=True,
description="This node sets the reference audio for ace step 1.5",
inputs=[
IO.Conditioning.Input("conditioning"),
IO.Latent.Input("latent", optional=True),
io.Conditioning.Input("conditioning"),
io.Latent.Input("latent", optional=True),
],
outputs=[
IO.Conditioning.Output(),
io.Conditioning.Output(),
]
)
@classmethod
def execute(cls, conditioning, latent=None) -> IO.NodeOutput:
def execute(cls, conditioning, latent=None) -> io.NodeOutput:
if latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True)
return IO.NodeOutput(conditioning)
return io.NodeOutput(conditioning)
class AceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeAceStepAudio,
EmptyAceStepLatentAudio,

View File

@ -104,7 +104,7 @@ def vae_decode_audio(vae, samples, tile=None, overlap=None):
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100))
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}

View File

@ -3,8 +3,9 @@ import comfy.utils
import comfy.model_management
import torch
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_audio import VAEEncodeAudio
class LTXVAudioVAELoader(io.ComfyNode):
@classmethod
@ -27,14 +28,10 @@ class LTXVAudioVAELoader(io.ComfyNode):
def execute(cls, ckpt_name: str) -> io.NodeOutput:
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder.", "vocoder.": "vocoder."}, filter_keys=True)
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
vae.throw_exception_if_invalid()
return io.NodeOutput(vae)
return io.NodeOutput(AudioVAE(sd, metadata))
class LTXVAudioVAEEncode(VAEEncodeAudio):
class LTXVAudioVAEEncode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
@ -53,8 +50,15 @@ class LTXVAudioVAEEncode(VAEEncodeAudio):
)
@classmethod
def execute(cls, audio, audio_vae) -> io.NodeOutput:
return super().execute(audio_vae, audio)
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput:
audio_latents = audio_vae.encode(audio)
return io.NodeOutput(
{
"samples": audio_latents,
"sample_rate": int(audio_vae.sample_rate),
"type": "audio",
}
)
class LTXVAudioVAEDecode(io.ComfyNode):
@ -76,12 +80,12 @@ class LTXVAudioVAEDecode(io.ComfyNode):
)
@classmethod
def execute(cls, samples, audio_vae) -> io.NodeOutput:
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput:
audio_latent = samples["samples"]
if audio_latent.is_nested:
audio_latent = audio_latent.unbind()[-1]
audio = audio_vae.decode(audio_latent).movedim(-1, 1).to(audio_latent.device)
output_audio_sample_rate = audio_vae.first_stage_model.output_sample_rate
audio = audio_vae.decode(audio_latent).to(audio_latent.device)
output_audio_sample_rate = audio_vae.output_sample_rate
return io.NodeOutput(
{
"waveform": audio,
@ -139,17 +143,17 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
frames_number: int,
frame_rate: int,
batch_size: int,
audio_vae,
audio_vae: AudioVAE,
) -> io.NodeOutput:
"""Generate empty audio latents matching the reference pipeline structure."""
assert audio_vae is not None, "Audio VAE model is required"
z_channels = audio_vae.latent_channels
audio_freq = audio_vae.first_stage_model.latent_frequency_bins
sampling_rate = int(audio_vae.first_stage_model.sample_rate)
audio_freq = audio_vae.latent_frequency_bins
sampling_rate = int(audio_vae.sample_rate)
num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate)
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate)
audio_latents = torch.zeros(
(batch_size, z_channels, num_audio_latents, audio_freq),

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.42.14
comfyui-frontend-package==1.42.11
comfyui-workflow-templates==0.9.57
comfyui-embedded-docs==0.4.3
torch
@ -19,7 +19,7 @@ scipy
tqdm
psutil
alembic
SQLAlchemy>=2.0
SQLAlchemy
filelock
av>=14.2.0
comfy-kitchen>=0.2.8

View File

@ -1,246 +0,0 @@
import pytest
from comfy_api_nodes.nodes_openai import (
OpenAIGPTImage1,
OpenAIGPTImage2,
_GPT_IMAGE_2_SIZES,
_resolve_gpt_image_2_size,
calculate_tokens_price_image_1,
calculate_tokens_price_image_1_5,
calculate_tokens_price_image_2,
)
from comfy_api_nodes.apis.openai import OpenAIImageGenerationResponse, Usage
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_response(input_tokens: int, output_tokens: int) -> OpenAIImageGenerationResponse:
return OpenAIImageGenerationResponse(
data=[],
usage=Usage(input_tokens=input_tokens, output_tokens=output_tokens),
)
# ---------------------------------------------------------------------------
# Price extractor tests
# ---------------------------------------------------------------------------
def test_price_image_1_formula():
response = _make_response(input_tokens=1_000_000, output_tokens=1_000_000)
assert calculate_tokens_price_image_1(response) == pytest.approx(50.0)
def test_price_image_1_5_formula():
response = _make_response(input_tokens=1_000_000, output_tokens=1_000_000)
assert calculate_tokens_price_image_1_5(response) == pytest.approx(40.0)
def test_price_image_2_formula():
response = _make_response(input_tokens=1_000_000, output_tokens=1_000_000)
assert calculate_tokens_price_image_2(response) == pytest.approx(38.0)
def test_price_image_2_cheaper_than_1():
response = _make_response(input_tokens=500, output_tokens=196)
assert calculate_tokens_price_image_2(response) < calculate_tokens_price_image_1(response)
def test_price_image_2_cheaper_output_than_1_5():
# gpt-image-2 output rate ($30/1M) is lower than gpt-image-1.5 ($32/1M)
response = _make_response(input_tokens=0, output_tokens=1_000_000)
assert calculate_tokens_price_image_2(response) < calculate_tokens_price_image_1_5(response)
# ---------------------------------------------------------------------------
# _resolve_gpt_image_2_size tests
# ---------------------------------------------------------------------------
def test_resolve_preset_passthrough_when_custom_zero():
# 0/0 means "use size preset"
assert _resolve_gpt_image_2_size("1024x1024", 0, 0) == "1024x1024"
assert _resolve_gpt_image_2_size("auto", 0, 0) == "auto"
assert _resolve_gpt_image_2_size("3840x2160", 0, 0) == "3840x2160"
def test_resolve_preset_passthrough_when_only_one_dim_set():
# only one dimension set → still use preset
assert _resolve_gpt_image_2_size("auto", 1024, 0) == "auto"
assert _resolve_gpt_image_2_size("auto", 0, 1024) == "auto"
def test_resolve_custom_overrides_preset():
assert _resolve_gpt_image_2_size("auto", 1024, 1024) == "1024x1024"
assert _resolve_gpt_image_2_size("1024x1024", 2048, 1152) == "2048x1152"
assert _resolve_gpt_image_2_size("auto", 3840, 2160) == "3840x2160"
def test_resolve_custom_rejects_edge_too_large():
with pytest.raises(ValueError, match="3840"):
_resolve_gpt_image_2_size("auto", 4096, 1024)
def test_resolve_custom_rejects_non_multiple_of_16():
with pytest.raises(ValueError, match="multiple of 16"):
_resolve_gpt_image_2_size("auto", 1025, 1024)
def test_resolve_custom_rejects_bad_ratio():
with pytest.raises(ValueError, match="ratio"):
_resolve_gpt_image_2_size("auto", 3840, 1024) # 3.75:1 > 3:1
def test_resolve_custom_rejects_too_few_pixels():
with pytest.raises(ValueError, match="Total pixels"):
_resolve_gpt_image_2_size("auto", 16, 16)
def test_resolve_custom_rejects_too_many_pixels():
# 3840x2176 exceeds 8,294,400
with pytest.raises(ValueError, match="Total pixels"):
_resolve_gpt_image_2_size("auto", 3840, 2176)
# ---------------------------------------------------------------------------
# OpenAIGPTImage1 schema tests
# ---------------------------------------------------------------------------
class TestOpenAIGPTImage1Schema:
def setup_method(self):
self.schema = OpenAIGPTImage1.define_schema()
def test_node_id(self):
assert self.schema.node_id == "OpenAIGPTImage1"
def test_display_name(self):
assert self.schema.display_name == "OpenAI GPT Image 1 & 1.5"
def test_model_options_exclude_gpt_image_2(self):
model_input = next(i for i in self.schema.inputs if i.name == "model")
assert "gpt-image-2" not in model_input.options
def test_model_options_include_legacy_models(self):
model_input = next(i for i in self.schema.inputs if i.name == "model")
assert "gpt-image-1" in model_input.options
assert "gpt-image-1.5" in model_input.options
def test_has_background_with_transparent(self):
bg_input = next(i for i in self.schema.inputs if i.name == "background")
assert "transparent" in bg_input.options
# ---------------------------------------------------------------------------
# OpenAIGPTImage2 schema tests
# ---------------------------------------------------------------------------
class TestOpenAIGPTImage2Schema:
def setup_method(self):
self.schema = OpenAIGPTImage2.define_schema()
def test_node_id(self):
assert self.schema.node_id == "OpenAIGPTImage2"
def test_display_name(self):
assert self.schema.display_name == "OpenAI GPT Image 2"
def test_category(self):
assert "OpenAI" in self.schema.category
def test_no_transparent_background(self):
bg_input = next(i for i in self.schema.inputs if i.name == "background")
assert "transparent" not in bg_input.options
def test_background_options(self):
bg_input = next(i for i in self.schema.inputs if i.name == "background")
assert set(bg_input.options) == {"auto", "opaque"}
def test_quality_options(self):
quality_input = next(i for i in self.schema.inputs if i.name == "quality")
assert set(quality_input.options) == {"auto", "low", "medium", "high"}
def test_quality_default_is_auto(self):
quality_input = next(i for i in self.schema.inputs if i.name == "quality")
assert quality_input.default == "auto"
def test_all_popular_sizes_present(self):
size_input = next(i for i in self.schema.inputs if i.name == "size")
for size in ["1024x1024", "1536x1024", "1024x1536", "2048x2048", "2048x1152", "3840x2160", "2160x3840"]:
assert size in size_input.options, f"Missing size: {size}"
def test_no_custom_size_option(self):
size_input = next(i for i in self.schema.inputs if i.name == "size")
assert "custom" not in size_input.options
def test_size_default_is_auto(self):
size_input = next(i for i in self.schema.inputs if i.name == "size")
assert size_input.default == "auto"
def test_custom_width_and_height_inputs_exist(self):
input_names = [i.name for i in self.schema.inputs]
assert "custom_width" in input_names
assert "custom_height" in input_names
def test_custom_width_height_default_zero(self):
width_input = next(i for i in self.schema.inputs if i.name == "custom_width")
height_input = next(i for i in self.schema.inputs if i.name == "custom_height")
assert width_input.default == 0
assert height_input.default == 0
def test_custom_width_height_step_is_16(self):
width_input = next(i for i in self.schema.inputs if i.name == "custom_width")
height_input = next(i for i in self.schema.inputs if i.name == "custom_height")
assert width_input.step == 16
assert height_input.step == 16
def test_custom_width_height_max_is_3840(self):
width_input = next(i for i in self.schema.inputs if i.name == "custom_width")
height_input = next(i for i in self.schema.inputs if i.name == "custom_height")
assert width_input.max == 3840
assert height_input.max == 3840
def test_uses_num_images_not_n(self):
input_names = [i.name for i in self.schema.inputs]
assert "num_images" in input_names
assert "n" not in input_names
def test_model_input_shows_gpt_image_2(self):
model_input = next(i for i in self.schema.inputs if i.name == "model")
assert model_input.options == ["gpt-image-2"]
assert model_input.default == "gpt-image-2"
def test_has_image_and_mask_inputs(self):
input_names = [i.name for i in self.schema.inputs]
assert "image" in input_names
assert "mask" in input_names
def test_is_api_node(self):
assert self.schema.is_api_node is True
def test_sizes_match_constant(self):
size_input = next(i for i in self.schema.inputs if i.name == "size")
assert size_input.options == _GPT_IMAGE_2_SIZES
# ---------------------------------------------------------------------------
# OpenAIGPTImage2 execute validation tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_execute_raises_on_empty_prompt():
with pytest.raises(Exception):
await OpenAIGPTImage2.execute(prompt=" ")
@pytest.mark.asyncio
async def test_execute_raises_mask_without_image():
import torch
mask = torch.ones(1, 64, 64)
with pytest.raises(ValueError, match="mask without an input image"):
await OpenAIGPTImage2.execute(prompt="test", mask=mask)
@pytest.mark.asyncio
async def test_execute_raises_invalid_custom_size():
with pytest.raises(ValueError):
await OpenAIGPTImage2.execute(prompt="test", custom_width=4096, custom_height=1024)

View File

@ -0,0 +1,94 @@
"""Tests for configurable retry defaults via environment variables.
Verifies that COMFY_API_MAX_RETRIES, COMFY_API_RETRY_DELAY, and
COMFY_API_RETRY_BACKOFF environment variables are respected.
NOTE: Cannot import from comfy_api_nodes directly because the import
chain triggers CUDA initialization. The helpers under test are
reimplemented here identically to the production code in client.py.
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from unittest.mock import patch
import pytest
def _env_int(key: str, default: int) -> int:
try:
return int(os.environ[key])
except (KeyError, ValueError):
return default
def _env_float(key: str, default: float) -> float:
try:
return float(os.environ[key])
except (KeyError, ValueError):
return default
@dataclass(frozen=True)
class _RetryDefaults:
max_retries: int = _env_int("COMFY_API_MAX_RETRIES", 3)
retry_delay: float = _env_float("COMFY_API_RETRY_DELAY", 1.0)
retry_backoff: float = _env_float("COMFY_API_RETRY_BACKOFF", 2.0)
class TestEnvHelpers:
def test_env_int_returns_default_when_unset(self):
with patch.dict(os.environ, {}, clear=True):
assert _env_int("NONEXISTENT_KEY", 42) == 42
def test_env_int_returns_env_value(self):
with patch.dict(os.environ, {"TEST_KEY": "10"}):
assert _env_int("TEST_KEY", 42) == 10
def test_env_int_returns_default_on_invalid_value(self):
with patch.dict(os.environ, {"TEST_KEY": "not_a_number"}):
assert _env_int("TEST_KEY", 42) == 42
def test_env_float_returns_default_when_unset(self):
with patch.dict(os.environ, {}, clear=True):
assert _env_float("NONEXISTENT_KEY", 1.5) == 1.5
def test_env_float_returns_env_value(self):
with patch.dict(os.environ, {"TEST_KEY": "2.5"}):
assert _env_float("TEST_KEY", 1.5) == 2.5
def test_env_float_returns_default_on_invalid_value(self):
with patch.dict(os.environ, {"TEST_KEY": "bad"}):
assert _env_float("TEST_KEY", 1.5) == 1.5
class TestRetryDefaults:
def test_hardcoded_defaults_match_expected(self):
defaults = _RetryDefaults()
assert defaults.max_retries == 3
assert defaults.retry_delay == 1.0
assert defaults.retry_backoff == 2.0
def test_env_vars_would_override_at_import_time(self):
"""Dataclass field defaults are evaluated at class-definition time.
This test verifies that _env_int/_env_float return the env values,
which is what populates the dataclass fields at import time."""
with patch.dict(os.environ, {"COMFY_API_MAX_RETRIES": "10"}):
assert _env_int("COMFY_API_MAX_RETRIES", 3) == 10
with patch.dict(os.environ, {"COMFY_API_RETRY_DELAY": "3.0"}):
assert _env_float("COMFY_API_RETRY_DELAY", 1.0) == 3.0
with patch.dict(os.environ, {"COMFY_API_RETRY_BACKOFF": "1.5"}):
assert _env_float("COMFY_API_RETRY_BACKOFF", 2.0) == 1.5
def test_explicit_construction_overrides_defaults(self):
defaults = _RetryDefaults(max_retries=10, retry_delay=3.0, retry_backoff=1.5)
assert defaults.max_retries == 10
assert defaults.retry_delay == 3.0
assert defaults.retry_backoff == 1.5
def test_frozen_dataclass(self):
defaults = _RetryDefaults()
with pytest.raises(AttributeError):
defaults.max_retries = 999