Compare commits

..

9 Commits

Author SHA1 Message Date
e9f8aa9346 feat(api-nodes): add dedicated OpenAI GPT-Image-2 node
- Add `OpenAIGPTImage2` node (`node_id: OpenAIGPTImage2`) with settings
  specific to gpt-image-2: quality auto/low/medium/high, background
  auto/opaque (transparent not supported), all 8 popular size presets,
  and custom width/height inputs (step=16, max=3840) that override the
  size preset when both are non-zero
- Add `_resolve_gpt_image_2_size` helper that enforces API constraints:
  max edge ≤ 3840px, multiples of 16, ratio ≤ 3:1, total pixels
  655,360–8,294,400
- Add `calculate_tokens_price_image_2` using correct gpt-image-2 rates
  ($8/1M input, $30/1M output); price badge shows range per quality
  tier with approximate flag for auto quality
- Rename `OpenAIGPTImage1` display name to "OpenAI GPT Image 1 & 1.5",
  remove gpt-image-2 from its model dropdown, and update its price badge
  to be model-aware with correct per-model ranges
- Add unit tests covering price formulas, size resolution logic, and
  schema correctness for both nodes

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-22 05:23:17 +02:00
43a1263b60 Add gpt-image-2 as version option (#13501) 2026-04-21 17:58:59 -07:00
102773cd2c Bump comfyui-frontend-package to 1.42.14 (#13493) 2026-04-21 11:35:45 -07:00
1e1d4f1254 [Partner Nodes] added 4K resolution for Veo models; added Veo 3 Lite model (#13330)
* feat(api nodes): added 4K resolution for Veo models; added Veo 3 Lite model

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* increase poll_interval from 5 to 9

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-04-21 11:27:35 -07:00
eb22225387 Support standalone LTXV audio VAEs (#13499) 2026-04-21 10:46:37 -07:00
b38dd0ff23 feat(api-nodes): add automatic downscaling of videos for ByteDance 2 nodes (#13465) 2026-04-21 10:45:10 -07:00
ad94d47221 Make the ltx audio vae more native. (#13486) 2026-04-21 11:02:42 -04:00
e75f775ae8 Bump comfyui-frontend-package to 1.42.12 (#13489) 2026-04-21 00:43:11 -07:00
c514890325 Refactor io to IO in nodes_ace.py (#13485) 2026-04-20 21:59:26 -04:00
14 changed files with 879 additions and 214 deletions

View File

@ -4,9 +4,6 @@ import math
import torch
import torchaudio
import comfy.model_management
import comfy.model_patcher
import comfy.utils as utils
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
@ -43,30 +40,6 @@ class AudioVAEComponentConfig:
return cls(autoencoder=audio_config, vocoder=vocoder_config)
class ModelDeviceManager:
"""Manages device placement and GPU residency for the composed model."""
def __init__(self, module: torch.nn.Module):
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.vae_offload_device()
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
def ensure_model_loaded(self) -> None:
comfy.model_management.free_memory(
self.patcher.model_size(),
self.patcher.load_device,
)
comfy.model_management.load_model_gpu(self.patcher)
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self.patcher.load_device)
@property
def load_device(self):
return self.patcher.load_device
class AudioLatentNormalizer:
"""Applies per-channel statistics in patch space and restores original layout."""
@ -132,23 +105,17 @@ class AudioPreprocessor:
class AudioVAE(torch.nn.Module):
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
def __init__(self, state_dict: dict, metadata: dict):
def __init__(self, metadata: dict):
super().__init__()
component_config = AudioVAEComponentConfig.from_metadata(metadata)
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
if "bwe" in component_config.vocoder:
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
else:
self.vocoder = Vocoder(config=component_config.vocoder)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)
autoencoder_config = self.autoencoder.get_config()
self.normalizer = AudioLatentNormalizer(
AudioPatchifier(
@ -168,18 +135,12 @@ class AudioVAE(torch.nn.Module):
n_fft=autoencoder_config["n_fft"],
)
self.device_manager = ModelDeviceManager(self)
def encode(self, audio: dict) -> torch.Tensor:
def encode(self, audio, sample_rate=44100) -> torch.Tensor:
"""Encode a waveform dictionary into normalized latent tensors."""
waveform = audio["waveform"]
waveform_sample_rate = audio["sample_rate"]
waveform = audio
waveform_sample_rate = sample_rate
input_device = waveform.device
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
waveform = self.device_manager.move_to_load_device(waveform)
expected_channels = self.autoencoder.encoder.in_channels
if waveform.shape[1] != expected_channels:
if waveform.shape[1] == 1:
@ -190,7 +151,7 @@ class AudioVAE(torch.nn.Module):
)
mel_spec = self.preprocessor.waveform_to_mel(
waveform, waveform_sample_rate, device=self.device_manager.load_device
waveform, waveform_sample_rate, device=waveform.device
)
latents = self.autoencoder.encode(mel_spec)
@ -204,17 +165,13 @@ class AudioVAE(torch.nn.Module):
"""Decode normalized latent tensors into an audio waveform."""
original_shape = latents.shape
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
latents = self.device_manager.move_to_load_device(latents)
latents = self.normalizer.denormalize(latents)
target_shape = self.target_shape_from_latents(original_shape)
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
waveform = self.run_vocoder(mel_spec)
return self.device_manager.move_to_load_device(waveform)
return waveform
def target_shape_from_latents(self, latents_shape):
batch, _, time, _ = latents_shape

View File

@ -12,6 +12,7 @@ from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.lightricks.vae.audio_vae
import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
@ -805,6 +806,24 @@ class VAE:
self.downscale_index_formula = (4, 8, 8)
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio
sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder."})
self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = self.first_stage_model.latent_channels
self.audio_sample_rate_output = self.first_stage_model.output_sample_rate
self.autoencoder = self.first_stage_model.autoencoder # TODO: remove hack for ltxv custom nodes
self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 4096
self.downscale_ratio = 4096
self.latent_dim = 2
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None

View File

@ -15,7 +15,6 @@ from comfy_execution.progress import get_progress_state, PreviewImageTuple
from PIL import Image
from comfy.cli_args import args
import numpy as np
import os
class ComfyAPI_latest(ComfyAPIBase):
@ -26,7 +25,6 @@ class ComfyAPI_latest(ComfyAPIBase):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
self.environment = self.Environment()
self.caching = self.Caching()
class NodeReplacement(ProxiedSingleton):
@ -87,27 +85,6 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
class Environment(ProxiedSingleton):
"""
Query the current execution environment.
Managed deployments set the ``COMFY_EXECUTION_ENVIRONMENT`` env var
so custom nodes can adapt their behaviour at runtime.
Example::
from comfy_api.latest import api
env = api.environment.get() # "local" | "cloud" | "remote"
"""
_VALID = {"local", "cloud", "remote"}
async def get(self) -> str:
"""Return the execution environment: ``"local"``, ``"cloud"``, or ``"remote"``."""
value = os.environ.get("COMFY_EXECUTION_ENVIRONMENT", "local").lower().strip()
return value if value in self._VALID else "local"
class Caching(ProxiedSingleton):
"""
External cache provider API for sharing cached node outputs

View File

@ -158,10 +158,17 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
("Custom", None, None),
]
# Seedance 2.0 reference video pixel count limits per model.
# Seedance 2.0 reference video pixel count limits per model and output resolution.
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
"dreamina-seedance-2-0-260128": {"min": 409_600, "max": 927_408},
"dreamina-seedance-2-0-fast-260128": {"min": 409_600, "max": 927_408},
"dreamina-seedance-2-0-260128": {
"480p": {"min": 409_600, "max": 927_408},
"720p": {"min": 409_600, "max": 927_408},
"1080p": {"min": 409_600, "max": 2_073_600},
},
"dreamina-seedance-2-0-fast-260128": {
"480p": {"min": 409_600, "max": 927_408},
"720p": {"min": 409_600, "max": 927_408},
},
}
# The time in this dictionary are given for 10 seconds duration.

View File

@ -35,6 +35,7 @@ from comfy_api_nodes.util import (
get_number_of_images,
image_tensor_pair_to_batch,
poll_op,
resize_video_to_pixel_budget,
sync_op,
upload_audio_to_comfyapi,
upload_image_to_comfyapi,
@ -69,9 +70,12 @@ DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-2504
logger = logging.getLogger(__name__)
def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> None:
"""Validate reference video pixel count against Seedance 2.0 model limits."""
limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id)
def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: str, index: int) -> None:
"""Validate reference video pixel count against Seedance 2.0 model limits for the selected resolution."""
model_limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id)
if not model_limits:
return
limits = model_limits.get(resolution)
if not limits:
return
try:
@ -1373,6 +1377,14 @@ def _seedance2_reference_inputs(resolutions: list[str]):
min=0,
),
),
IO.Boolean.Input(
"auto_downscale",
default=False,
advanced=True,
optional=True,
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
),
]
@ -1480,10 +1492,23 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
model_id = SEEDANCE_MODELS[model["model"]]
has_video_input = len(reference_videos) > 0
if model.get("auto_downscale") and reference_videos:
max_px = (
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {})
.get(model["resolution"], {})
.get("max")
)
if max_px:
for key in reference_videos:
reference_videos[key] = resize_video_to_pixel_budget(
reference_videos[key], max_px
)
total_video_duration = 0.0
for i, key in enumerate(reference_videos, 1):
video = reference_videos[key]
_validate_ref_video_pixels(video, model_id, i)
_validate_ref_video_pixels(video, model_id, model["resolution"], i)
try:
dur = video.get_duration()
if dur < 1.8:

View File

@ -357,13 +357,18 @@ def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) ->
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
def calculate_tokens_price_image_2(response: OpenAIImageGenerationResponse) -> float | None:
# https://platform.openai.com/docs/pricing - gpt-image-2: input $8/1M, output $30/1M
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 30.0)) / 1_000_000.0
class OpenAIGPTImage1(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="OpenAIGPTImage1",
display_name="OpenAI GPT Image 1.5",
display_name="OpenAI GPT Image 1 & 1.5",
category="api node/image/OpenAI",
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
inputs=[
@ -442,14 +447,22 @@ class OpenAIGPTImage1(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]),
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n", "model"]),
expr="""
(
$ranges := {
"low": [0.011, 0.02],
"medium": [0.046, 0.07],
"high": [0.167, 0.3]
};
$m := widgets.model;
$ranges :=
$contains($m, "gpt-image-1.5")
? {
"low": [0.009, 0.016],
"medium": [0.037, 0.056],
"high": [0.134, 0.240]
}
: {
"low": [0.011, 0.020],
"medium": [0.046, 0.070],
"high": [0.167, 0.300]
};
$range := $lookup($ranges, widgets.quality);
$n := widgets.n;
($n = 1)
@ -564,6 +577,261 @@ class OpenAIGPTImage1(IO.ComfyNode):
return IO.NodeOutput(await validate_and_cast_response(response))
_GPT_IMAGE_2_SIZES = [
"auto",
"1024x1024",
"1536x1024",
"1024x1536",
"2048x2048",
"2048x1152",
"3840x2160",
"2160x3840",
]
def _resolve_gpt_image_2_size(size: str, custom_width: int, custom_height: int) -> str:
if custom_width <= 0 or custom_height <= 0:
return size
w, h = custom_width, custom_height
if max(w, h) > 3840:
raise ValueError(f"Maximum edge length must be ≤ 3840px, got {max(w, h)}")
if w % 16 != 0 or h % 16 != 0:
raise ValueError(f"Both edges must be multiples of 16px, got {w}x{h}")
if max(w, h) / min(w, h) > 3:
raise ValueError(f"Long-to-short edge ratio must not exceed 3:1, got {max(w, h) / min(w, h):.2f}:1")
total = w * h
if total < 655_360 or total > 8_294_400:
raise ValueError(f"Total pixels must be between 655,360 and 8,294,400, got {total:,}")
return f"{w}x{h}"
class OpenAIGPTImage2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="OpenAIGPTImage2",
display_name="OpenAI GPT Image 2",
category="api node/image/OpenAI",
description="Generates images synchronously via OpenAI's GPT-Image-2 endpoint.",
inputs=[
IO.String.Input(
"prompt",
default="",
multiline=True,
tooltip="Text prompt for GPT Image 2",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2**31 - 1,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="not implemented yet in backend",
optional=True,
),
IO.Combo.Input(
"quality",
default="auto",
options=["auto", "low", "medium", "high"],
tooltip="Image quality. 'auto' lets the model decide based on the prompt. Square images are typically fastest.",
optional=True,
),
IO.Combo.Input(
"background",
default="auto",
options=["auto", "opaque"],
tooltip="Background style. GPT-Image-2 does not support transparent backgrounds.",
optional=True,
),
IO.Combo.Input(
"size",
default="auto",
options=_GPT_IMAGE_2_SIZES,
tooltip="Output image dimensions. Ignored when custom_width and custom_height are both non-zero.",
optional=True,
),
IO.Int.Input(
"custom_width",
default=0,
min=0,
max=3840,
step=16,
display_mode=IO.NumberDisplay.number,
tooltip="Custom output width in pixels. Set to 0 (default) to use the size preset. When both width and height are non-zero, they override the size preset. Slider enforces multiples of 16 and max edge 3840px. Additional constraints checked at generation: ratio ≤ 3:1, total pixels 655,3608,294,400.",
optional=True,
),
IO.Int.Input(
"custom_height",
default=0,
min=0,
max=3840,
step=16,
display_mode=IO.NumberDisplay.number,
tooltip="Custom output height in pixels. Set to 0 (default) to use the size preset. When both width and height are non-zero, they override the size preset. Slider enforces multiples of 16 and max edge 3840px. Additional constraints checked at generation: ratio ≤ 3:1, total pixels 655,3608,294,400.",
optional=True,
),
IO.Int.Input(
"num_images",
default=1,
min=1,
max=8,
step=1,
tooltip="Number of images to generate per run.",
display_mode=IO.NumberDisplay.number,
optional=True,
),
IO.Image.Input(
"image",
tooltip="Optional reference image for image editing.",
optional=True,
),
IO.Mask.Input(
"mask",
tooltip="Optional mask for inpainting (white areas will be replaced).",
optional=True,
),
IO.Combo.Input(
"model",
options=["gpt-image-2"],
default="gpt-image-2",
tooltip="Model used for image generation.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["quality", "num_images"]),
expr="""
(
$ranges := {
"low": [0.005, 0.010],
"medium": [0.041, 0.060],
"high": [0.165, 0.250]
};
$q := widgets.quality;
$n := widgets.num_images;
$n := ($n != null and $n != 0) ? $n : 1;
$range := $lookup($ranges, $q);
$lo := $range ? $range[0] : 0.005;
$hi := $range ? $range[1] : 0.250;
($n = 1)
? {"type":"range_usd","min_usd": $lo, "max_usd": $hi, "format": {"approximate": ($range ? false : true)}}
: {
"type":"range_usd",
"min_usd": $lo,
"max_usd": $hi,
"format": {"approximate": ($range ? false : true), "suffix": " x " & $string($n) & "/Run"}
}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
seed: int = 0,
quality: str = "auto",
background: str = "auto",
image: Input.Image | None = None,
mask: Input.Image | None = None,
num_images: int = 1,
size: str = "auto",
custom_width: int = 0,
custom_height: int = 0,
model: str = "gpt-image-2",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
if mask is not None and image is None:
raise ValueError("Cannot use a mask without an input image")
resolved_size = _resolve_gpt_image_2_size(size, custom_width, custom_height)
if image is not None:
files = []
batch_size = image.shape[0]
for i in range(batch_size):
single_image = image[i : i + 1]
scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze()
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
img_byte_arr = BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
if batch_size == 1:
files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
else:
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
if mask is not None:
if image.shape[0] != 1:
raise Exception("Cannot use a mask with multiple image")
if mask.shape[1:] != image.shape[1:-1]:
raise Exception("Mask and Image must be the same size")
_, height, width = mask.shape
rgba_mask = torch.zeros(height, width, 4, device="cpu")
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048 * 2048).squeeze()
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_img_byte_arr = BytesIO()
mask_img.save(mask_img_byte_arr, format="PNG")
mask_img_byte_arr.seek(0)
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
response_model=OpenAIImageGenerationResponse,
data=OpenAIImageEditRequest(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=num_images,
size=resolved_size,
moderation="low",
),
content_type="multipart/form-data",
files=files,
price_extractor=calculate_tokens_price_image_2,
)
else:
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
response_model=OpenAIImageGenerationResponse,
data=OpenAIImageGenerationRequest(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=num_images,
size=resolved_size,
moderation="low",
),
price_extractor=calculate_tokens_price_image_2,
)
return IO.NodeOutput(await validate_and_cast_response(response))
class OpenAIChatNode(IO.ComfyNode):
"""
Node to generate text responses from an OpenAI model.
@ -913,6 +1181,7 @@ class OpenAIExtension(ComfyExtension):
OpenAIDalle2,
OpenAIDalle3,
OpenAIGPTImage1,
OpenAIGPTImage2,
OpenAIChatNode,
OpenAIInputFiles,
OpenAIChatConfig,

View File

@ -24,8 +24,9 @@ from comfy_api_nodes.util import (
AVERAGE_DURATION_VIDEO_GEN = 32
MODELS_MAP = {
"veo-2.0-generate-001": "veo-2.0-generate-001",
"veo-3.1-generate": "veo-3.1-generate-preview",
"veo-3.1-fast-generate": "veo-3.1-fast-generate-preview",
"veo-3.1-generate": "veo-3.1-generate-001",
"veo-3.1-fast-generate": "veo-3.1-fast-generate-001",
"veo-3.1-lite": "veo-3.1-lite-generate-001",
"veo-3.0-generate-001": "veo-3.0-generate-001",
"veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001",
}
@ -247,17 +248,8 @@ class VeoVideoGenerationNode(IO.ComfyNode):
raise Exception("Video generation completed but no video was returned")
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
"""
Generates videos from text prompts using Google's Veo 3 API.
Supported models:
- veo-3.0-generate-001
- veo-3.0-fast-generate-001
This node extends the base Veo node with Veo 3 specific features including
audio generation and fixed 8-second duration.
"""
class Veo3VideoGenerationNode(IO.ComfyNode):
"""Generates videos from text prompts using Google's Veo 3 API."""
@classmethod
def define_schema(cls):
@ -279,6 +271,13 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
default="16:9",
tooltip="Aspect ratio of the output video",
),
IO.Combo.Input(
"resolution",
options=["720p", "1080p", "4k"],
default="720p",
tooltip="Output video resolution. 4K is not available for veo-3.1-lite and veo-3.0 models.",
optional=True,
),
IO.String.Input(
"negative_prompt",
multiline=True,
@ -289,11 +288,11 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
IO.Int.Input(
"duration_seconds",
default=8,
min=8,
min=4,
max=8,
step=1,
step=2,
display_mode=IO.NumberDisplay.number,
tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
tooltip="Duration of the output video in seconds",
optional=True,
),
IO.Boolean.Input(
@ -332,10 +331,10 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
options=[
"veo-3.1-generate",
"veo-3.1-fast-generate",
"veo-3.1-lite",
"veo-3.0-generate-001",
"veo-3.0-fast-generate-001",
],
default="veo-3.0-generate-001",
tooltip="Veo 3 model to use for video generation",
optional=True,
),
@ -356,21 +355,111 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]),
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "resolution", "duration_seconds"]),
expr="""
(
$m := widgets.model;
$r := widgets.resolution;
$a := widgets.generate_audio;
($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate"))
? {"type":"usd","usd": ($a ? 1.2 : 0.8)}
: ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate"))
? {"type":"usd","usd": ($a ? 3.2 : 1.6)}
: {"type":"range_usd","min_usd":0.8,"max_usd":3.2}
$seconds := widgets.duration_seconds;
$pps :=
$contains($m, "lite")
? ($r = "1080p" ? ($a ? 0.08 : 0.05) : ($a ? 0.05 : 0.03))
: $contains($m, "3.1-fast")
? ($r = "4k" ? ($a ? 0.30 : 0.25) : $r = "1080p" ? ($a ? 0.12 : 0.10) : ($a ? 0.10 : 0.08))
: $contains($m, "3.1-generate")
? ($r = "4k" ? ($a ? 0.60 : 0.40) : ($a ? 0.40 : 0.20))
: $contains($m, "3.0-fast")
? ($a ? 0.15 : 0.10)
: ($a ? 0.40 : 0.20);
{"type":"usd","usd": $pps * $seconds}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt,
aspect_ratio="16:9",
resolution="720p",
negative_prompt="",
duration_seconds=8,
enhance_prompt=True,
person_generation="ALLOW",
seed=0,
image=None,
model="veo-3.0-generate-001",
generate_audio=False,
):
if "lite" in model and resolution == "4k":
raise Exception("4K resolution is not supported by the veo-3.1-lite model.")
model = MODELS_MAP[model]
instances = [{"prompt": prompt}]
if image is not None:
image_base64 = tensor_to_base64_string(image)
if image_base64:
instances[0]["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"}
parameters = {
"aspectRatio": aspect_ratio,
"personGeneration": person_generation,
"durationSeconds": duration_seconds,
"enhancePrompt": True,
"generateAudio": generate_audio,
}
if negative_prompt:
parameters["negativePrompt"] = negative_prompt
if seed > 0:
parameters["seed"] = seed
if "veo-3.1" in model:
parameters["resolution"] = resolution
initial_response = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
response_model=VeoGenVidResponse,
data=VeoGenVidRequest(
instances=instances,
parameters=parameters,
),
)
poll_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
response_model=VeoGenVidPollResponse,
status_extractor=lambda r: "completed" if r.done else "pending",
data=VeoGenVidPollRequest(operationName=initial_response.name),
poll_interval=9.0,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
)
if poll_response.error:
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
response = poll_response.response
filtered_count = response.raiMediaFilteredCount
if filtered_count:
reasons = response.raiMediaFilteredReasons or []
reason_part = f": {reasons[0]}" if reasons else ""
raise Exception(
f"Content blocked by Google's Responsible AI filters{reason_part} "
f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)."
)
if response.videos:
video = response.videos[0]
if video.bytesBase64Encoded:
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
raise Exception("Video returned but no data or URL was provided")
raise Exception("Video generation completed but no video was returned")
class Veo3FirstLastFrameNode(IO.ComfyNode):
@ -394,7 +483,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
default="",
tooltip="Negative text prompt to guide what to avoid in the video",
),
IO.Combo.Input("resolution", options=["720p", "1080p"]),
IO.Combo.Input("resolution", options=["720p", "1080p", "4k"]),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16"],
@ -424,8 +513,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
IO.Image.Input("last_frame", tooltip="End frame"),
IO.Combo.Input(
"model",
options=["veo-3.1-generate", "veo-3.1-fast-generate"],
default="veo-3.1-fast-generate",
options=["veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.1-lite"],
),
IO.Boolean.Input(
"generate_audio",
@ -443,26 +531,20 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]),
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration", "resolution"]),
expr="""
(
$prices := {
"veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 },
"veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 }
};
$m := widgets.model;
$ga := (widgets.generate_audio = "true");
$r := widgets.resolution;
$ga := widgets.generate_audio;
$seconds := widgets.duration;
$modelKey :=
$contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" :
$contains($m, "veo-3.1-generate") ? "veo-3.1-generate" :
"";
$audioKey := $ga ? "audio" : "no_audio";
$modelPrices := $lookup($prices, $modelKey);
$pps := $lookup($modelPrices, $audioKey);
($pps != null)
? {"type":"usd","usd": $pps * $seconds}
: {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2}
$pps :=
$contains($m, "lite")
? ($r = "1080p" ? ($ga ? 0.08 : 0.05) : ($ga ? 0.05 : 0.03))
: $contains($m, "fast")
? ($r = "4k" ? ($ga ? 0.30 : 0.25) : $r = "1080p" ? ($ga ? 0.12 : 0.10) : ($ga ? 0.10 : 0.08))
: ($r = "4k" ? ($ga ? 0.60 : 0.40) : ($ga ? 0.40 : 0.20));
{"type":"usd","usd": $pps * $seconds}
)
""",
),
@ -482,6 +564,9 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
model: str,
generate_audio: bool,
):
if "lite" in model and resolution == "4k":
raise Exception("4K resolution is not supported by the veo-3.1-lite model.")
model = MODELS_MAP[model]
initial_response = await sync_op(
cls,
@ -519,7 +604,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
data=VeoGenVidPollRequest(
operationName=initial_response.name,
),
poll_interval=5.0,
poll_interval=9.0,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
)

View File

@ -19,6 +19,7 @@ from .conversions import (
image_tensor_pair_to_batch,
pil_to_bytesio,
resize_mask_to_image,
resize_video_to_pixel_budget,
tensor_to_base64_string,
tensor_to_bytesio,
tensor_to_pil,
@ -90,6 +91,7 @@ __all__ = [
"image_tensor_pair_to_batch",
"pil_to_bytesio",
"resize_mask_to_image",
"resize_video_to_pixel_budget",
"tensor_to_base64_string",
"tensor_to_bytesio",
"tensor_to_pil",

View File

@ -129,22 +129,38 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
return img_byte_arr
def _compute_downscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[int, int] | None:
"""Return downscaled (w, h) with even dims fitting ``total_pixels``, or None if already fits.
Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions
are rounded down to even values (many codecs require divisible-by-2).
"""
pixels = src_w * src_h
if pixels <= total_pixels:
return None
scale = math.sqrt(total_pixels / pixels)
new_w = max(2, int(src_w * scale))
new_h = max(2, int(src_h * scale))
new_w -= new_w % 2
new_h -= new_h % 2
return new_w, new_h
def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor:
"""Downscale input image tensor to roughly the specified total pixels."""
"""Downscale input image tensor to roughly the specified total pixels.
Output dimensions are rounded down to even values so that the result is guaranteed to fit within ``total_pixels``
and is compatible with codecs that require even dimensions (e.g. yuv420p).
"""
samples = image.movedim(-1, 1)
total = int(total_pixels)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
if scale_by >= 1:
dims = _compute_downscale_dims(samples.shape[3], samples.shape[2], int(total_pixels))
if dims is None:
return image
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = common_upscale(samples, width, height, "lanczos", "disabled")
s = s.movedim(1, -1)
return s
new_w, new_h = dims
return common_upscale(samples, new_w, new_h, "lanczos", "disabled").movedim(1, -1)
def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor:
def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor:
"""Downscale input image tensor so the largest dimension is at most max_side pixels."""
samples = image.movedim(-1, 1)
height, width = samples.shape[2], samples.shape[3]
@ -399,6 +415,72 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
def resize_video_to_pixel_budget(video: Input.Video, total_pixels: int) -> Input.Video:
"""Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio.
Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio.
Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
"""
src_w, src_h = video.get_dimensions()
scale_dims = _compute_downscale_dims(src_w, src_h, total_pixels)
if scale_dims is None:
return video
return _apply_video_scale(video, scale_dims)
def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input.Video:
"""Re-encode ``video`` scaled to ``scale_dims`` with a single decode/encode pass."""
out_w, out_h = scale_dims
output_buffer = BytesIO()
input_container = None
output_container = None
try:
input_source = video.get_stream_source()
input_container = av.open(input_source, mode="r")
output_container = av.open(output_buffer, mode="w", format="mp4")
video_stream = output_container.add_stream("h264", rate=video.get_frame_rate())
video_stream.width = out_w
video_stream.height = out_h
video_stream.pix_fmt = "yuv420p"
audio_stream = None
for stream in input_container.streams:
if isinstance(stream, av.AudioStream):
audio_stream = output_container.add_stream("aac", rate=stream.sample_rate)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
break
for frame in input_container.decode(video=0):
frame = frame.reformat(width=out_w, height=out_h, format="yuv420p")
for packet in video_stream.encode(frame):
output_container.mux(packet)
for packet in video_stream.encode():
output_container.mux(packet)
if audio_stream is not None:
input_container.seek(0)
for audio_frame in input_container.decode(audio=0):
for packet in audio_stream.encode(audio_frame):
output_container.mux(packet)
for packet in audio_stream.encode():
output_container.mux(packet)
output_container.close()
input_container.close()
output_buffer.seek(0)
return InputImpl.VideoFromFile(output_buffer)
except Exception as e:
if input_container is not None:
input_container.close()
if output_container is not None:
output_container.close()
raise RuntimeError(f"Failed to resize video: {str(e)}") from e
def _f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
if wav.dtype.is_floating_point:

View File

@ -3,136 +3,136 @@ from typing_extensions import override
import comfy.model_management
import node_helpers
from comfy_api.latest import ComfyExtension, io
from comfy_api.latest import ComfyExtension, IO
class TextEncodeAceStepAudio(io.ComfyNode):
class TextEncodeAceStepAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
return IO.Schema(
node_id="TextEncodeAceStepAudio",
category="conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("tags", multiline=True, dynamic_prompts=True),
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
IO.Clip.Input("clip"),
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
IO.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[io.Conditioning.Output()],
outputs=[IO.Conditioning.Output()],
)
@classmethod
def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput:
def execute(cls, clip, tags, lyrics, lyrics_strength) -> IO.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics)
conditioning = clip.encode_from_tokens_scheduled(tokens)
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
return io.NodeOutput(conditioning)
return IO.NodeOutput(conditioning)
class TextEncodeAceStepAudio15(io.ComfyNode):
class TextEncodeAceStepAudio15(IO.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
return IO.Schema(
node_id="TextEncodeAceStepAudio1.5",
category="conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("tags", multiline=True, dynamic_prompts=True),
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
io.Int.Input("bpm", default=120, min=10, max=300),
io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
IO.Clip.Input("clip"),
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
IO.Int.Input("bpm", default=120, min=10, max=300),
IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
IO.Combo.Input("timesignature", options=['2', '3', '4', '6']),
IO.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
IO.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
IO.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
IO.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
IO.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
],
outputs=[io.Conditioning.Output()],
outputs=[IO.Conditioning.Output()],
)
@classmethod
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> IO.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
conditioning = clip.encode_from_tokens_scheduled(tokens)
return io.NodeOutput(conditioning)
return IO.NodeOutput(conditioning)
class EmptyAceStepLatentAudio(io.ComfyNode):
class EmptyAceStepLatentAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
return IO.Schema(
node_id="EmptyAceStepLatentAudio",
display_name="Empty Ace Step 1.0 Latent Audio",
category="latent/audio",
inputs=[
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
io.Int.Input(
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
IO.Int.Input(
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
),
],
outputs=[io.Latent.Output()],
outputs=[IO.Latent.Output()],
)
@classmethod
def execute(cls, seconds, batch_size) -> io.NodeOutput:
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
length = int(seconds * 44100 / 512 / 8)
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return io.NodeOutput({"samples": latent, "type": "audio"})
return IO.NodeOutput({"samples": latent, "type": "audio"})
class EmptyAceStep15LatentAudio(io.ComfyNode):
class EmptyAceStep15LatentAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
return IO.Schema(
node_id="EmptyAceStep1.5LatentAudio",
display_name="Empty Ace Step 1.5 Latent Audio",
category="latent/audio",
inputs=[
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
io.Int.Input(
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
IO.Int.Input(
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
),
],
outputs=[io.Latent.Output()],
outputs=[IO.Latent.Output()],
)
@classmethod
def execute(cls, seconds, batch_size) -> io.NodeOutput:
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
length = round((seconds * 48000 / 1920))
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return io.NodeOutput({"samples": latent, "type": "audio"})
return IO.NodeOutput({"samples": latent, "type": "audio"})
class ReferenceAudio(io.ComfyNode):
class ReferenceAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
return IO.Schema(
node_id="ReferenceTimbreAudio",
display_name="Reference Audio",
category="advanced/conditioning/audio",
is_experimental=True,
description="This node sets the reference audio for ace step 1.5",
inputs=[
io.Conditioning.Input("conditioning"),
io.Latent.Input("latent", optional=True),
IO.Conditioning.Input("conditioning"),
IO.Latent.Input("latent", optional=True),
],
outputs=[
io.Conditioning.Output(),
IO.Conditioning.Output(),
]
)
@classmethod
def execute(cls, conditioning, latent=None) -> io.NodeOutput:
def execute(cls, conditioning, latent=None) -> IO.NodeOutput:
if latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True)
return io.NodeOutput(conditioning)
return IO.NodeOutput(conditioning)
class AceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TextEncodeAceStepAudio,
EmptyAceStepLatentAudio,

View File

@ -104,7 +104,7 @@ def vae_decode_audio(vae, samples, tile=None, overlap=None):
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100))
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}

View File

@ -3,9 +3,8 @@ import comfy.utils
import comfy.model_management
import torch
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_audio import VAEEncodeAudio
class LTXVAudioVAELoader(io.ComfyNode):
@classmethod
@ -28,10 +27,14 @@ class LTXVAudioVAELoader(io.ComfyNode):
def execute(cls, ckpt_name: str) -> io.NodeOutput:
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
return io.NodeOutput(AudioVAE(sd, metadata))
sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder.", "vocoder.": "vocoder."}, filter_keys=True)
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
vae.throw_exception_if_invalid()
return io.NodeOutput(vae)
class LTXVAudioVAEEncode(io.ComfyNode):
class LTXVAudioVAEEncode(VAEEncodeAudio):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
@ -50,15 +53,8 @@ class LTXVAudioVAEEncode(io.ComfyNode):
)
@classmethod
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput:
audio_latents = audio_vae.encode(audio)
return io.NodeOutput(
{
"samples": audio_latents,
"sample_rate": int(audio_vae.sample_rate),
"type": "audio",
}
)
def execute(cls, audio, audio_vae) -> io.NodeOutput:
return super().execute(audio_vae, audio)
class LTXVAudioVAEDecode(io.ComfyNode):
@ -80,12 +76,12 @@ class LTXVAudioVAEDecode(io.ComfyNode):
)
@classmethod
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput:
def execute(cls, samples, audio_vae) -> io.NodeOutput:
audio_latent = samples["samples"]
if audio_latent.is_nested:
audio_latent = audio_latent.unbind()[-1]
audio = audio_vae.decode(audio_latent).to(audio_latent.device)
output_audio_sample_rate = audio_vae.output_sample_rate
audio = audio_vae.decode(audio_latent).movedim(-1, 1).to(audio_latent.device)
output_audio_sample_rate = audio_vae.first_stage_model.output_sample_rate
return io.NodeOutput(
{
"waveform": audio,
@ -143,17 +139,17 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
frames_number: int,
frame_rate: int,
batch_size: int,
audio_vae: AudioVAE,
audio_vae,
) -> io.NodeOutput:
"""Generate empty audio latents matching the reference pipeline structure."""
assert audio_vae is not None, "Audio VAE model is required"
z_channels = audio_vae.latent_channels
audio_freq = audio_vae.latent_frequency_bins
sampling_rate = int(audio_vae.sample_rate)
audio_freq = audio_vae.first_stage_model.latent_frequency_bins
sampling_rate = int(audio_vae.first_stage_model.sample_rate)
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate)
num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate)
audio_latents = torch.zeros(
(batch_size, z_channels, num_audio_latents, audio_freq),

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.42.11
comfyui-frontend-package==1.42.14
comfyui-workflow-templates==0.9.57
comfyui-embedded-docs==0.4.3
torch

View File

@ -0,0 +1,246 @@
import pytest
from comfy_api_nodes.nodes_openai import (
OpenAIGPTImage1,
OpenAIGPTImage2,
_GPT_IMAGE_2_SIZES,
_resolve_gpt_image_2_size,
calculate_tokens_price_image_1,
calculate_tokens_price_image_1_5,
calculate_tokens_price_image_2,
)
from comfy_api_nodes.apis.openai import OpenAIImageGenerationResponse, Usage
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_response(input_tokens: int, output_tokens: int) -> OpenAIImageGenerationResponse:
return OpenAIImageGenerationResponse(
data=[],
usage=Usage(input_tokens=input_tokens, output_tokens=output_tokens),
)
# ---------------------------------------------------------------------------
# Price extractor tests
# ---------------------------------------------------------------------------
def test_price_image_1_formula():
response = _make_response(input_tokens=1_000_000, output_tokens=1_000_000)
assert calculate_tokens_price_image_1(response) == pytest.approx(50.0)
def test_price_image_1_5_formula():
response = _make_response(input_tokens=1_000_000, output_tokens=1_000_000)
assert calculate_tokens_price_image_1_5(response) == pytest.approx(40.0)
def test_price_image_2_formula():
response = _make_response(input_tokens=1_000_000, output_tokens=1_000_000)
assert calculate_tokens_price_image_2(response) == pytest.approx(38.0)
def test_price_image_2_cheaper_than_1():
response = _make_response(input_tokens=500, output_tokens=196)
assert calculate_tokens_price_image_2(response) < calculate_tokens_price_image_1(response)
def test_price_image_2_cheaper_output_than_1_5():
# gpt-image-2 output rate ($30/1M) is lower than gpt-image-1.5 ($32/1M)
response = _make_response(input_tokens=0, output_tokens=1_000_000)
assert calculate_tokens_price_image_2(response) < calculate_tokens_price_image_1_5(response)
# ---------------------------------------------------------------------------
# _resolve_gpt_image_2_size tests
# ---------------------------------------------------------------------------
def test_resolve_preset_passthrough_when_custom_zero():
# 0/0 means "use size preset"
assert _resolve_gpt_image_2_size("1024x1024", 0, 0) == "1024x1024"
assert _resolve_gpt_image_2_size("auto", 0, 0) == "auto"
assert _resolve_gpt_image_2_size("3840x2160", 0, 0) == "3840x2160"
def test_resolve_preset_passthrough_when_only_one_dim_set():
# only one dimension set → still use preset
assert _resolve_gpt_image_2_size("auto", 1024, 0) == "auto"
assert _resolve_gpt_image_2_size("auto", 0, 1024) == "auto"
def test_resolve_custom_overrides_preset():
assert _resolve_gpt_image_2_size("auto", 1024, 1024) == "1024x1024"
assert _resolve_gpt_image_2_size("1024x1024", 2048, 1152) == "2048x1152"
assert _resolve_gpt_image_2_size("auto", 3840, 2160) == "3840x2160"
def test_resolve_custom_rejects_edge_too_large():
with pytest.raises(ValueError, match="3840"):
_resolve_gpt_image_2_size("auto", 4096, 1024)
def test_resolve_custom_rejects_non_multiple_of_16():
with pytest.raises(ValueError, match="multiple of 16"):
_resolve_gpt_image_2_size("auto", 1025, 1024)
def test_resolve_custom_rejects_bad_ratio():
with pytest.raises(ValueError, match="ratio"):
_resolve_gpt_image_2_size("auto", 3840, 1024) # 3.75:1 > 3:1
def test_resolve_custom_rejects_too_few_pixels():
with pytest.raises(ValueError, match="Total pixels"):
_resolve_gpt_image_2_size("auto", 16, 16)
def test_resolve_custom_rejects_too_many_pixels():
# 3840x2176 exceeds 8,294,400
with pytest.raises(ValueError, match="Total pixels"):
_resolve_gpt_image_2_size("auto", 3840, 2176)
# ---------------------------------------------------------------------------
# OpenAIGPTImage1 schema tests
# ---------------------------------------------------------------------------
class TestOpenAIGPTImage1Schema:
def setup_method(self):
self.schema = OpenAIGPTImage1.define_schema()
def test_node_id(self):
assert self.schema.node_id == "OpenAIGPTImage1"
def test_display_name(self):
assert self.schema.display_name == "OpenAI GPT Image 1 & 1.5"
def test_model_options_exclude_gpt_image_2(self):
model_input = next(i for i in self.schema.inputs if i.name == "model")
assert "gpt-image-2" not in model_input.options
def test_model_options_include_legacy_models(self):
model_input = next(i for i in self.schema.inputs if i.name == "model")
assert "gpt-image-1" in model_input.options
assert "gpt-image-1.5" in model_input.options
def test_has_background_with_transparent(self):
bg_input = next(i for i in self.schema.inputs if i.name == "background")
assert "transparent" in bg_input.options
# ---------------------------------------------------------------------------
# OpenAIGPTImage2 schema tests
# ---------------------------------------------------------------------------
class TestOpenAIGPTImage2Schema:
def setup_method(self):
self.schema = OpenAIGPTImage2.define_schema()
def test_node_id(self):
assert self.schema.node_id == "OpenAIGPTImage2"
def test_display_name(self):
assert self.schema.display_name == "OpenAI GPT Image 2"
def test_category(self):
assert "OpenAI" in self.schema.category
def test_no_transparent_background(self):
bg_input = next(i for i in self.schema.inputs if i.name == "background")
assert "transparent" not in bg_input.options
def test_background_options(self):
bg_input = next(i for i in self.schema.inputs if i.name == "background")
assert set(bg_input.options) == {"auto", "opaque"}
def test_quality_options(self):
quality_input = next(i for i in self.schema.inputs if i.name == "quality")
assert set(quality_input.options) == {"auto", "low", "medium", "high"}
def test_quality_default_is_auto(self):
quality_input = next(i for i in self.schema.inputs if i.name == "quality")
assert quality_input.default == "auto"
def test_all_popular_sizes_present(self):
size_input = next(i for i in self.schema.inputs if i.name == "size")
for size in ["1024x1024", "1536x1024", "1024x1536", "2048x2048", "2048x1152", "3840x2160", "2160x3840"]:
assert size in size_input.options, f"Missing size: {size}"
def test_no_custom_size_option(self):
size_input = next(i for i in self.schema.inputs if i.name == "size")
assert "custom" not in size_input.options
def test_size_default_is_auto(self):
size_input = next(i for i in self.schema.inputs if i.name == "size")
assert size_input.default == "auto"
def test_custom_width_and_height_inputs_exist(self):
input_names = [i.name for i in self.schema.inputs]
assert "custom_width" in input_names
assert "custom_height" in input_names
def test_custom_width_height_default_zero(self):
width_input = next(i for i in self.schema.inputs if i.name == "custom_width")
height_input = next(i for i in self.schema.inputs if i.name == "custom_height")
assert width_input.default == 0
assert height_input.default == 0
def test_custom_width_height_step_is_16(self):
width_input = next(i for i in self.schema.inputs if i.name == "custom_width")
height_input = next(i for i in self.schema.inputs if i.name == "custom_height")
assert width_input.step == 16
assert height_input.step == 16
def test_custom_width_height_max_is_3840(self):
width_input = next(i for i in self.schema.inputs if i.name == "custom_width")
height_input = next(i for i in self.schema.inputs if i.name == "custom_height")
assert width_input.max == 3840
assert height_input.max == 3840
def test_uses_num_images_not_n(self):
input_names = [i.name for i in self.schema.inputs]
assert "num_images" in input_names
assert "n" not in input_names
def test_model_input_shows_gpt_image_2(self):
model_input = next(i for i in self.schema.inputs if i.name == "model")
assert model_input.options == ["gpt-image-2"]
assert model_input.default == "gpt-image-2"
def test_has_image_and_mask_inputs(self):
input_names = [i.name for i in self.schema.inputs]
assert "image" in input_names
assert "mask" in input_names
def test_is_api_node(self):
assert self.schema.is_api_node is True
def test_sizes_match_constant(self):
size_input = next(i for i in self.schema.inputs if i.name == "size")
assert size_input.options == _GPT_IMAGE_2_SIZES
# ---------------------------------------------------------------------------
# OpenAIGPTImage2 execute validation tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_execute_raises_on_empty_prompt():
with pytest.raises(Exception):
await OpenAIGPTImage2.execute(prompt=" ")
@pytest.mark.asyncio
async def test_execute_raises_mask_without_image():
import torch
mask = torch.ones(1, 64, 64)
with pytest.raises(ValueError, match="mask without an input image"):
await OpenAIGPTImage2.execute(prompt="test", mask=mask)
@pytest.mark.asyncio
async def test_execute_raises_invalid_custom_size():
with pytest.raises(ValueError):
await OpenAIGPTImage2.execute(prompt="test", custom_width=4096, custom_height=1024)