mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-17 04:07:31 +08:00
Compare commits
1 Commits
cloud-open
...
ListInput
| Author | SHA1 | Date | |
|---|---|---|---|
| a3b9cf837d |
@ -364,7 +364,7 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (implies `--enable-manager`) |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||
|
||||
|
||||
@ -382,7 +382,11 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt
|
||||
|
||||
### AMD ROCm Tips
|
||||
|
||||
You can try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
|
||||
You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default.
|
||||
|
||||
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
|
||||
|
||||
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
|
||||
|
||||
# Notes
|
||||
|
||||
|
||||
@ -115,7 +115,6 @@ cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metav
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--high-ram", action="store_true", help="Can improve performance slightly on high RAM or on systems where pagefile use is preferred over model loading.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
@ -134,7 +133,7 @@ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disabl
|
||||
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
|
||||
manager_group = parser.add_mutually_exclusive_group()
|
||||
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
|
||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager. Implies --enable-manager.")
|
||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
|
||||
|
||||
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
@ -145,7 +144,6 @@ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn'
|
||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||
parser.add_argument("--vram-headroom", type=float, default=0, help="Set the amount of vram in GB for DynamicVRAM to maintain as extra headroom above default. ComfyUI will try and keep this much VRAM completely free and unused, even counting VRAM from other apps.")
|
||||
|
||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
@ -251,9 +249,6 @@ else:
|
||||
if args.cache_ram is not None and len(args.cache_ram) > 2:
|
||||
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
|
||||
|
||||
if args.high_ram:
|
||||
args.cache_classic = True
|
||||
|
||||
if args.windows_standalone_build:
|
||||
args.auto_launch = True
|
||||
|
||||
@ -263,10 +258,6 @@ if args.disable_auto_launch:
|
||||
if args.force_fp16:
|
||||
args.fp16_unet = True
|
||||
|
||||
# '--enable-manager-legacy-ui' is meaningless unless the manager is enabled, so imply '--enable-manager'.
|
||||
if args.enable_manager_legacy_ui:
|
||||
args.enable_manager = True
|
||||
|
||||
|
||||
# '--fast' is not provided, use an empty set
|
||||
if args.fast is None:
|
||||
|
||||
@ -106,11 +106,11 @@ class Ideogram4EmbedScalar(nn.Module):
|
||||
self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||
self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, dtype):
|
||||
def forward(self, x):
|
||||
x = x.to(torch.float32)
|
||||
scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min)
|
||||
emb = _sinusoidal_embedding(scaled, self.dim)
|
||||
emb = emb.to(dtype)
|
||||
emb = emb.to(self.mlp_in.weight.dtype)
|
||||
emb = F.silu(self.mlp_in(emb))
|
||||
return self.mlp_out(emb)
|
||||
|
||||
@ -161,7 +161,7 @@ class Ideogram4Transformer(nn.Module):
|
||||
x = x * output_image_mask
|
||||
h = self.input_proj(x) * output_image_mask
|
||||
|
||||
t_cond = self.t_embedding(t, dtype=x.dtype)
|
||||
t_cond = self.t_embedding(t)
|
||||
if t.dim() == 1:
|
||||
t_cond = t_cond.unsqueeze(1)
|
||||
adaln_input = F.silu(self.adaln_proj(t_cond))
|
||||
|
||||
@ -8,7 +8,6 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from comfy.ldm.lightricks.model import Timesteps
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
@ -18,7 +17,9 @@ def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
return apply_rope1(x, freqs_cis)
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape).to(dtype=x.dtype)
|
||||
|
||||
|
||||
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -643,8 +643,6 @@ def free_pins(size, evict_active=False):
|
||||
return freed_total
|
||||
|
||||
def ensure_pin_budget(size, evict_active=False):
|
||||
if args.high_ram:
|
||||
return True
|
||||
if args.fast_disk:
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
else:
|
||||
@ -1498,8 +1496,6 @@ if not args.disable_pinned_memory:
|
||||
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||
|
||||
def pinned_hostbuf_size(size):
|
||||
if args.high_ram:
|
||||
return max(0, int(size * 2))
|
||||
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
|
||||
|
||||
def discard_cuda_async_error():
|
||||
|
||||
@ -180,7 +180,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
if pin is not None:
|
||||
cast_maybe_lowvram_patch([pin], dest, offload_stream)
|
||||
return
|
||||
if signature is None or args.high_ram:
|
||||
if signature is None:
|
||||
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest)
|
||||
|
||||
@ -27,13 +27,10 @@ class VideoInput(ABC):
|
||||
path: Union[str, IO[bytes]],
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: int | None = None,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Abstract method to save the video input to a file.
|
||||
|
||||
bit_depth selects the encoded bit depth; None keeps the video's native depth.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -86,14 +83,6 @@ class VideoInput(ABC):
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
||||
def get_bit_depth(self) -> int:
|
||||
"""
|
||||
Returns the bit depth of the video (e.g. 8 or 10).
|
||||
|
||||
Default implementation returns 8; subclasses report their real depth.
|
||||
"""
|
||||
return 8
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
@ -52,12 +52,6 @@ def get_open_write_kwargs(
|
||||
return open_kwargs
|
||||
|
||||
|
||||
def video_stream_bit_depth(stream) -> int:
|
||||
if stream is None or stream.format is None or not stream.format.components:
|
||||
return 8
|
||||
return max(component.bits for component in stream.format.components)
|
||||
|
||||
|
||||
class VideoFromFile(VideoInput):
|
||||
"""
|
||||
Class representing video input from a file.
|
||||
@ -103,13 +97,6 @@ class VideoFromFile(VideoInput):
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_bit_depth(self) -> int:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||
return video_stream_bit_depth(video_stream)
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
@ -270,7 +257,6 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
image_format = 'gbrpf32le'
|
||||
process_image_format = lambda a: a
|
||||
align_graph = None
|
||||
audio = None
|
||||
|
||||
streams = [video_stream]
|
||||
@ -324,28 +310,7 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
checked_alpha = True
|
||||
|
||||
# Fix non-deterministic video decode when the video width is not a multiple of 32
|
||||
# For non-yuvj pixel formats: most H.264/H.265 video and static images (e.g. lossy WebP via LoadImage)
|
||||
# Pad both axes to a multiple of 32 and smear the border so the alignment padding never bleeds into the cropped edges
|
||||
if image_format in ('gbrpf32le', 'gbrapf32le') and frame.width % 32 != 0:
|
||||
if align_graph is None:
|
||||
pad_w = ((frame.width + 31) // 32) * 32
|
||||
pad_h = ((frame.height + 31) // 32) * 32
|
||||
g = av.filter.Graph()
|
||||
g_src = g.add_buffer(width=frame.width, height=frame.height,
|
||||
format=frame.format.name, time_base=video_stream.time_base)
|
||||
g_pad = g.add('pad', f'{pad_w}:{pad_h}:0:0')
|
||||
g_fill = g.add('fillborders', f'left=0:right={pad_w - frame.width}:top=0:bottom={pad_h - frame.height}:mode=smear')
|
||||
g_sink = g.add('buffersink')
|
||||
g_src.link_to(g_pad)
|
||||
g_pad.link_to(g_fill)
|
||||
g_fill.link_to(g_sink)
|
||||
g.configure()
|
||||
align_graph = (g, g_src, g_sink)
|
||||
align_graph[1].push(frame)
|
||||
img = np.ascontiguousarray(align_graph[2].pull().to_ndarray(format=image_format)[:frame.height, :frame.width])
|
||||
else:
|
||||
img = frame.to_ndarray(format=image_format)
|
||||
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
|
||||
if frame.rotation != 0:
|
||||
k = int(round(frame.rotation // 90))
|
||||
img = np.rot90(img, k=k, axes=(0, 1)).copy()
|
||||
@ -412,32 +377,25 @@ class VideoFromFile(VideoInput):
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: int | None = None,
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||
video_encoding = video_stream.codec.name if video_stream is not None else None
|
||||
source_bit_depth = video_stream_bit_depth(video_stream)
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
reuse_streams = True
|
||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
if bit_depth is not None and video_encoding is not None and bit_depth != source_bit_depth:
|
||||
reuse_streams = False
|
||||
if self.__start_time or self.__duration:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
if bit_depth is None:
|
||||
bit_depth = source_bit_depth
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path, format=format, codec=codec, metadata=metadata, bit_depth=bit_depth,
|
||||
path, format=format, codec=codec, metadata=metadata
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
@ -493,10 +451,8 @@ class VideoFromComponents(VideoInput):
|
||||
Class representing video input from tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, components: VideoComponents, bit_depth: int = 8):
|
||||
def __init__(self, components: VideoComponents):
|
||||
self.__components = components
|
||||
# Tensor components have no inherent bit depth; this is the depth used when encoding.
|
||||
self.__bit_depth = bit_depth
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
return VideoComponents(
|
||||
@ -505,26 +461,18 @@ class VideoFromComponents(VideoInput):
|
||||
frame_rate=self.__components.frame_rate,
|
||||
)
|
||||
|
||||
def get_bit_depth(self) -> int:
|
||||
return self.__bit_depth
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: int | None = None,
|
||||
):
|
||||
"""Save the video to a file path or BytesIO buffer."""
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
# None means "use the depth this video was created with" (CreateVideo's choice).
|
||||
if bit_depth is None:
|
||||
bit_depth = self.__bit_depth
|
||||
is_10bit = bit_depth >= 10
|
||||
extra_kwargs = {}
|
||||
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
||||
extra_kwargs["format"] = format.value
|
||||
@ -540,11 +488,10 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
pix_fmt = "yuv420p10le" if is_10bit else "yuv420p"
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.pix_fmt = pix_fmt
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
@ -558,14 +505,9 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
if is_10bit:
|
||||
# 16-bit RGB keeps float precision through the conversion to 10-bit YUV.
|
||||
img = (frame.float() * 65535).clamp(0, 65535).cpu().numpy().astype(np.uint16) # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format="rgb48le")
|
||||
else:
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format=pix_fmt)
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
packet = video_stream.encode(frame)
|
||||
output.mux(packet)
|
||||
|
||||
|
||||
@ -1253,6 +1253,140 @@ class DynamicSlot(ComfyTypeI):
|
||||
out_dict[input_type][finalized_id] = value
|
||||
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||
|
||||
@comfytype(io_type="COMFY_LIST_V3")
|
||||
class List(ComfyTypeI):
|
||||
"""A repeatable group of widget inputs (e.g. lora_name + strength stacked into N rows).
|
||||
|
||||
At execution time the node receives a ``list[dict]`` where each element is a row.
|
||||
|
||||
Example::
|
||||
|
||||
io.List.Input(
|
||||
"loras",
|
||||
template=[
|
||||
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras")),
|
||||
io.Float.Input("strength", default=1.0, min=-100, max=100, step=0.01),
|
||||
],
|
||||
min=0,
|
||||
max=50,
|
||||
)
|
||||
# execute receives: loras: list[dict] = [{"lora_name": "x.safetensors", "strength": 1.0}, ...]
|
||||
"""
|
||||
|
||||
Type = list[dict[str, Any]]
|
||||
_MaxRows = 100
|
||||
|
||||
class Input(DynamicInput):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
template: list["Input"],
|
||||
min: int = 0,
|
||||
max: int = 50,
|
||||
display_name: str = None,
|
||||
optional: bool = False,
|
||||
tooltip: str = None,
|
||||
lazy: bool = None,
|
||||
extra_dict=None,
|
||||
):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
# Validate template entries: only WidgetInput subclasses, no nesting
|
||||
assert len(template) > 0, "List template must have at least one field."
|
||||
for t in template:
|
||||
assert isinstance(t, WidgetInput), (
|
||||
f"List template field '{t.id}' must be a WidgetInput subclass "
|
||||
f"(Combo, Float, Int, String, Boolean, Color). Got {type(t).__name__}."
|
||||
)
|
||||
assert not isinstance(t, DynamicInput), (
|
||||
f"List template field '{t.id}' must not be a DynamicInput. "
|
||||
"Nesting dynamic inputs inside List is not supported."
|
||||
)
|
||||
# Enforce unique field ids within template
|
||||
field_ids = [t.id for t in template]
|
||||
assert len(field_ids) == len(set(field_ids)), (
|
||||
f"List template field ids must be unique within a row. Got: {field_ids}"
|
||||
)
|
||||
assert min >= 0, "List min must be >= 0."
|
||||
assert max >= 1, "List max must be >= 1."
|
||||
assert max <= List._MaxRows, f"List max must be <= {List._MaxRows}."
|
||||
assert min <= max, "List min must be <= max."
|
||||
self.template = template
|
||||
self.min = min
|
||||
self.max = max
|
||||
|
||||
def get_all(self) -> list["Input"]:
|
||||
return [self] + list(self.template)
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": create_input_dict_v1(self.template),
|
||||
"min": self.min,
|
||||
"max": self.max,
|
||||
})
|
||||
|
||||
def validate(self):
|
||||
for t in self.template:
|
||||
t.validate()
|
||||
|
||||
@staticmethod
|
||||
def _expand_schema_for_dynamic(
|
||||
out_dict: dict[str, Any],
|
||||
live_inputs: dict[str, Any],
|
||||
value: tuple[str, dict[str, Any]],
|
||||
input_type: str,
|
||||
curr_prefix: list[str] | None,
|
||||
):
|
||||
info = value[1]
|
||||
min_rows: int = info.get("min", 0)
|
||||
template: dict[str, Any] = info.get("template", {})
|
||||
|
||||
# Collect all template field specs across required/optional sections
|
||||
field_specs: list[tuple[str, tuple[str, dict[str, Any]], bool]] = []
|
||||
for field_required_key in ("required", "optional"):
|
||||
section = template.get(field_required_key, {})
|
||||
is_required_field = field_required_key == "required"
|
||||
for field_id, field_value in section.items():
|
||||
field_specs.append((field_id, field_value, is_required_field))
|
||||
|
||||
# Determine how many rows are currently present by scanning live_inputs
|
||||
finalized_prefix = finalize_prefix(curr_prefix)
|
||||
present_rows = 0
|
||||
for live_key in live_inputs:
|
||||
# Keys look like "<prefix>.<row>.<field_id>"
|
||||
if live_key.startswith(finalized_prefix + "."):
|
||||
remainder = live_key[len(finalized_prefix) + 1:]
|
||||
parts = remainder.split(".", 1)
|
||||
if len(parts) >= 1:
|
||||
try:
|
||||
row_idx = int(parts[0])
|
||||
present_rows = max(present_rows, row_idx + 1)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
row_count = max(min_rows, present_rows)
|
||||
|
||||
for row in range(row_count):
|
||||
for field_id, field_value, is_required_field in field_specs:
|
||||
slot_id = f"{finalized_prefix}.{row}.{field_id}"
|
||||
# The first `min_rows` rows are required if the field itself is required
|
||||
if row < min_rows and is_required_field:
|
||||
out_dict["required"][slot_id] = field_value
|
||||
else:
|
||||
out_dict["optional"][slot_id] = field_value
|
||||
# Register into dynamic_paths so build_nested_inputs places value at the right path
|
||||
out_dict["dynamic_paths"][slot_id] = slot_id
|
||||
|
||||
# Track the list root path so build_nested_inputs can convert the index dict to a list
|
||||
out_dict.setdefault("list_paths", set()).add(finalized_prefix)
|
||||
|
||||
# Handle the empty case (0 rows) – emit an empty-list default for the parent.
|
||||
# This must only fire when there are genuinely no rows; otherwise the parent
|
||||
# path would clobber the per-row dict built from the slot ids above.
|
||||
if row_count == 0:
|
||||
out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix
|
||||
out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_LIST
|
||||
|
||||
|
||||
@comfytype(io_type="IMAGECOMPARE")
|
||||
class ImageCompare(ComfyTypeI):
|
||||
Type = dict
|
||||
@ -1383,6 +1517,8 @@ def setup_dynamic_input_funcs():
|
||||
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
|
||||
# DynamicSlot.Input
|
||||
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
|
||||
# List.Input
|
||||
register_dynamic_input_func(List.io_type, List._expand_schema_for_dynamic)
|
||||
|
||||
if len(DYNAMIC_INPUT_LOOKUP) == 0:
|
||||
setup_dynamic_input_funcs()
|
||||
@ -1394,14 +1530,15 @@ class V3Data(TypedDict):
|
||||
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
||||
dynamic_paths_default_value: dict[str, Any]
|
||||
'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.'
|
||||
list_paths: set[str]
|
||||
'Set of top-level keys whose index-keyed dict values should be converted to a sorted list[dict] after build_nested_inputs runs.'
|
||||
create_dynamic_tuple: bool
|
||||
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
||||
|
||||
class HiddenHolder:
|
||||
def __init__(self, unique_id: str, prompt: Any,
|
||||
extra_pnginfo: Any, dynprompt: Any,
|
||||
auth_token_comfy_org: str, api_key_comfy_org: str,
|
||||
comfy_usage_source: str = None, **kwargs):
|
||||
auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs):
|
||||
self.unique_id = unique_id
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
self.prompt = prompt
|
||||
@ -1414,8 +1551,6 @@ class HiddenHolder:
|
||||
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
|
||||
self.api_key_comfy_org = api_key_comfy_org
|
||||
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
|
||||
self.comfy_usage_source = comfy_usage_source
|
||||
"""COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header."""
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
'''If hidden variable not found, return None.'''
|
||||
@ -1432,7 +1567,6 @@ class HiddenHolder:
|
||||
dynprompt=d.get(Hidden.dynprompt, None),
|
||||
auth_token_comfy_org=d.get(Hidden.auth_token_comfy_org, None),
|
||||
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
|
||||
comfy_usage_source=d.get(Hidden.comfy_usage_source, None),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1455,8 +1589,6 @@ class Hidden(str, Enum):
|
||||
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
|
||||
api_key_comfy_org = "API_KEY_COMFY_ORG"
|
||||
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
|
||||
comfy_usage_source = "COMFY_USAGE_SOURCE"
|
||||
"""COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header."""
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1660,8 +1792,6 @@ class Schema:
|
||||
self.hidden.append(Hidden.auth_token_comfy_org)
|
||||
if Hidden.api_key_comfy_org not in self.hidden:
|
||||
self.hidden.append(Hidden.api_key_comfy_org)
|
||||
if Hidden.comfy_usage_source not in self.hidden:
|
||||
self.hidden.append(Hidden.comfy_usage_source)
|
||||
# if is an output_node, will need prompt and extra_pnginfo
|
||||
if self.is_output_node:
|
||||
if Hidden.prompt not in self.hidden:
|
||||
@ -1735,6 +1865,7 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
"optional": {},
|
||||
"dynamic_paths": {},
|
||||
"dynamic_paths_default_value": {},
|
||||
"list_paths": set(),
|
||||
}
|
||||
d = d.copy()
|
||||
# ignore hidden for parsing
|
||||
@ -1750,6 +1881,10 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None)
|
||||
if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0:
|
||||
v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value
|
||||
# list_paths: keys whose nested dict should be post-converted to a sorted list[dict]
|
||||
list_paths = out_dict.pop("list_paths", None)
|
||||
if list_paths:
|
||||
v3_data["list_paths"] = list_paths
|
||||
return out_dict, hidden, v3_data
|
||||
|
||||
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
||||
@ -1785,10 +1920,12 @@ def add_to_dict_v1(i: Input, d: dict):
|
||||
|
||||
class DynamicPathsDefaultValue:
|
||||
EMPTY_DICT = "empty_dict"
|
||||
EMPTY_LIST = "empty_list"
|
||||
|
||||
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
paths = v3_data.get("dynamic_paths", None)
|
||||
default_value_dict = v3_data.get("dynamic_paths_default_value", {})
|
||||
list_paths: set[str] = v3_data.get("list_paths", set()) or set()
|
||||
if paths is None:
|
||||
return values
|
||||
values = values.copy()
|
||||
@ -1811,6 +1948,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
default_option = default_value_dict.get(key, None)
|
||||
if default_option == DynamicPathsDefaultValue.EMPTY_DICT:
|
||||
value = {}
|
||||
elif default_option == DynamicPathsDefaultValue.EMPTY_LIST:
|
||||
value = []
|
||||
if create_tuple:
|
||||
value = (value, key)
|
||||
current[p] = value
|
||||
@ -1818,6 +1957,34 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
current = current.setdefault(p, {})
|
||||
|
||||
values.update(result)
|
||||
|
||||
# Post-pass: convert index-keyed dicts to sorted lists for io.List fields
|
||||
for list_path in list_paths:
|
||||
parts = list_path.split(".")
|
||||
# Navigate to the parent container, then convert the leaf
|
||||
container = values
|
||||
for part in parts[:-1]:
|
||||
if not isinstance(container, dict) or part not in container:
|
||||
container = None
|
||||
break
|
||||
container = container[part]
|
||||
if container is None:
|
||||
continue
|
||||
leaf_key = parts[-1]
|
||||
leaf = container.get(leaf_key, None)
|
||||
if isinstance(leaf, dict):
|
||||
try:
|
||||
sorted_rows = [leaf[k] for k in sorted(leaf.keys(), key=int)]
|
||||
container[leaf_key] = sorted_rows
|
||||
except (ValueError, TypeError):
|
||||
# Keys are not all integers; leave as-is
|
||||
pass
|
||||
elif isinstance(leaf, list):
|
||||
# Already a list (e.g. the EMPTY_LIST default was applied above)
|
||||
pass
|
||||
elif leaf is None:
|
||||
container[leaf_key] = []
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@ -2380,7 +2547,9 @@ __all__ = [
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
"DynamicCombo",
|
||||
"DynamicSlot",
|
||||
"Autogrow",
|
||||
"List",
|
||||
# Other classes
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
|
||||
9
comfy_api_nodes/apis/__init__.py
generated
9
comfy_api_nodes/apis/__init__.py
generated
@ -1310,6 +1310,13 @@ class KlingTaskStatus(str, Enum):
|
||||
failed = 'failed'
|
||||
|
||||
|
||||
class KlingTextToVideoModelName(str, Enum):
|
||||
kling_v1 = 'kling-v1'
|
||||
kling_v1_6 = 'kling-v1-6'
|
||||
kling_v2_1_master = 'kling-v2-1-master'
|
||||
kling_v2_5_turbo = 'kling-v2-5-turbo'
|
||||
|
||||
|
||||
class KlingVideoGenAspectRatio(str, Enum):
|
||||
field_16_9 = '16:9'
|
||||
field_9_16 = '9:16'
|
||||
@ -5172,7 +5179,7 @@ class KlingText2VideoRequest(BaseModel):
|
||||
duration: Optional[KlingVideoGenDuration] = '5'
|
||||
external_task_id: Optional[str] = Field(None, description='Customized Task ID')
|
||||
mode: Optional[KlingVideoGenMode] = 'std'
|
||||
model_name: Optional[str] = 'kling-v1'
|
||||
model_name: Optional[KlingTextToVideoModelName] = 'kling-v1'
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None, description='Negative text prompt', max_length=2500
|
||||
)
|
||||
|
||||
@ -67,6 +67,15 @@ class RunwayImageToVideoResponse(BaseModel):
|
||||
id: Optional[str] = Field(None, description='Task ID')
|
||||
|
||||
|
||||
class RunwayTaskStatusEnum(str, Enum):
|
||||
SUCCEEDED = 'SUCCEEDED'
|
||||
RUNNING = 'RUNNING'
|
||||
FAILED = 'FAILED'
|
||||
PENDING = 'PENDING'
|
||||
CANCELLED = 'CANCELLED'
|
||||
THROTTLED = 'THROTTLED'
|
||||
|
||||
|
||||
class RunwayTaskStatusResponse(BaseModel):
|
||||
createdAt: datetime = Field(..., description='Task creation timestamp')
|
||||
id: str = Field(..., description='Task ID')
|
||||
@ -77,7 +86,7 @@ class RunwayTaskStatusResponse(BaseModel):
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
status: str = Field(..., description="SUCCEEDED, RUNNING, FAILED, PENDING, CANCELLED or THROTTLED")
|
||||
status: RunwayTaskStatusEnum
|
||||
|
||||
|
||||
class Model4(str, Enum):
|
||||
@ -116,144 +125,3 @@ class RunwayTextToImageRequest(BaseModel):
|
||||
|
||||
class RunwayTextToImageResponse(BaseModel):
|
||||
id: Optional[str] = Field(None, description='Task ID')
|
||||
|
||||
|
||||
class RunwayAleph2IO:
|
||||
"""Custom socket types for chaining Aleph2 guidance images."""
|
||||
|
||||
KEYFRAME = "RUNWAY_ALEPH2_KEYFRAME"
|
||||
PROMPT_IMAGE = "RUNWAY_ALEPH2_PROMPT_IMAGE"
|
||||
|
||||
|
||||
# Keyframe timing modes (anchored to the INPUT video). Stored on the chain item and used to
|
||||
# choose the request model below. The values match the Aleph2 keyframe union field names.
|
||||
KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the input video
|
||||
KEYFRAME_MODE_AT = "at" # fraction [0.0, 1.0] of the input video duration
|
||||
|
||||
# Prompt-image position modes (anchored to the OUTPUT video). Values match the Aleph2 position `type`.
|
||||
PROMPT_IMAGE_MODE_TIMESTAMP = "timestamp" # absolute time, in seconds, from the start of the output video
|
||||
PROMPT_IMAGE_MODE_POSITION = "position" # fraction [0.0, 1.0] of the output video duration
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeItem:
|
||||
"""A guidance image anchored to a point of the INPUT video (one Aleph2 ``keyframe``)."""
|
||||
|
||||
def __init__(self, image, mode: str, value: float):
|
||||
self.image = image
|
||||
self.mode = mode # KEYFRAME_MODE_SECONDS | KEYFRAME_MODE_AT
|
||||
self.value = value
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeChain:
|
||||
"""An ordered collection of keyframes, built by chaining Runway Aleph2 Keyframe nodes."""
|
||||
|
||||
def __init__(self):
|
||||
self.items: list[RunwayAleph2KeyframeItem] = []
|
||||
|
||||
def add(self, item: RunwayAleph2KeyframeItem) -> None:
|
||||
self.items.append(item)
|
||||
|
||||
def clone(self) -> "RunwayAleph2KeyframeChain":
|
||||
c = RunwayAleph2KeyframeChain()
|
||||
c.items = list(self.items)
|
||||
return c
|
||||
|
||||
|
||||
class RunwayAleph2PromptImageItem:
|
||||
"""A guidance image anchored to a point of the OUTPUT video (one Aleph2 ``promptImage``)."""
|
||||
|
||||
def __init__(self, image, mode: str, value: float):
|
||||
self.image = image
|
||||
self.mode = mode # PROMPT_IMAGE_MODE_TIMESTAMP | PROMPT_IMAGE_MODE_POSITION
|
||||
self.value = value
|
||||
|
||||
|
||||
class RunwayAleph2PromptImageChain:
|
||||
"""An ordered collection of prompt images, built by chaining Runway Aleph2 Prompt Image nodes."""
|
||||
|
||||
def __init__(self):
|
||||
self.items: list[RunwayAleph2PromptImageItem] = []
|
||||
|
||||
def add(self, item: RunwayAleph2PromptImageItem) -> None:
|
||||
self.items.append(item)
|
||||
|
||||
def clone(self) -> "RunwayAleph2PromptImageChain":
|
||||
c = RunwayAleph2PromptImageChain()
|
||||
c.items = list(self.items)
|
||||
return c
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeSeconds(BaseModel):
|
||||
seconds: float = Field(
|
||||
...,
|
||||
description="Absolute timestamp in seconds from the start of the input video when this guidance image should apply.",
|
||||
ge=0.0,
|
||||
)
|
||||
uri: str = Field(...)
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeAt(BaseModel):
|
||||
at: float = Field(
|
||||
...,
|
||||
description="Position as a fraction [0.0, 1.0] of the input video duration.",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
uri: str = Field(...)
|
||||
|
||||
|
||||
class RunwayAleph2TimestampPosition(BaseModel):
|
||||
type: str = Field(default="timestamp")
|
||||
timestampSeconds: float = Field(
|
||||
...,
|
||||
description="Absolute timestamp in seconds from the start of the output video.",
|
||||
ge=0.0,
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2RelativePosition(BaseModel):
|
||||
type: str = Field(default="position")
|
||||
positionPercentage: float = Field(
|
||||
...,
|
||||
description="Position as a fraction [0.0, 1.0] of the total output video duration.",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2PromptImage(BaseModel):
|
||||
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
|
||||
uri: str = Field(...)
|
||||
|
||||
|
||||
class RunwayAleph2ContentModeration(BaseModel):
|
||||
publicFigureThreshold: str = Field(
|
||||
...,
|
||||
description='When set to "low", the content moderation system is less strict about '
|
||||
'recognizable public figures. One of "auto" or "low".',
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2Request(BaseModel):
|
||||
model: str = Field(default="aleph2")
|
||||
promptText: str = Field(
|
||||
...,
|
||||
description="A non-empty string describing what should appear in the output.",
|
||||
min_length=1,
|
||||
max_length=1000,
|
||||
)
|
||||
videoUri: str = Field(...)
|
||||
seed: int = Field(..., description="Random seed for generation", ge=0, le=4294967295)
|
||||
contentModeration: RunwayAleph2ContentModeration = Field(...)
|
||||
keyframes: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] | None = Field(
|
||||
None,
|
||||
description="Timed guidance images placed at specific points in the input video. Up to 5.",
|
||||
)
|
||||
promptImage: list[RunwayAleph2PromptImage] | None = Field(
|
||||
None,
|
||||
description="Up to 5 image keyframes for guiding the edit at specific points in the output video.",
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2Response(BaseModel):
|
||||
id: str | None = Field(None, description="Task ID")
|
||||
|
||||
@ -208,10 +208,6 @@ class TripoMultiviewToModelRequest(BaseModel):
|
||||
quad: bool | None = Field(False, description="Whether to apply quad to the generated model")
|
||||
|
||||
|
||||
class TripoTexturePrompt(BaseModel):
|
||||
text: str | None = Field(None, description="Text guidance for texture generation")
|
||||
|
||||
|
||||
class TripoTextureModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description="Type of task")
|
||||
original_model_task_id: str = Field(..., description="The task ID of the original model")
|
||||
@ -223,11 +219,6 @@ class TripoTextureModelRequest(BaseModel):
|
||||
texture_alignment: TripoTextureAlignment | None = Field(
|
||||
TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method"
|
||||
)
|
||||
texture_prompt: TripoTexturePrompt | None = Field(
|
||||
None,
|
||||
description="Optional guidance for texturing. Required in practice for imported models, "
|
||||
"which carry no source image to infer texture from.",
|
||||
)
|
||||
|
||||
|
||||
class TripoRefineModelRequest(BaseModel):
|
||||
@ -316,17 +307,6 @@ class TripoP1MultiviewToModelRequest(TripoP1CommonRequest):
|
||||
orientation: str | None = None
|
||||
|
||||
|
||||
class TripoImportModelRequest(BaseModel):
|
||||
"""Request for the comfy-api composite import endpoint (/proxy/tripo/v2/openapi/import).
|
||||
|
||||
The model file is uploaded to ComfyUI API storage first; the backend downloads it from
|
||||
`url`, re-uploads it to Tripo's storage and creates the import_model task server-side.
|
||||
"""
|
||||
|
||||
url: str = Field(..., description="ComfyUI API storage download URL of the model file")
|
||||
format: str = Field(..., description='File format: "glb", "fbx", "obj" or "stl"')
|
||||
|
||||
|
||||
class TripoTaskOutput(BaseModel):
|
||||
model: str | None = Field(None, description="URL to the model")
|
||||
base_model: str | None = Field(None, description="URL to the base model")
|
||||
|
||||
@ -289,7 +289,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@ -357,7 +357,7 @@ class BriaVideoGreenScreen(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@ -433,7 +433,7 @@ class BriaVideoReplaceBackground(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@ -452,10 +452,7 @@ class BriaVideoReplaceBackground(IO.ComfyNode):
|
||||
validate_video_duration(background_video, max_duration=60.0)
|
||||
background_url = await upload_video_to_comfyapi(cls, background_video, wait_label="Uploading background")
|
||||
else:
|
||||
# Bria's replace_background 500s on RGBA, so drop the alpha channel before upload.
|
||||
background_url = await upload_image_to_comfyapi(
|
||||
cls, background_image[:, :, :, :3], wait_label="Uploading background"
|
||||
)
|
||||
background_url = await upload_image_to_comfyapi(cls, background_image, wait_label="Uploading background")
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bria/v2/video/edit/replace_background", method="POST"),
|
||||
@ -533,7 +530,7 @@ class BriaTransparentVideoBackground(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@ -574,7 +571,7 @@ class BriaExtension(ComfyExtension):
|
||||
BriaRemoveImageBackground,
|
||||
BriaRemoveVideoBackground,
|
||||
BriaVideoGreenScreen,
|
||||
BriaVideoReplaceBackground,
|
||||
# BriaVideoReplaceBackground, # server returns Status 500 when we pass background video
|
||||
BriaTransparentVideoBackground,
|
||||
]
|
||||
|
||||
|
||||
@ -436,7 +436,7 @@ async def execute_text2video(
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
duration=KlingVideoGenDuration(duration),
|
||||
mode=KlingVideoGenMode(model_mode),
|
||||
model_name=model_name,
|
||||
model_name=KlingVideoGenModelName(model_name),
|
||||
cfg_scale=cfg_scale,
|
||||
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
|
||||
camera_control=camera_control,
|
||||
|
||||
@ -30,33 +30,13 @@ from comfy_api_nodes.apis.runway import (
|
||||
Model4,
|
||||
ReferenceImage,
|
||||
RunwayTextToImageAspectRatioEnum,
|
||||
RunwayAleph2IO,
|
||||
RunwayAleph2KeyframeChain,
|
||||
RunwayAleph2KeyframeItem,
|
||||
RunwayAleph2PromptImageChain,
|
||||
RunwayAleph2PromptImageItem,
|
||||
RunwayAleph2Request,
|
||||
RunwayAleph2Response,
|
||||
RunwayAleph2KeyframeSeconds,
|
||||
RunwayAleph2KeyframeAt,
|
||||
RunwayAleph2PromptImage,
|
||||
RunwayAleph2TimestampPosition,
|
||||
RunwayAleph2RelativePosition,
|
||||
RunwayAleph2ContentModeration,
|
||||
KEYFRAME_MODE_SECONDS,
|
||||
KEYFRAME_MODE_AT,
|
||||
PROMPT_IMAGE_MODE_TIMESTAMP,
|
||||
PROMPT_IMAGE_MODE_POSITION,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
image_tensor_pair_to_batch,
|
||||
validate_string,
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio,
|
||||
validate_video_duration,
|
||||
upload_images_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
download_url_to_image_tensor,
|
||||
ApiEndpoint,
|
||||
@ -65,7 +45,6 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_VIDEO_TO_VIDEO = "/proxy/runway/video_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
|
||||
|
||||
@ -74,6 +53,12 @@ AVERAGE_DURATION_FLF_SECONDS = 256
|
||||
AVERAGE_DURATION_T2I_SECONDS = 41
|
||||
|
||||
|
||||
class RunwayApiError(Exception):
|
||||
"""Base exception for Runway API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RunwayGen4TurboAspectRatio(str, Enum):
|
||||
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
|
||||
|
||||
@ -99,6 +84,14 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def extract_progress_from_task_status(
|
||||
response: TaskStatusResponse,
|
||||
) -> float | None:
|
||||
if hasattr(response, "progress") and response.progress is not None:
|
||||
return response.progress * 100
|
||||
return None
|
||||
|
||||
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
"""Returns the image URL from the task status response if it exists."""
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
@ -109,13 +102,14 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
async def get_response(
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
status_extractor=lambda r: r.status.value,
|
||||
estimated_duration=estimated_duration,
|
||||
progress_extractor=lambda r: r.progress * 100 if r.progress is not None else None,
|
||||
progress_extractor=extract_progress_from_task_status,
|
||||
)
|
||||
|
||||
|
||||
@ -133,7 +127,7 @@ async def generate_video(
|
||||
|
||||
final_response = await get_response(cls, initial_response.id, estimated_duration)
|
||||
if not final_response.output:
|
||||
raise ValueError("Runway task succeeded but no video data found in response.")
|
||||
raise RunwayApiError("Runway task succeeded but no video data found in response.")
|
||||
|
||||
video_url = get_video_url_from_task_status(final_response)
|
||||
return await download_url_to_video_output(video_url)
|
||||
@ -416,7 +410,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
mime_type="image/png",
|
||||
)
|
||||
if len(download_urls) != 2:
|
||||
raise ValueError("Failed to upload one or more images to comfy api.")
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return IO.NodeOutput(
|
||||
await generate_video(
|
||||
@ -520,321 +514,11 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||
)
|
||||
if not final_response.output:
|
||||
raise ValueError("Runway task succeeded but no image data found in response.")
|
||||
raise RunwayApiError("Runway task succeeded but no image data found in response.")
|
||||
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
|
||||
|
||||
|
||||
_TIMING_ABSOLUTE = "Absolute time (seconds)"
|
||||
_TIMING_FRACTION = "Fraction of duration (0.0-1.0)"
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RunwayAleph2KeyframeNode",
|
||||
display_name="Runway Aleph2 Keyframe",
|
||||
category="partner/video/Runway",
|
||||
description="Anchor a guidance image to a moment of the input (source) video, so Aleph2 "
|
||||
"steers the edit at that point of your footage. Connect this to the 'keyframes' input of "
|
||||
"the Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
|
||||
"'keyframes' input below.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The guidance image to apply at the chosen moment of the input video.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"timing",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_ABSOLUTE,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"seconds",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=30.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Time in seconds from start of the input video where this image applies.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_FRACTION,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"fraction",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Where in the input video this image applies, "
|
||||
"as a fraction of its duration (0.0 = start, 1.0 = end).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="How to place this image on the input video's timeline.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
|
||||
"keyframes",
|
||||
optional=True,
|
||||
tooltip="Optional earlier keyframes to chain with this one.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(RunwayAleph2IO.KEYFRAME).Output(display_name="keyframes")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
timing: dict,
|
||||
keyframes: RunwayAleph2KeyframeChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
chain = keyframes.clone() if keyframes is not None else RunwayAleph2KeyframeChain()
|
||||
if timing["timing"] == _TIMING_ABSOLUTE:
|
||||
mode, value = KEYFRAME_MODE_SECONDS, float(timing["seconds"])
|
||||
else:
|
||||
mode, value = KEYFRAME_MODE_AT, float(timing["fraction"])
|
||||
chain.add(RunwayAleph2KeyframeItem(image=image, mode=mode, value=value))
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class RunwayAleph2PromptImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RunwayAleph2PromptImageNode",
|
||||
display_name="Runway Aleph2 Prompt Image",
|
||||
category="partner/video/Runway",
|
||||
description="Anchor a guidance image to a moment of the output (result) video, to guide what "
|
||||
"the edited video looks like at that point. Connect this to the 'prompt_images' input of the "
|
||||
"Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
|
||||
"'prompt_images' input below.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The guidance image to place at the chosen moment of the output video.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"position",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_ABSOLUTE,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"seconds",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=30.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Time in seconds from start of the output video where this image applies.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_FRACTION,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"fraction",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Where in the output video this image applies, "
|
||||
"as a fraction of its duration (0.0 = start, 1.0 = end).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="How to place this image on the output video's timeline.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
|
||||
"prompt_images",
|
||||
optional=True,
|
||||
tooltip="Optional earlier prompt images to chain with this one.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Output(display_name="prompt_images")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
position: dict,
|
||||
prompt_images: RunwayAleph2PromptImageChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
chain = prompt_images.clone() if prompt_images is not None else RunwayAleph2PromptImageChain()
|
||||
if position["position"] == _TIMING_ABSOLUTE:
|
||||
mode, value = PROMPT_IMAGE_MODE_TIMESTAMP, float(position["seconds"])
|
||||
else:
|
||||
mode, value = PROMPT_IMAGE_MODE_POSITION, float(position["fraction"])
|
||||
chain.add(RunwayAleph2PromptImageItem(image=image, mode=mode, value=value))
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class RunwayAleph2VideoToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RunwayAleph2VideoToVideoNode",
|
||||
display_name="Runway Aleph2 Video to Video",
|
||||
category="partner/video/Runway",
|
||||
description="Edit a video with a text prompt using Runway's Aleph2 model. Aleph2 transforms "
|
||||
"your footage (restyle, relight, add or remove elements, change the viewpoint) while keeping "
|
||||
"the original motion and timing; the output resolution matches the input video, which must be "
|
||||
"2-30 seconds at 30 fps or lower. Optionally steer the edit with either keyframes (anchored to "
|
||||
"the input video) or prompt images (anchored to the output video) - use one or the other, not both.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Describes what should appear in the output (1-1000 characters).",
|
||||
),
|
||||
IO.Video.Input(
|
||||
"video",
|
||||
tooltip="Input video to edit. Must be 2-30 seconds at 30 fps or lower.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Random seed for generation",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"public_figure_threshold",
|
||||
options=["auto", "low"],
|
||||
default="low",
|
||||
tooltip="Content moderation for recognizable public figures.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
|
||||
"keyframes",
|
||||
optional=True,
|
||||
tooltip="Guidance images anchored to the input video, from Aleph2 Keyframe nodes (up to 5). "
|
||||
"Use keyframes or prompt images, not both.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
|
||||
"prompt_images",
|
||||
optional=True,
|
||||
tooltip="Guidance images anchored to the output video, from Aleph2 Prompt Image nodes (up to 5). "
|
||||
"Use keyframes or prompt images, not both.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.4004, "format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
video: Input.Video,
|
||||
seed: int,
|
||||
public_figure_threshold: str = "low",
|
||||
keyframes: RunwayAleph2KeyframeChain | None = None,
|
||||
prompt_images: RunwayAleph2PromptImageChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=1000)
|
||||
validate_video_duration(
|
||||
video,
|
||||
min_duration=2.0,
|
||||
max_duration=30.0,
|
||||
)
|
||||
try:
|
||||
fps = float(video.get_frame_rate())
|
||||
except Exception:
|
||||
fps = None
|
||||
if fps is not None and fps > 30.0 + 0.01:
|
||||
raise ValueError(f"Input video frame rate ({fps:.2f} fps) exceeds Aleph2's maximum of 30 fps.")
|
||||
|
||||
if (keyframes and keyframes.items) and (prompt_images and prompt_images.items):
|
||||
raise ValueError("Aleph2 accepts either keyframes or prompt images, not both.")
|
||||
|
||||
video_duration: float | None = None
|
||||
try:
|
||||
video_duration = video.get_duration()
|
||||
except Exception:
|
||||
video_duration = None
|
||||
|
||||
def _check_seconds(value: float, label: str) -> None:
|
||||
if video_duration is not None and value > video_duration + 0.0001:
|
||||
raise ValueError(f"{label} {value:.2f}s exceeds the input video duration ({video_duration:.2f}s).")
|
||||
|
||||
video_url = await upload_video_to_comfyapi(cls, video)
|
||||
|
||||
keyframe_models: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] = []
|
||||
if keyframes is not None:
|
||||
if len(keyframes.items) > 5:
|
||||
raise ValueError("Aleph2 supports at most 5 keyframes.")
|
||||
for item in keyframes.items:
|
||||
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
|
||||
if item.mode == KEYFRAME_MODE_SECONDS:
|
||||
_check_seconds(item.value, "Keyframe timestamp")
|
||||
keyframe_models.append(RunwayAleph2KeyframeSeconds(seconds=item.value, uri=image_url))
|
||||
else:
|
||||
keyframe_models.append(RunwayAleph2KeyframeAt(at=item.value, uri=image_url))
|
||||
|
||||
prompt_image_models: list[RunwayAleph2PromptImage] = []
|
||||
if prompt_images is not None:
|
||||
if len(prompt_images.items) > 5:
|
||||
raise ValueError("Aleph2 supports at most 5 prompt images.")
|
||||
for item in prompt_images.items:
|
||||
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
|
||||
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
|
||||
if item.mode == PROMPT_IMAGE_MODE_TIMESTAMP:
|
||||
_check_seconds(item.value, "Prompt image timestamp")
|
||||
position = RunwayAleph2TimestampPosition(timestampSeconds=item.value)
|
||||
else:
|
||||
position = RunwayAleph2RelativePosition(positionPercentage=item.value)
|
||||
prompt_image_models.append(RunwayAleph2PromptImage(position=position, uri=image_url))
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_VIDEO_TO_VIDEO, method="POST"),
|
||||
response_model=RunwayAleph2Response,
|
||||
data=RunwayAleph2Request(
|
||||
promptText=prompt,
|
||||
videoUri=video_url,
|
||||
seed=seed,
|
||||
contentModeration=RunwayAleph2ContentModeration(publicFigureThreshold=public_figure_threshold),
|
||||
keyframes=keyframe_models or None,
|
||||
promptImage=prompt_image_models or None,
|
||||
),
|
||||
)
|
||||
|
||||
final_response = await get_response(cls, initial_response.id)
|
||||
if not final_response.output:
|
||||
raise ValueError("Runway task succeeded but no video data found in response.")
|
||||
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(final_response)))
|
||||
|
||||
|
||||
class RunwayExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -843,9 +527,6 @@ class RunwayExtension(ComfyExtension):
|
||||
RunwayImageToVideoNodeGen3a,
|
||||
RunwayImageToVideoNodeGen4,
|
||||
RunwayTextToImageNode,
|
||||
RunwayAleph2VideoToVideoNode,
|
||||
RunwayAleph2KeyframeNode,
|
||||
RunwayAleph2PromptImageNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
from comfy_api_nodes.util._helpers import (
|
||||
default_base_url,
|
||||
get_comfy_api_headers,
|
||||
get_auth_header,
|
||||
get_node_id,
|
||||
is_processing_interrupted,
|
||||
)
|
||||
@ -111,10 +111,11 @@ class SoniloTextToMusic(IO.ComfyNode):
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=30,
|
||||
min=1,
|
||||
default=0,
|
||||
min=0,
|
||||
max=360,
|
||||
tooltip="Target duration in seconds. Maximum: 6 minutes.",
|
||||
tooltip="Target duration in seconds. Set to 0 to let the model "
|
||||
"infer the duration from the prompt. Maximum: 6 minutes.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
@ -149,13 +150,14 @@ class SoniloTextToMusic(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
duration: int = 1,
|
||||
duration: int = 0,
|
||||
seed: int = 0,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=1000)
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("prompt", prompt)
|
||||
form.add_field("duration", str(duration))
|
||||
if duration > 0:
|
||||
form.add_field("duration", str(duration))
|
||||
audio_bytes = await _stream_sonilo_music(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/sonilo/t2m/generate", method="POST"),
|
||||
@ -172,7 +174,8 @@ async def _stream_sonilo_music(
|
||||
"""POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes."""
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/"))
|
||||
|
||||
headers = get_comfy_api_headers(cls)
|
||||
headers: dict[str, str] = {}
|
||||
headers.update(get_auth_header(cls))
|
||||
headers.update(endpoint.headers)
|
||||
|
||||
node_id = get_node_id(cls)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.tripo import (
|
||||
TripoAnimateRetargetRequest,
|
||||
TripoAnimateRigRequest,
|
||||
@ -8,7 +8,6 @@ from comfy_api_nodes.apis.tripo import (
|
||||
TripoFileEmptyReference,
|
||||
TripoFileReference,
|
||||
TripoImageToModelRequest,
|
||||
TripoImportModelRequest,
|
||||
TripoModelVersion,
|
||||
TripoMultiviewToModelRequest,
|
||||
TripoOrientation,
|
||||
@ -22,7 +21,6 @@ from comfy_api_nodes.apis.tripo import (
|
||||
TripoTaskType,
|
||||
TripoTextToModelRequest,
|
||||
TripoTextureModelRequest,
|
||||
TripoTexturePrompt,
|
||||
TripoUrlReference,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
@ -30,7 +28,6 @@ from comfy_api_nodes.util import (
|
||||
download_url_to_file_3d,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_3d_model_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
)
|
||||
|
||||
@ -541,14 +538,6 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"texture_prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
optional=True,
|
||||
tooltip="Optional text guidance for texturing. Required in practice for imported "
|
||||
"models (Tripo: Import Model), which carry no source image to infer colors from.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
@ -582,7 +571,6 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
texture_alignment: str | None = None,
|
||||
texture_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
@ -595,7 +583,6 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
texture_alignment=texture_alignment,
|
||||
texture_prompt=TripoTexturePrompt(text=texture_prompt.strip()) if texture_prompt.strip() else None,
|
||||
),
|
||||
)
|
||||
return await poll_until_finished(cls, response, average_duration=80)
|
||||
@ -928,90 +915,6 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
return await poll_until_finished(cls, response, average_duration=30)
|
||||
|
||||
|
||||
class TripoImportModelNode(IO.ComfyNode):
|
||||
"""Imports an external 3D model into Tripo, producing a MODEL_TASK_ID for post-processing nodes."""
|
||||
|
||||
SUPPORTED_FORMATS = ("glb", "fbx", "obj", "stl")
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TripoImportModelNode",
|
||||
display_name="Tripo: Import Model",
|
||||
category="partner/3d/Tripo",
|
||||
description="Import an external 3D model (e.g. from Rodin, Hunyuan3D or a local file) into Tripo "
|
||||
"to use it with Tripo's post-processing nodes: Texture, Rig, Convert. "
|
||||
"GLB is recommended: textures survive import only when embedded in the file. "
|
||||
"Note that texturing an imported model requires a texture prompt.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[IO.File3DGLB, IO.File3DFBX, IO.File3DOBJ, IO.File3DSTL, IO.File3DAny],
|
||||
tooltip="3D model to import (GLB / FBX / OBJ / STL, up to 150 MB). "
|
||||
"OBJ and STL files carry no embedded textures.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"text","text":"Free"}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, model_3d: Types.File3D) -> IO.NodeOutput:
|
||||
file_format = (model_3d.format or "").lstrip(".").lower()
|
||||
if file_format == "gltf":
|
||||
raise ValueError(
|
||||
"GLTF (.gltf) references external files and cannot be imported. Export a single-file GLB instead."
|
||||
)
|
||||
if file_format not in cls.SUPPORTED_FORMATS:
|
||||
raise ValueError(
|
||||
f"Unsupported 3D format '{file_format or 'unknown'}'. "
|
||||
f"Tripo import supports: {', '.join(f.upper() for f in cls.SUPPORTED_FORMATS)}."
|
||||
)
|
||||
size = len(model_3d.get_bytes())
|
||||
if size > 150 * 1024 * 1024:
|
||||
raise ValueError(f"Model file is {size / (1024 * 1024):.1f} MB; Tripo import allows up to 150 MB.")
|
||||
|
||||
url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/import", method="POST"),
|
||||
response_model=TripoTaskResponse,
|
||||
data=TripoImportModelRequest(url=url, format=file_format),
|
||||
)
|
||||
if response.code != 0:
|
||||
raise RuntimeError(f"Failed to import model: {response.error}")
|
||||
|
||||
task_id = response.data.task_id
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
poll_endpoint=ApiEndpoint(path=f"/proxy/tripo/v2/openapi/task/{task_id}"),
|
||||
response_model=TripoTaskResponse,
|
||||
failed_statuses=[
|
||||
TripoTaskStatus.FAILED,
|
||||
TripoTaskStatus.CANCELLED,
|
||||
TripoTaskStatus.UNKNOWN,
|
||||
TripoTaskStatus.BANNED,
|
||||
TripoTaskStatus.EXPIRED,
|
||||
],
|
||||
status_extractor=lambda x: x.data.status,
|
||||
progress_extractor=lambda x: x.data.progress,
|
||||
estimated_duration=10,
|
||||
)
|
||||
if response_poll.data.status != TripoTaskStatus.SUCCESS:
|
||||
raise RuntimeError(f"Failed to import model: {response_poll}")
|
||||
return IO.NodeOutput(task_id)
|
||||
|
||||
|
||||
def _p1_price_expr(*, geometry_credits: int, textured_credits: int, detailed_credits: int) -> str:
|
||||
return (
|
||||
"("
|
||||
@ -1389,7 +1292,6 @@ class TripoExtension(ComfyExtension):
|
||||
TripoP1TextToModelNode,
|
||||
TripoP1ImageToModelNode,
|
||||
TripoP1MultiviewToModelNode,
|
||||
TripoImportModelNode,
|
||||
TripoTextureNode,
|
||||
TripoRefineNode,
|
||||
TripoRigNode,
|
||||
|
||||
@ -9,7 +9,6 @@ from io import BytesIO
|
||||
from yarl import URL
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy.deploy_environment import get_deploy_environment
|
||||
from comfy.model_management import processing_interrupted
|
||||
from comfy_api.latest import IO
|
||||
|
||||
@ -36,30 +35,6 @@ def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
|
||||
def get_usage_source(node_cls: type[IO.ComfyNode]) -> str:
|
||||
"""Source of the prompt that triggered this API node.
|
||||
|
||||
Defaults to "comfyui-api" when the submitting client didn't identify itself,
|
||||
i.e. a direct API call to this server.
|
||||
"""
|
||||
return node_cls.hidden.comfy_usage_source or "comfyui-api"
|
||||
|
||||
|
||||
def get_comfy_api_headers(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
||||
"""Common headers (auth, deploy environment, usage source) for Comfy API requests.
|
||||
|
||||
Centralizes the shared header set so every Comfy API request sends a consistent
|
||||
set and new shared headers only need to be added in one place. Intended for
|
||||
relative/cloud URLs resolved against ``default_base_url()``; because the result
|
||||
includes auth, callers must not attach it to arbitrary absolute/presigned URLs.
|
||||
"""
|
||||
return {
|
||||
**get_auth_header(node_cls),
|
||||
"Comfy-Env": get_deploy_environment(),
|
||||
"Comfy-Usage-Source": get_usage_source(node_cls),
|
||||
}
|
||||
|
||||
|
||||
def default_base_url() -> str:
|
||||
return getattr(args, "comfy_api_base", "https://api.comfy.org")
|
||||
|
||||
|
||||
@ -19,10 +19,12 @@ from comfy import utils
|
||||
from comfy_api.latest import IO
|
||||
from server import PromptServer
|
||||
|
||||
from comfy.deploy_environment import get_deploy_environment
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
get_comfy_api_headers,
|
||||
get_auth_header,
|
||||
get_node_id,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
@ -643,7 +645,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
|
||||
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
payload_headers.update(get_comfy_api_headers(cfg.node_cls))
|
||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||
payload_headers["Comfy-Env"] = get_deploy_environment()
|
||||
if cfg.endpoint.headers:
|
||||
payload_headers.update(cfg.endpoint.headers)
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ from folder_paths import get_output_directory
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
get_comfy_api_headers,
|
||||
get_auth_header,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
to_aiohttp_url,
|
||||
@ -64,7 +64,7 @@ async def download_url_to_bytesio(
|
||||
if cls is None:
|
||||
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||
headers = get_comfy_api_headers(cls)
|
||||
headers = get_auth_header(cls)
|
||||
|
||||
while True:
|
||||
attempt += 1
|
||||
|
||||
@ -245,11 +245,6 @@ class KV_Attn_Input:
|
||||
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
|
||||
if cache_key in self.cache:
|
||||
kk, vv = self.cache[cache_key]
|
||||
|
||||
# Fix batch size changing.
|
||||
kk = comfy.utils.repeat_to_batch_size(kk, k.shape[0])
|
||||
vv = comfy.utils.repeat_to_batch_size(vv, v.shape[0])
|
||||
|
||||
self.set_cache = False
|
||||
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ class RTDETR_detect(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RTDETR_detect",
|
||||
display_name="Run Real-Time Detection (RT-DETR)",
|
||||
display_name="RT-DETR Detect",
|
||||
category="image/detection",
|
||||
search_aliases=["bbox", "bounding box", "object detection", "coco"],
|
||||
inputs=[
|
||||
|
||||
@ -264,7 +264,7 @@ class SAM3_VideoTrack(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SAM3_VideoTrack",
|
||||
display_name="Run SAM3 Video Track",
|
||||
display_name="SAM3 Video Track",
|
||||
category="image/detection",
|
||||
search_aliases=["sam3", "video", "track", "propagate"],
|
||||
inputs=[
|
||||
|
||||
@ -134,17 +134,6 @@ class CreateVideo(io.ComfyNode):
|
||||
io.Image.Input("images", tooltip="The images to create a video from."),
|
||||
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
||||
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
||||
io.Int.Input(
|
||||
"bit_depth",
|
||||
min=8,
|
||||
max=10,
|
||||
default=8,
|
||||
step=2,
|
||||
tooltip="Bit depth of the created video. 10-bit keeps smoother gradients with less"
|
||||
" banding, but some players and downstream nodes may not support it.",
|
||||
optional=True,
|
||||
display_mode=io.NumberDisplay.number,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(),
|
||||
@ -152,14 +141,9 @@ class CreateVideo(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None, bit_depth: int = 8,
|
||||
) -> io.NodeOutput:
|
||||
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
|
||||
return io.NodeOutput(
|
||||
InputImpl.VideoFromComponents(
|
||||
Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)),
|
||||
bit_depth=bit_depth,
|
||||
)
|
||||
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
||||
)
|
||||
|
||||
class GetVideoComponents(io.ComfyNode):
|
||||
@ -170,7 +154,7 @@ class GetVideoComponents(io.ComfyNode):
|
||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||
display_name="Get Video Components",
|
||||
category="video",
|
||||
description="Extracts all components from a video: frames, audio, framerate, and bit depth.",
|
||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to extract components from."),
|
||||
],
|
||||
@ -178,14 +162,13 @@ class GetVideoComponents(io.ComfyNode):
|
||||
io.Image.Output(display_name="images"),
|
||||
io.Audio.Output(display_name="audio"),
|
||||
io.Float.Output(display_name="fps"),
|
||||
io.Int.Output(display_name="bit_depth"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: Input.Video) -> io.NodeOutput:
|
||||
components = video.get_components()
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate), video.get_bit_depth())
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
||||
|
||||
|
||||
class LoadVideo(io.ComfyNode):
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.25.0"
|
||||
__version__ = "0.24.0"
|
||||
|
||||
@ -200,8 +200,6 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||
if io.Hidden.api_key_comfy_org.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||
if io.Hidden.comfy_usage_source.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.comfy_usage_source] = extra_data.get("comfy_usage_source", None)
|
||||
else:
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
@ -218,8 +216,6 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
if h[x] == "COMFY_USAGE_SOURCE":
|
||||
input_data_all[x] = [extra_data.get("comfy_usage_source", None)]
|
||||
v3_data["hidden_inputs"] = hidden_inputs_v3
|
||||
return input_data_all, missing_keys, v3_data
|
||||
|
||||
|
||||
45
main.py
45
main.py
@ -55,11 +55,7 @@ if __name__ == "__main__" and args.debug_hang:
|
||||
import comfy_aimdo.control
|
||||
|
||||
if enables_dynamic_vram():
|
||||
try:
|
||||
comfy_aimdo.control.init(simple_vram_headroom=None if args.reserve_vram is None else int(args.reserve_vram * 1024 ** 3))
|
||||
except TypeError:
|
||||
# comfy-aimdo 0.4.9 protocol.
|
||||
comfy_aimdo.control.init()
|
||||
comfy_aimdo.control.init()
|
||||
|
||||
if os.name == "nt":
|
||||
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
|
||||
@ -235,30 +231,23 @@ import comfy.model_patcher
|
||||
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
|
||||
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
|
||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
elif comfy_aimdo.control.init_devices(d.index for d in comfy.model_management.get_all_torch_devices()):
|
||||
if args.verbose == 'DEBUG':
|
||||
comfy_aimdo.control.set_log_debug()
|
||||
elif args.verbose == 'CRITICAL':
|
||||
comfy_aimdo.control.set_log_critical()
|
||||
elif args.verbose == 'ERROR':
|
||||
comfy_aimdo.control.set_log_error()
|
||||
elif args.verbose == 'WARNING':
|
||||
comfy_aimdo.control.set_log_warning()
|
||||
else: #INFO
|
||||
comfy_aimdo.control.set_log_info()
|
||||
|
||||
comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic
|
||||
comfy.memory_management.aimdo_enabled = True
|
||||
logging.info("DynamicVRAM support detected and enabled")
|
||||
else:
|
||||
try:
|
||||
aimdo_initialized = comfy_aimdo.control.init_devices((d.index, int(args.vram_headroom * 1024 ** 3)) for d in comfy.model_management.get_all_torch_devices())
|
||||
except TypeError:
|
||||
# comfy-aimdo 0.4.9 protocol.
|
||||
aimdo_initialized = comfy_aimdo.control.init_devices(d.index for d in comfy.model_management.get_all_torch_devices())
|
||||
|
||||
if aimdo_initialized:
|
||||
if args.verbose == 'DEBUG':
|
||||
comfy_aimdo.control.set_log_debug()
|
||||
elif args.verbose == 'CRITICAL':
|
||||
comfy_aimdo.control.set_log_critical()
|
||||
elif args.verbose == 'ERROR':
|
||||
comfy_aimdo.control.set_log_error()
|
||||
elif args.verbose == 'WARNING':
|
||||
comfy_aimdo.control.set_log_warning()
|
||||
else: #INFO
|
||||
comfy_aimdo.control.set_log_info()
|
||||
|
||||
comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic
|
||||
comfy.memory_management.aimdo_enabled = True
|
||||
logging.info("DynamicVRAM support detected and enabled")
|
||||
else:
|
||||
logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
|
||||
|
||||
def cuda_malloc_warning():
|
||||
|
||||
@ -1 +1 @@
|
||||
comfyui_manager==4.2.2
|
||||
comfyui_manager==4.2.1
|
||||
|
||||
20
openapi.yaml
20
openapi.yaml
@ -896,6 +896,11 @@ components:
|
||||
additionalProperties: true
|
||||
description: The workflow graph to execute
|
||||
type: object
|
||||
prompt_id:
|
||||
description: Optional client-supplied job id. Must be a UUID in canonical lowercase hyphenated form; it is echoed back in the response. Omitted or null means the server generates one.
|
||||
format: uuid
|
||||
nullable: true
|
||||
type: string
|
||||
workflow_id:
|
||||
description: UUID identifying the cloud workflow entity to associate with this job
|
||||
type: string
|
||||
@ -1795,9 +1800,7 @@ paths:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: |
|
||||
Invalid request — no fields provided, or `preview_id` is the zero UUID
|
||||
(`INVALID_PREVIEW_ID`).
|
||||
description: Invalid request (no fields provided)
|
||||
"401":
|
||||
content:
|
||||
application/json:
|
||||
@ -1809,10 +1812,7 @@ paths:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: |
|
||||
Asset not found — returned both when the asset being updated does
|
||||
not exist and when `preview_id` does not reference an asset
|
||||
accessible to the caller.
|
||||
description: Asset not found
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
@ -3050,12 +3050,6 @@ paths:
|
||||
schema:
|
||||
$ref: '#/components/schemas/PromptErrorResponse'
|
||||
description: Payment required - Insufficient credits
|
||||
"413":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/PromptErrorResponse'
|
||||
description: Workflow JSON too large
|
||||
"429":
|
||||
content:
|
||||
application/json:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.25.0"
|
||||
version = "0.24.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.45.15
|
||||
comfyui-workflow-templates==0.10.0
|
||||
comfyui-embedded-docs==0.5.4
|
||||
comfyui-workflow-templates==0.9.98
|
||||
comfyui-embedded-docs==0.5.3
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=16.0.0
|
||||
comfy-kitchen==0.2.10
|
||||
comfy-aimdo==0.4.10
|
||||
comfy-aimdo==0.4.9
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
blake3
|
||||
|
||||
@ -27,7 +27,6 @@ import logging
|
||||
|
||||
import mimetypes
|
||||
from comfy.cli_args import args
|
||||
from comfy.deploy_environment import get_deploy_environment
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy_api import feature_flags
|
||||
@ -691,7 +690,6 @@ class PromptServer():
|
||||
"python_version": sys.version,
|
||||
"pytorch_version": comfy.model_management.torch_version,
|
||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||
"deploy_environment": get_deploy_environment(),
|
||||
"argv": sys.argv
|
||||
},
|
||||
"devices": device_entries
|
||||
@ -973,11 +971,6 @@ class PromptServer():
|
||||
|
||||
if "client_id" in json_data:
|
||||
extra_data["client_id"] = json_data["client_id"]
|
||||
|
||||
if "comfy_usage_source" not in extra_data:
|
||||
usage_source = request.headers.get("Comfy-Usage-Source")
|
||||
if usage_source:
|
||||
extra_data["comfy_usage_source"] = usage_source
|
||||
if valid[0]:
|
||||
outputs_to_execute = valid[2]
|
||||
sensitive = {}
|
||||
|
||||
204
tests-unit/comfy_api_test/io_list_test.py
Normal file
204
tests-unit/comfy_api_test/io_list_test.py
Normal file
@ -0,0 +1,204 @@
|
||||
"""Unit tests for io.List: expansion/reconstruction (0-row and N-row cases)."""
|
||||
import sys
|
||||
import types
|
||||
import pytest
|
||||
|
||||
# Stub torch (type-hint only in _io.py; real torch not available in unit-test env)
|
||||
if "torch" not in sys.modules:
|
||||
_torch_stub = types.ModuleType("torch")
|
||||
_torch_stub.Tensor = object # type: ignore[attr-defined]
|
||||
sys.modules["torch"] = _torch_stub
|
||||
|
||||
from comfy_api.latest._io import ( # noqa: E402
|
||||
List,
|
||||
Float,
|
||||
Int,
|
||||
String,
|
||||
Boolean,
|
||||
get_finalized_class_inputs,
|
||||
build_nested_inputs,
|
||||
create_input_dict_v1,
|
||||
setup_dynamic_input_funcs,
|
||||
)
|
||||
|
||||
# Make sure dynamic input funcs are registered (may already be done at import time)
|
||||
setup_dynamic_input_funcs()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_class_inputs(list_input: List.Input) -> dict:
|
||||
"""Wrap a List.Input into the required/optional dict structure."""
|
||||
return create_input_dict_v1([list_input])
|
||||
|
||||
|
||||
def _run(list_input: List.Input, live_values: dict) -> dict:
|
||||
"""End-to-end helper: expand schema + reconstruct values.
|
||||
|
||||
Mirrors the production split in execution.py:
|
||||
1. get_finalized_class_inputs (schema expansion, line 162)
|
||||
2. build_nested_inputs (value reconstruction, line 281)
|
||||
|
||||
The two steps are separate in production because the engine resolves
|
||||
linked node outputs between them, but in tests we supply values directly.
|
||||
"""
|
||||
class_inputs = _make_class_inputs(list_input)
|
||||
_, _, v3_data = get_finalized_class_inputs(class_inputs, live_values)
|
||||
return build_nested_inputs(dict(live_values), v3_data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListInputConstruction:
|
||||
def test_basic_construction(self):
|
||||
inp = List.Input(
|
||||
"loras",
|
||||
template=[
|
||||
Float.Input("strength", default=1.0),
|
||||
String.Input("name"),
|
||||
],
|
||||
min=0,
|
||||
max=10,
|
||||
)
|
||||
assert inp.id == "loras"
|
||||
assert inp.min == 0
|
||||
assert inp.max == 10
|
||||
assert len(inp.template) == 2
|
||||
|
||||
def test_get_all_includes_self_and_template(self):
|
||||
inp = List.Input(
|
||||
"items",
|
||||
template=[Float.Input("value")],
|
||||
)
|
||||
all_inputs = inp.get_all()
|
||||
assert all_inputs[0] is inp
|
||||
assert all_inputs[1].id == "value"
|
||||
|
||||
def test_as_dict_has_template_min_max(self):
|
||||
inp = List.Input(
|
||||
"items",
|
||||
template=[Float.Input("val", default=0.5)],
|
||||
min=1,
|
||||
max=5,
|
||||
)
|
||||
d = inp.as_dict()
|
||||
assert "template" in d
|
||||
assert d["min"] == 1
|
||||
assert d["max"] == 5
|
||||
|
||||
def test_duplicate_field_ids_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
List.Input(
|
||||
"bad",
|
||||
template=[Float.Input("x"), Float.Input("x")],
|
||||
)
|
||||
|
||||
def test_empty_template_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
List.Input("bad", template=[])
|
||||
|
||||
def test_min_gt_max_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
List.Input("bad", template=[Float.Input("x")], min=5, max=3)
|
||||
|
||||
def test_max_exceeds_limit_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
List.Input("bad", template=[Float.Input("x")], max=101)
|
||||
|
||||
def test_dynamic_input_in_template_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
List.Input(
|
||||
"bad",
|
||||
template=[List.Input("nested", template=[Float.Input("x")])],
|
||||
)
|
||||
|
||||
def test_validate_calls_through(self):
|
||||
inp = List.Input("items", template=[Float.Input("val", min=-1.0, max=1.0)])
|
||||
inp.validate() # should not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 0-row case
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestZeroRows:
|
||||
def test_empty_live_inputs_produces_empty_list(self):
|
||||
"""With min=0 and no live values, the result should be an empty list."""
|
||||
inp = List.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
|
||||
assert _run(inp, {}).get("loras") == []
|
||||
|
||||
def test_min_zero_with_values(self):
|
||||
"""min=0 but 2 rows of live data."""
|
||||
inp = List.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
|
||||
result = _run(inp, {"loras.0.strength": 0.8, "loras.1.strength": 0.5})
|
||||
assert result["loras"] == [{"strength": 0.8}, {"strength": 0.5}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# N-row case
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNRows:
|
||||
def test_two_rows_two_fields(self):
|
||||
"""Two rows with two fields each produce a list[dict]."""
|
||||
inp = List.Input(
|
||||
"loras",
|
||||
template=[String.Input("lora_name"), Float.Input("strength", default=1.0)],
|
||||
min=0, max=50,
|
||||
)
|
||||
result = _run(inp, {
|
||||
"loras.0.lora_name": "model_a.safetensors", "loras.0.strength": 0.9,
|
||||
"loras.1.lora_name": "model_b.safetensors", "loras.1.strength": 0.4,
|
||||
})
|
||||
assert result["loras"] == [
|
||||
{"lora_name": "model_a.safetensors", "strength": 0.9},
|
||||
{"lora_name": "model_b.safetensors", "strength": 0.4},
|
||||
]
|
||||
|
||||
def test_rows_are_sorted_by_index(self):
|
||||
"""Rows must be in ascending index order even if dict iteration is unordered."""
|
||||
inp = List.Input("items", template=[Int.Input("v", default=0)], min=0, max=10)
|
||||
result = _run(inp, {"items.0.v": 10, "items.2.v": 30, "items.1.v": 20})
|
||||
assert [row["v"] for row in result["items"]] == [10, 20, 30]
|
||||
|
||||
def test_min_rows_schema_slots(self):
|
||||
"""With min=2 and no live data, 2 slots must appear in the expanded schema."""
|
||||
inp = List.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
out, _, _ = get_finalized_class_inputs(_make_class_inputs(inp), {})
|
||||
all_slots = {**out.get("required", {}), **out.get("optional", {})}
|
||||
assert "items.0.val" in all_slots
|
||||
assert "items.1.val" in all_slots
|
||||
|
||||
def test_min_rows_reconstructs_when_no_values(self):
|
||||
"""min=2 with NO live values must still yield a 2-element list,
|
||||
not collapse to [] (regression: parent-path clobber)."""
|
||||
inp = List.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
result = _run(inp, {})
|
||||
assert len(result["items"]) == 2
|
||||
assert all("val" in row for row in result["items"])
|
||||
|
||||
def test_min_rows_reconstructs_with_partial_values(self):
|
||||
"""min=2 with only the first row's value present still yields 2 rows."""
|
||||
inp = List.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
result = _run(inp, {"items.0.val": 0.7})
|
||||
assert len(result["items"]) == 2
|
||||
assert result["items"][0]["val"] == 0.7
|
||||
assert result["items"][1]["val"] is None
|
||||
|
||||
def test_list_paths_in_v3_data(self):
|
||||
"""list_paths must contain the list id so build_nested_inputs knows to convert."""
|
||||
inp = List.Input("things", template=[Boolean.Input("flag")], min=0, max=5)
|
||||
_, _, v3_data = get_finalized_class_inputs(_make_class_inputs(inp), {})
|
||||
assert "things" in v3_data.get("list_paths", set())
|
||||
|
||||
def test_no_leftover_flat_keys(self):
|
||||
"""Flat keys must be consumed; only the reconstructed list remains."""
|
||||
inp = List.Input("rows", template=[Float.Input("x", default=0.0)], min=0, max=5)
|
||||
result = _run(inp, {"rows.0.x": 1.0, "rows.1.x": 2.0})
|
||||
assert "rows.0.x" not in result
|
||||
assert "rows.1.x" not in result
|
||||
assert isinstance(result["rows"], list)
|
||||
@ -1,93 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import av
|
||||
import numpy as np
|
||||
from fractions import Fraction
|
||||
from comfy_api.latest._input_impl.video_types import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util.video_types import VideoComponents
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gradient_components():
|
||||
"""Narrow horizontal ramp (0.25..0.30) that needs more than 8 bits to stay smooth"""
|
||||
width, height, frames = 64, 64, 3
|
||||
ramp = torch.linspace(0.25, 0.30, width).view(1, 1, width, 1).expand(frames, height, width, 3)
|
||||
return VideoComponents(images=ramp.contiguous(), frame_rate=Fraction(30))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def src8(gradient_components, tmp_path_factory):
|
||||
"""8-bit h264 mp4 (Create Video default)"""
|
||||
path = str(tmp_path_factory.mktemp("video") / "src8.mp4")
|
||||
VideoFromComponents(gradient_components).save_to(path)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def src10(gradient_components, tmp_path_factory):
|
||||
"""10-bit h264 mp4 (Create Video with bit_depth=10)"""
|
||||
path = str(tmp_path_factory.mktemp("video") / "src10.mp4")
|
||||
VideoFromComponents(gradient_components, bit_depth=10).save_to(path)
|
||||
return path
|
||||
|
||||
|
||||
def probe(path):
|
||||
"""(codec, pix_fmt, bit_depth) of the first video stream"""
|
||||
with av.open(path) as container:
|
||||
stream = container.streams.video[0]
|
||||
return (stream.codec.name, stream.format.name, max(c.bits for c in stream.format.components))
|
||||
|
||||
|
||||
def decoded_levels(path):
|
||||
"""Unique tonal levels in the first decoded frame (banding measure)"""
|
||||
with av.open(path) as container:
|
||||
frame = next(container.decode(container.streams.video[0]))
|
||||
return len(np.unique(frame.to_ndarray(format="gbrpf32le")[..., 0]))
|
||||
|
||||
|
||||
def video_packet_bytes(path):
|
||||
"""Raw video packet payloads; identical to the source's only for a true remux"""
|
||||
with av.open(path) as container:
|
||||
return [bytes(p) for p in container.demux(container.streams.video[0]) if p.size]
|
||||
|
||||
|
||||
def test_create_video_bit_depth(src8, src10):
|
||||
"""Create Video's bit_depth picks the encoded depth (default 8-bit); 10-bit reduces banding"""
|
||||
assert probe(src8) == ("h264", "yuv420p", 8)
|
||||
assert probe(src10) == ("h264", "yuv420p10le", 10)
|
||||
assert decoded_levels(src10) > 2 * decoded_levels(src8)
|
||||
|
||||
|
||||
def test_save_auto_keeps_source_depth(src8, src10, tmp_path):
|
||||
"""Save Video (no bit_depth = auto) stream-copies the source, preserving its depth byte-for-byte"""
|
||||
for name, src in [("p8", src8), ("p10", src10)]:
|
||||
path = str(tmp_path / f"{name}.mp4")
|
||||
VideoFromFile(src).save_to(path)
|
||||
assert probe(path) == probe(src)
|
||||
assert video_packet_bytes(path) == video_packet_bytes(src)
|
||||
|
||||
|
||||
def test_save_explicit_depth_reencodes(src8, src10, tmp_path):
|
||||
"""An explicit bit_depth different from the source forces a re-encode to that depth"""
|
||||
down = str(tmp_path / "down8.mp4")
|
||||
VideoFromFile(src10).save_to(down, bit_depth=8)
|
||||
assert probe(down) == ("h264", "yuv420p", 8)
|
||||
|
||||
up = str(tmp_path / "up10.mp4")
|
||||
VideoFromFile(src8).save_to(up, bit_depth=10)
|
||||
assert probe(up) == ("h264", "yuv420p10le", 10)
|
||||
|
||||
|
||||
def test_trim_keeps_source_depth(src10, tmp_path):
|
||||
"""Video Slice re-encodes (trim) but preserves the source's 10-bit depth"""
|
||||
path = str(tmp_path / "trim.mp4")
|
||||
VideoFromFile(src10).as_trimmed(start_time=0, duration=1 / 30, strict_duration=False).save_to(path)
|
||||
assert probe(path) == ("h264", "yuv420p10le", 10)
|
||||
|
||||
|
||||
def test_get_bit_depth(gradient_components, src8, src10):
|
||||
"""get_bit_depth reports a video's depth (backs the Get Video Components output)"""
|
||||
assert VideoFromFile(src8).get_bit_depth() == 8
|
||||
assert VideoFromFile(src10).get_bit_depth() == 10
|
||||
assert VideoFromComponents(gradient_components, bit_depth=10).get_bit_depth() == 10
|
||||
assert VideoFromComponents(gradient_components).get_bit_depth() == 8
|
||||
Reference in New Issue
Block a user