Compare commits

..

4 Commits

Author SHA1 Message Date
9a98cdc389 feat: declarative input validation with opt-in runtime enforcement
- Add `minLength`/`maxLength` to `IO.String.Input`, mirroring existing `min`/`max` for `Int`/`Float`.
- Add `runtime_input_validation` to V3 `Schema` (and `RUNTIME_INPUT_VALIDATION` class attribute for V1 nodes). Default `False`

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-14 20:56:09 +03:00
3d870ff51f chore: update workflow templates to v0.9.77 (#13895) 2026-05-15 01:25:18 +08:00
1f28908d6e Make audio processing nodes handle None -inputs (#13879) 2026-05-14 10:51:35 +08:00
fb51a988b6 Add test that each model has unique identifiers CORE-134 (#13654) 2026-05-14 10:41:25 +08:00
10 changed files with 416 additions and 548 deletions

View File

@ -327,11 +327,14 @@ class String(ComfyTypeIO):
'''String input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None,
min_length: int=None, max_length: int=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.multiline = multiline
self.placeholder = placeholder
self.dynamic_prompts = dynamic_prompts
self.min_length = min_length
self.max_length = max_length
self.default: str
def as_dict(self):
@ -339,6 +342,8 @@ class String(ComfyTypeIO):
"multiline": self.multiline,
"placeholder": self.placeholder,
"dynamicPrompts": self.dynamic_prompts,
"minLength": self.min_length,
"maxLength": self.max_length,
})
@comfytype(io_type="COMBO")
@ -1551,6 +1556,12 @@ class Schema:
Use this for nodes with interactive/operable UI regions that produce intermediate outputs
(e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image).
"""
runtime_input_validation: bool = False
"""Opt this node into runtime validation of declared input bounds (STRING minLength/maxLength,
INT/FLOAT min/max, COMBO membership) against resolved values, including values that arrive via links.
When False, only direct widget values are validated pre-execution and linked values flow through unchecked.
"""
def validate(self):
'''Validate the schema:
@ -2006,6 +2017,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls.GET_SCHEMA()
return cls._ACCEPT_ALL_INPUTS
_RUNTIME_INPUT_VALIDATION = None
@final
@classproperty
def RUNTIME_INPUT_VALIDATION(cls): # noqa
if cls._RUNTIME_INPUT_VALIDATION is None:
cls.GET_SCHEMA()
return cls._RUNTIME_INPUT_VALIDATION
@final
@classmethod
def INPUT_TYPES(cls) -> dict[str, dict]:
@ -2050,6 +2069,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls._NOT_IDEMPOTENT = schema.not_idempotent
if cls._ACCEPT_ALL_INPUTS is None:
cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs
if cls._RUNTIME_INPUT_VALIDATION is None:
cls._RUNTIME_INPUT_VALIDATION = schema.runtime_input_validation
if cls._RETURN_TYPES is None:
output = []

View File

@ -82,6 +82,8 @@ class VAEEncodeAudio(IO.ComfyNode):
@classmethod
def execute(cls, vae, audio) -> IO.NodeOutput:
if audio is None:
raise ValueError("VAEEncodeAudio: input audio is None (source video may have no audio track).")
sample_rate = audio["sample_rate"]
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
if vae_sample_rate != sample_rate:
@ -171,6 +173,8 @@ class SaveAudio(IO.ComfyNode):
@classmethod
def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
if audio is None:
raise ValueError("SaveAudio: input audio is None (source video may have no audio track).")
return IO.NodeOutput(
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
)
@ -198,6 +202,8 @@ class SaveAudioMP3(IO.ComfyNode):
@classmethod
def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
if audio is None:
raise ValueError("SaveAudioMP3: input audio is None (source video may have no audio track).")
return IO.NodeOutput(
ui=UI.AudioSaveHelper.get_save_audio_ui(
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
@ -226,6 +232,8 @@ class SaveAudioOpus(IO.ComfyNode):
@classmethod
def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
if audio is None:
raise ValueError("SaveAudioOpus: input audio is None (source video may have no audio track).")
return IO.NodeOutput(
ui=UI.AudioSaveHelper.get_save_audio_ui(
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
@ -252,6 +260,8 @@ class PreviewAudio(IO.ComfyNode):
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
if audio is None:
raise ValueError("PreviewAudio: input audio is None (source video may have no audio track).")
return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
save_flac = execute # TODO: remove
@ -392,21 +402,26 @@ class TrimAudioDuration(IO.ComfyNode):
@classmethod
def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
if audio is None:
return IO.NodeOutput(None)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
audio_length = waveform.shape[-1]
if audio_length == 0:
return IO.NodeOutput(audio)
if start_index < 0:
start_frame = audio_length + int(round(start_index * sample_rate))
else:
start_frame = int(round(start_index * sample_rate))
start_frame = max(0, min(start_frame, audio_length - 1))
start_frame = max(0, min(start_frame, audio_length))
end_frame = start_frame + int(round(duration * sample_rate))
end_frame = max(0, min(end_frame, audio_length))
if start_frame >= end_frame:
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
raise ValueError("TrimAudioDuration: Start time must be less than end time and be within the audio length.")
return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
@ -433,11 +448,13 @@ class SplitAudioChannels(IO.ComfyNode):
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
if audio is None:
return IO.NodeOutput(None, None)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
if waveform.shape[1] != 2:
raise ValueError("AudioSplit: Input audio has only one channel.")
raise ValueError(f"AudioSplit: Input audio must be stereo (2 channels), got {waveform.shape[1]} channel(s).")
left_channel = waveform[..., 0:1, :]
right_channel = waveform[..., 1:2, :]
@ -465,6 +482,12 @@ class JoinAudioChannels(IO.ComfyNode):
@classmethod
def execute(cls, audio_left, audio_right) -> IO.NodeOutput:
if audio_left is None and audio_right is None:
return IO.NodeOutput(None)
if audio_left is None:
return IO.NodeOutput(audio_right)
if audio_right is None:
return IO.NodeOutput(audio_left)
waveform_left = audio_left["waveform"]
sample_rate_left = audio_left["sample_rate"]
waveform_right = audio_right["waveform"]
@ -538,6 +561,12 @@ class AudioConcat(IO.ComfyNode):
@classmethod
def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
if audio1 is None and audio2 is None:
return IO.NodeOutput(None)
if audio1 is None:
return IO.NodeOutput(audio2)
if audio2 is None:
return IO.NodeOutput(audio1)
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@ -585,6 +614,12 @@ class AudioMerge(IO.ComfyNode):
@classmethod
def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
if audio1 is None and audio2 is None:
return IO.NodeOutput(None)
if audio1 is None:
return IO.NodeOutput(audio2)
if audio2 is None:
return IO.NodeOutput(audio1)
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@ -595,6 +630,9 @@ class AudioMerge(IO.ComfyNode):
length_1 = waveform_1.shape[-1]
length_2 = waveform_2.shape[-1]
if length_1 == 0 or length_2 == 0:
return IO.NodeOutput({"waveform": waveform_1, "sample_rate": output_sample_rate})
if length_2 > length_1:
logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.")
waveform_2 = waveform_2[..., :length_1]
@ -646,6 +684,8 @@ class AudioAdjustVolume(IO.ComfyNode):
@classmethod
def execute(cls, audio, volume) -> IO.NodeOutput:
if audio is None:
return IO.NodeOutput(None)
if volume == 0:
return IO.NodeOutput(audio)
waveform = audio["waveform"]
@ -729,8 +769,14 @@ class AudioEqualizer3Band(IO.ComfyNode):
@classmethod
def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput:
if audio is None:
return IO.NodeOutput(None)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
if waveform.shape[-1] == 0:
return IO.NodeOutput(audio)
eq_waveform = waveform.clone()
# 1. Apply Low Shelf (Bass)

View File

@ -1,249 +0,0 @@
"""
SaveImagePromotable: a pass-through SaveImage variant with accumulating previews
and a "promote/lock" feature.
Modes:
- Pass-through (default): saves incoming images, emits preview UI, returns the
input tensor as output. With `accumulate=True`, the frontend appends previews
to a gallery instead of replacing it.
- Locked: when `promoted_asset_ref` is a non-empty JSON ref to a saved asset,
the node skips saving, loads the referenced image, and outputs that image.
The frontend is expected to write the ref into the widget when the user
clicks the "lock" UI on a preview.
Caching: IS_CHANGED returns a stable key derived from the ref (+ file mtime)
when locked, so re-queues with the same lock are cache hits and upstream
ancestors are skipped. Unlocked, IS_CHANGED returns False to defer to normal
input-signature caching.
"""
from __future__ import annotations
import json
import os
import numpy as np
import torch
from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo
import folder_paths
import node_helpers
from comfy.cli_args import args
def _parse_promoted_ref(promoted_asset_ref: str) -> dict | None:
if not promoted_asset_ref:
return None
try:
ref = json.loads(promoted_asset_ref)
except (json.JSONDecodeError, TypeError):
return None
if not isinstance(ref, dict):
return None
filename = ref.get("filename")
if not isinstance(filename, str) or not filename:
return None
subfolder = ref.get("subfolder", "") or ""
asset_type = ref.get("type", "output") or "output"
if not isinstance(subfolder, str) or not isinstance(asset_type, str):
return None
# Reject anything that could escape the base directory.
if os.path.isabs(subfolder) or ".." in subfolder.split(os.sep):
return None
if os.path.isabs(filename) or ".." in filename.split(os.sep):
return None
return {"filename": filename, "subfolder": subfolder, "type": asset_type}
def _resolve_ref_path(ref: dict) -> str | None:
asset_type = ref["type"]
if asset_type == "output":
base = folder_paths.get_output_directory()
elif asset_type == "input":
base = folder_paths.get_input_directory()
elif asset_type == "temp":
base = folder_paths.get_temp_directory()
else:
return None
path = os.path.join(base, ref["subfolder"], ref["filename"])
# Defense-in-depth: ensure the resolved path stays inside the base dir.
base_real = os.path.realpath(base)
path_real = os.path.realpath(path)
if not path_real.startswith(base_real + os.sep) and path_real != base_real:
return None
if not os.path.isfile(path_real):
return None
return path_real
def _load_image_tensor(path: str) -> torch.Tensor:
img = node_helpers.pillow(Image.open, path)
output_images: list[torch.Tensor] = []
w: int | None = None
h: int | None = None
for frame in ImageSequence.Iterator(img):
frame = node_helpers.pillow(ImageOps.exif_transpose, frame)
image = frame.convert("RGB")
if not output_images:
w, h = image.size
if image.size != (w, h):
continue
arr = np.array(image).astype(np.float32) / 255.0
output_images.append(torch.from_numpy(arr)[None,])
if not output_images:
raise RuntimeError(f"Failed to decode any frames from {path}")
return torch.cat(output_images, dim=0)
class SaveImagePromotable:
"""Pass-through SaveImage with accumulating previews and promote/lock.
Inputs:
images: IMAGE tensor to save + pass through (ignored when locked).
filename_prefix: STRING prefix for saved files.
accumulate: BOOLEAN — when True, frontend appends previews to gallery.
promoted_asset_ref: STRING — JSON ref written by the frontend on lock.
Empty string means "not locked, normal pass-through".
Output:
IMAGE — input pass-through, or the loaded promoted image when locked.
"""
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.compress_level = 4
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": (
"IMAGE",
{
"tooltip": "Images to save and pass through. Ignored when a promoted asset is locked."
},
),
"filename_prefix": (
"STRING",
{"default": "ComfyUI", "tooltip": "Prefix for saved files."},
),
"accumulate": (
"BOOLEAN",
{
"default": False,
"tooltip": "When enabled, previews append to a per-node gallery instead of replacing it.",
},
),
},
"optional": {
"promoted_asset_ref": (
"STRING",
{
"default": "",
"multiline": False,
"tooltip": "JSON ref to a saved asset. Set by the UI; do not edit manually.",
},
),
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO",
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "execute"
OUTPUT_NODE = True
CATEGORY = "image"
DESCRIPTION = "Saves images, shows accumulating previews, and passes the input through. A promoted (locked) preview overrides pass-through to output the chosen image."
def _save_images(self, images, filename_prefix, prompt, extra_pnginfo):
full_output_folder, filename, counter, subfolder, _ = (
folder_paths.get_save_image_path(
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
)
)
results: list[dict] = []
for batch_number, image in enumerate(images):
arr = 255.0 * image.cpu().numpy()
img = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8))
metadata: PngInfo | None = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for key in extra_pnginfo:
metadata.add_text(key, json.dumps(extra_pnginfo[key]))
filename_with_batch = filename.replace("%batch_num%", str(batch_number))
out_name = f"{filename_with_batch}_{counter:05}_.png"
img.save(
os.path.join(full_output_folder, out_name),
pnginfo=metadata,
compress_level=self.compress_level,
)
results.append(
{"filename": out_name, "subfolder": subfolder, "type": self.type}
)
counter += 1
return results
def execute(
self,
images,
filename_prefix="ComfyUI",
accumulate=False, # noqa: ARG002
promoted_asset_ref="",
prompt=None,
extra_pnginfo=None,
):
ref = _parse_promoted_ref(promoted_asset_ref)
if ref is not None:
path = _resolve_ref_path(ref)
if path is not None:
tensor = _load_image_tensor(path)
tensor = tensor.to(device=images.device, dtype=images.dtype)
return {
"ui": {"images": [ref]},
"result": (tensor,),
}
# Ref is set but stale (file deleted / failed validation): fall
# through to pass-through so the user gets a working graph rather
# than an execution error.
saved = self._save_images(images, filename_prefix, prompt, extra_pnginfo)
return {"ui": {"images": saved}, "result": (images,)}
@classmethod
def IS_CHANGED(
cls,
images, # noqa: ARG003
filename_prefix="ComfyUI",
accumulate=False, # noqa: ARG003
promoted_asset_ref="",
prompt=None, # noqa: ARG003
extra_pnginfo=None, # noqa: ARG003
):
ref = _parse_promoted_ref(promoted_asset_ref)
if ref is None:
return False
path = _resolve_ref_path(ref)
if path is None:
return f"PROMOTED::MISSING::{promoted_asset_ref}"
try:
stat = os.stat(path)
sig = f"{stat.st_size}:{stat.st_mtime_ns}"
except OSError:
sig = "NOSTAT"
return f"PROMOTED::{promoted_asset_ref}::{sig}"
NODE_CLASS_MAPPINGS = {
"SaveImagePromotable": SaveImagePromotable,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SaveImagePromotable": "Save Image (Promotable, PoC)",
}

View File

@ -83,7 +83,7 @@ class IsChangedCache:
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
input_data_all, _, v3_data, _ = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
is_changed = await resolve_map_node_over_list_results(is_changed)
@ -215,7 +215,52 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
v3_data["hidden_inputs"] = hidden_inputs_v3
return input_data_all, missing_keys, v3_data
return input_data_all, missing_keys, v3_data, valid_inputs
def _check_resolved_input_bounds(name, val, input_type, extra_info):
"""Raise ValueError if a single resolved value violates declared bounds."""
if input_type == "STRING":
if not isinstance(val, str):
return
min_length = extra_info.get("minLength")
max_length = extra_info.get("maxLength")
if min_length is not None and len(val) < min_length:
raise ValueError(f"Input '{name}': string length {len(val)} is shorter than minLength of {min_length}")
if max_length is not None and len(val) > max_length:
raise ValueError(f"Input '{name}': string length {len(val)} is longer than maxLength of {max_length}")
elif input_type in ("INT", "FLOAT"):
if isinstance(val, bool) or not isinstance(val, (int, float)):
return
min_v = extra_info.get("min")
max_v = extra_info.get("max")
if min_v is not None and val < min_v:
raise ValueError(f"Input '{name}': value {val} is smaller than min of {min_v}")
if max_v is not None and val > max_v:
raise ValueError(f"Input '{name}': value {val} is bigger than max of {max_v}")
elif isinstance(input_type, list) or input_type == io.Combo.io_type:
combo_options = extra_info.get("options", []) if input_type == io.Combo.io_type else input_type
is_multiselect = extra_info.get("multiselect", False)
if is_multiselect and isinstance(val, list):
invalid_vals = [v for v in val if v not in combo_options]
else:
invalid_vals = [val] if val not in combo_options else []
if invalid_vals:
raise ValueError(f"Input '{name}': value(s) {invalid_vals} not in combo options")
def _validate_resolved_inputs(class_def, input_data_all, valid_inputs):
"""Enforce declared input bounds against resolved values, including values that arrive via links."""
if not getattr(class_def, "RUNTIME_INPUT_VALIDATION", False):
return
for x, values in input_data_all.items():
input_type, _, extra_info = get_input_info(class_def, x, valid_inputs)
if input_type is None or extra_info is None:
continue
for val in values:
if val is None:
continue
_check_resolved_input_bounds(x, val, input_type, extra_info)
map_node_over_list = None #Don't hook this please
@ -480,7 +525,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
input_data_all, missing_keys, v3_data, valid_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@ -509,6 +554,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
execution_list.make_input_strong_link(unique_id, i)
return (ExecutionResult.PENDING, None, None)
_validate_resolved_inputs(class_def, input_data_all, valid_inputs)
def execution_block_cb(block):
if block.message is not None:
mes = {
@ -1014,6 +1061,36 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
errors.append(error)
continue
if input_type == "STRING":
min_length = extra_info.get("minLength")
max_length = extra_info.get("maxLength")
if min_length is not None and len(val) < min_length:
error = {
"type": "value_shorter_than_min_length",
"message": f"Value length {len(val)} shorter than min length of {min_length}",
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if max_length is not None and len(val) > max_length:
error = {
"type": "value_longer_than_max_length",
"message": f"Value length {len(val)} longer than max length of {max_length}",
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if isinstance(input_type, list) or input_type == io.Combo.io_type:
if input_type == io.Combo.io_type:
combo_options = extra_info.get("options", [])
@ -1050,7 +1127,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
continue
if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
input_data_all, _, v3_data, _ = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs:

View File

@ -2397,7 +2397,6 @@ async def init_builtin_extra_nodes():
"nodes_fresca.py",
"nodes_apg.py",
"nodes_preview_any.py",
"nodes_save_image_promotable.py",
"nodes_ace.py",
"nodes_string.py",
"nodes_camera_trajectory.py",

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.43.18
comfyui-workflow-templates==0.9.75
comfyui-workflow-templates==0.9.77
comfyui-embedded-docs==0.5.0
torch
torchsde

View File

@ -1,289 +0,0 @@
import json
import os
from unittest.mock import MagicMock, patch
import numpy as np
import torch
from PIL import Image
mock_nodes = MagicMock()
mock_nodes.MAX_RESOLUTION = 16384
mock_server = MagicMock()
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
from comfy_extras import nodes_save_image_promotable as mod
def _make_image(width=8, height=4):
return torch.rand(1, height, width, 3)
def _write_png(path: str, width=8, height=4):
arr = (np.random.rand(height, width, 3) * 255).astype(np.uint8)
Image.fromarray(arr).save(path)
class TestParseRef:
def test_empty(self):
assert mod._parse_promoted_ref("") is None
def test_invalid_json(self):
assert mod._parse_promoted_ref("{not json") is None
def test_non_object(self):
assert mod._parse_promoted_ref('"a string"') is None
assert mod._parse_promoted_ref("[]") is None
def test_missing_filename(self):
assert mod._parse_promoted_ref('{"subfolder":"x","type":"output"}') is None
def test_path_traversal_filename(self):
ref = json.dumps(
{"filename": "../etc/passwd", "subfolder": "", "type": "output"}
)
assert mod._parse_promoted_ref(ref) is None
def test_path_traversal_subfolder(self):
ref = json.dumps({"filename": "x.png", "subfolder": "../..", "type": "output"})
assert mod._parse_promoted_ref(ref) is None
def test_absolute_filename(self):
ref = json.dumps({"filename": "/etc/passwd", "subfolder": "", "type": "output"})
assert mod._parse_promoted_ref(ref) is None
def test_valid(self):
ref = json.dumps({"filename": "x.png", "subfolder": "sub", "type": "output"})
parsed = mod._parse_promoted_ref(ref)
assert parsed == {"filename": "x.png", "subfolder": "sub", "type": "output"}
def test_defaults_applied(self):
ref = json.dumps({"filename": "x.png"})
parsed = mod._parse_promoted_ref(ref)
assert parsed == {"filename": "x.png", "subfolder": "", "type": "output"}
class TestResolveRefPath:
def test_unknown_type(self, tmp_path):
with patch.object(
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
):
assert (
mod._resolve_ref_path(
{"filename": "x.png", "subfolder": "", "type": "garbage"}
)
is None
)
def test_missing_file(self, tmp_path):
with patch.object(
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
):
assert (
mod._resolve_ref_path(
{"filename": "missing.png", "subfolder": "", "type": "output"}
)
is None
)
def test_resolves_file(self, tmp_path):
target = tmp_path / "img.png"
_write_png(str(target))
with patch.object(
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
):
resolved = mod._resolve_ref_path(
{"filename": "img.png", "subfolder": "", "type": "output"}
)
assert resolved is not None
assert os.path.realpath(resolved) == os.path.realpath(str(target))
def test_resolves_file_in_subfolder(self, tmp_path):
sub = tmp_path / "nested"
sub.mkdir()
target = sub / "img.png"
_write_png(str(target))
with patch.object(
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
):
resolved = mod._resolve_ref_path(
{"filename": "img.png", "subfolder": "nested", "type": "output"}
)
assert resolved is not None
assert os.path.realpath(resolved) == os.path.realpath(str(target))
def test_symlink_escape_rejected(self, tmp_path):
outside = tmp_path / "outside"
outside.mkdir()
secret = outside / "secret.png"
_write_png(str(secret))
base = tmp_path / "base"
base.mkdir()
link = base / "link.png"
os.symlink(str(secret), str(link))
with patch.object(
mod.folder_paths, "get_output_directory", return_value=str(base)
):
resolved = mod._resolve_ref_path(
{"filename": "link.png", "subfolder": "", "type": "output"}
)
assert resolved is None
class TestNodeContract:
def test_input_types_shape(self):
inp = mod.SaveImagePromotable.INPUT_TYPES()
assert set(inp["required"].keys()) == {
"images",
"filename_prefix",
"accumulate",
}
assert set(inp["optional"].keys()) == {"promoted_asset_ref"}
assert set(inp["hidden"].keys()) == {"prompt", "extra_pnginfo"}
assert inp["required"]["accumulate"][0] == "BOOLEAN"
assert inp["required"]["accumulate"][1]["default"] is False
def test_class_metadata(self):
cls = mod.SaveImagePromotable
assert cls.RETURN_TYPES == ("IMAGE",)
assert cls.RETURN_NAMES == ("images",)
assert cls.OUTPUT_NODE is True
assert cls.FUNCTION == "execute"
assert "SaveImagePromotable" in mod.NODE_CLASS_MAPPINGS
assert mod.NODE_CLASS_MAPPINGS["SaveImagePromotable"] is cls
class TestExecutePassthrough:
def test_passthrough_saves_and_returns_input(self, tmp_path):
node = mod.SaveImagePromotable()
node.output_dir = str(tmp_path)
images = _make_image()
with (
patch.object(mod.args, "disable_metadata", True),
patch.object(mod.folder_paths, "get_save_image_path") as get_path,
):
get_path.return_value = (str(tmp_path), "ComfyUI", 1, "", "ComfyUI")
result = node.execute(
images,
filename_prefix="ComfyUI",
accumulate=False,
promoted_asset_ref="",
)
assert "ui" in result
assert "result" in result
assert torch.equal(result["result"][0], images)
assert len(result["ui"]["images"]) == 1
saved_name = result["ui"]["images"][0]["filename"]
assert os.path.isfile(os.path.join(str(tmp_path), saved_name))
def test_stale_ref_falls_through_to_passthrough(self, tmp_path):
node = mod.SaveImagePromotable()
node.output_dir = str(tmp_path)
images = _make_image()
ref = json.dumps(
{"filename": "does_not_exist.png", "subfolder": "", "type": "output"}
)
with (
patch.object(mod.args, "disable_metadata", True),
patch.object(mod.folder_paths, "get_save_image_path") as get_path,
patch.object(
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
),
):
get_path.return_value = (str(tmp_path), "ComfyUI", 1, "", "ComfyUI")
result = node.execute(images, promoted_asset_ref=ref)
assert torch.equal(result["result"][0], images)
class TestExecuteLocked:
def test_locked_outputs_loaded_image(self, tmp_path):
target = tmp_path / "promoted.png"
_write_png(str(target), width=8, height=4)
ref = json.dumps(
{"filename": "promoted.png", "subfolder": "", "type": "output"}
)
node = mod.SaveImagePromotable()
node.output_dir = str(tmp_path)
upstream = _make_image(width=8, height=4)
with patch.object(
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
):
result = node.execute(upstream, promoted_asset_ref=ref)
assert result["ui"]["images"] == [
{"filename": "promoted.png", "subfolder": "", "type": "output"}
]
out = result["result"][0]
assert out.shape == upstream.shape
assert out.dtype == upstream.dtype
assert not torch.equal(out, upstream)
def test_locked_does_not_save(self, tmp_path):
target = tmp_path / "promoted.png"
_write_png(str(target))
ref = json.dumps(
{"filename": "promoted.png", "subfolder": "", "type": "output"}
)
node = mod.SaveImagePromotable()
node.output_dir = str(tmp_path)
images = _make_image()
with (
patch.object(
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
),
patch.object(node, "_save_images") as save_mock,
):
node.execute(images, promoted_asset_ref=ref)
save_mock.assert_not_called()
class TestIsChanged:
def test_unlocked_returns_false(self):
assert (
mod.SaveImagePromotable.IS_CHANGED(images=None, promoted_asset_ref="")
is False
)
def test_locked_missing_file(self, tmp_path):
ref = json.dumps({"filename": "missing.png", "subfolder": "", "type": "output"})
with patch.object(
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
):
key = mod.SaveImagePromotable.IS_CHANGED(
images=None, promoted_asset_ref=ref
)
assert isinstance(key, str)
assert key.startswith("PROMOTED::MISSING::")
def test_locked_stable_key(self, tmp_path):
target = tmp_path / "p.png"
_write_png(str(target))
ref = json.dumps({"filename": "p.png", "subfolder": "", "type": "output"})
with patch.object(
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
):
k1 = mod.SaveImagePromotable.IS_CHANGED(images=None, promoted_asset_ref=ref)
k2 = mod.SaveImagePromotable.IS_CHANGED(images=None, promoted_asset_ref=ref)
assert k1 == k2
assert k1.startswith("PROMOTED::")
def test_locked_key_changes_when_file_changes(self, tmp_path):
target = tmp_path / "p.png"
_write_png(str(target), width=8, height=4)
ref = json.dumps({"filename": "p.png", "subfolder": "", "type": "output"})
with patch.object(
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
):
k1 = mod.SaveImagePromotable.IS_CHANGED(images=None, promoted_asset_ref=ref)
os.utime(str(target), (1234567890, 1234567890))
with patch.object(
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
):
k2 = mod.SaveImagePromotable.IS_CHANGED(images=None, promoted_asset_ref=ref)
assert k1 != k2

View File

@ -1,9 +1,23 @@
from collections import defaultdict
import torch
from comfy.model_detection import detect_unet_config, model_config_from_unet_config
import comfy.supported_models
def _freeze(value):
"""Recursively convert a value to a hashable form so configs can be
compared/used as dict keys or set members."""
if isinstance(value, dict):
return frozenset((k, _freeze(v)) for k, v in value.items())
if isinstance(value, (list, tuple)):
return tuple(_freeze(v) for v in value)
if isinstance(value, set):
return frozenset(_freeze(v) for v in value)
return value
def _make_longcat_comfyui_sd():
"""Minimal ComfyUI-format state dict for pre-converted LongCat-Image weights."""
sd = {}
@ -110,3 +124,21 @@ class TestModelDetection:
model_config = model_config_from_unet_config(unet_config, sd)
assert model_config is not None
assert type(model_config).__name__ == "FluxSchnell"
def test_unet_config_and_required_keys_combination_is_unique(self):
"""Each model in the registry must have a unique combination of
``unet_config`` and ``required_keys``. If two models share the same
combination, ``BASE.matches`` cannot disambiguate between them and the
first one in the list will always win."""
models = comfy.supported_models.models
groups = defaultdict(list)
for model in models:
key = (_freeze(model.unet_config), _freeze(model.required_keys))
groups[key].append(model.__name__)
duplicates = {k: names for k, names in groups.items() if len(names) > 1}
assert not duplicates, (
"Found models sharing the same (unet_config, required_keys) "
"combination, which makes detection ambiguous: "
+ "; ".join(", ".join(names) for names in duplicates.values())
)

View File

@ -1011,3 +1011,124 @@ class TestExecution:
"""Test getting a non-existent job returns 404"""
job = client.get_job("nonexistent-job-id")
assert job is None, "Non-existent job should return None"
@pytest.mark.parametrize("text, expect_error", [
("hello", False), # 5 chars, within [3, 10]
("abc", False), # 3 chars, exact min boundary
("abcdefghij", False), # 10 chars, exact max boundary
("ab", True), # 2 chars, below min
("abcdefghijk", True), # 11 chars, above max
("", True), # 0 chars, below min
])
def test_string_length_widget_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder):
"""Test minLength/maxLength validation for direct widget values (validate_inputs path)."""
g = builder
node = g.node("StubStringWithLength", text=text)
g.node("SaveImage", images=node.out(0))
if expect_error:
with pytest.raises(urllib.error.HTTPError) as exc_info:
client.run(g)
assert exc_info.value.code == 400
else:
client.run(g)
@pytest.mark.parametrize("text, expect_error", [
("hello", False), # within bounds
("ab", True), # below min
("abcdefghijk", True), # above max
])
def test_string_length_linked_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder):
"""Test minLength/maxLength validation for linked inputs when node opts in via RUNTIME_INPUT_VALIDATION=True."""
g = builder
str_node = g.node("StubStringOutput", value=text)
node = g.node("StubStringWithLength", text=str_node.out(0))
g.node("SaveImage", images=node.out(0))
if expect_error:
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
else:
client.run(g)
@pytest.mark.parametrize("text", [
"ab", # below declared minLength
"abcdefghijk", # above declared maxLength
"", # empty
"hello", # within bounds
])
def test_string_length_linked_skipped_without_flag(self, text, client: ComfyClient, builder: GraphBuilder):
"""Without RUNTIME_INPUT_VALIDATION=True, declared bounds must NOT be enforced for linked values.
Preserves V1 behavior: many existing workflows rely on out-of-bounds values passing
through links. Adding declared bounds without the flag must not break them.
"""
g = builder
str_node = g.node("StubStringOutput", value=text)
node = g.node("StubStringWithLengthNoFlag", text=str_node.out(0))
g.node("SaveImage", images=node.out(0))
client.run(g)
@pytest.mark.parametrize("value, expect_error", [
(5, False), # within [1, 10]
(1, False), # exact min boundary
(10, False), # exact max boundary
(0, True), # below min
(11, True), # above max
(-7, True), # well below min
])
def test_int_bounds_linked_validation(self, value, expect_error, client: ComfyClient, builder: GraphBuilder):
"""min/max validation for linked INT inputs when node opts in via RUNTIME_INPUT_VALIDATION=True.
Direct widget INT values are already validated pre-execution. This test exercises the
symmetric runtime path for values arriving through a connection.
"""
g = builder
int_node = g.node("StubInt", value=value)
node = g.node("StubIntWithBounds", value=int_node.out(0))
g.node("SaveImage", images=node.out(0))
if expect_error:
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
else:
client.run(g)
@pytest.mark.parametrize("choice, expect_error", [
("RED", False),
("GREEN", False),
("BLUE", False),
("PURPLE", True),
("", True),
("red", True), # case-sensitive
])
def test_combo_membership_linked_validation(self, choice, expect_error, client: ComfyClient, builder: GraphBuilder):
"""COMBO option membership for linked values when node opts in via RUNTIME_INPUT_VALIDATION=True.
StubComboWithOptions declares ``input_types`` in VALIDATE_INPUTS to bypass the engine's
link-type compatibility check, so we can feed a STRING into a COMBO and verify the
runtime membership check fires.
"""
g = builder
str_node = g.node("StubStringOutput", value=choice)
node = g.node("StubComboWithOptions", choice=str_node.out(0))
g.node("SaveImage", images=node.out(0))
if expect_error:
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
else:
client.run(g)

View File

@ -113,12 +113,117 @@ class StubFloat:
def stub_float(self, value):
return (value,)
class StubStringOutput:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("STRING", {"default": ""}),
},
}
RETURN_TYPES = ("STRING",)
FUNCTION = "stub_string"
CATEGORY = "Testing/Stub Nodes"
def stub_string(self, value):
return (value,)
class StubStringWithLength:
"""STRING input with declared bounds AND opted in to runtime validation (RUNTIME_INPUT_VALIDATION = True)."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": ("STRING", {"default": "hello", "minLength": 3, "maxLength": 10}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_string_with_length"
RUNTIME_INPUT_VALIDATION = True
CATEGORY = "Testing/Stub Nodes"
def stub_string_with_length(self, text):
return (torch.zeros(1, 64, 64, 3),)
class StubStringWithLengthNoFlag:
"""Same bounds as StubStringWithLength but NOT opted in - linked values must flow through unchecked."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": ("STRING", {"default": "hello", "minLength": 3, "maxLength": 10}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_string_with_length_no_flag"
CATEGORY = "Testing/Stub Nodes"
def stub_string_with_length_no_flag(self, text):
return (torch.zeros(1, 64, 64, 3),)
class StubIntWithBounds:
"""INT input with min/max bounds AND opted in to runtime validation."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("INT", {"default": 5, "min": 1, "max": 10}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_int_with_bounds"
RUNTIME_INPUT_VALIDATION = True
CATEGORY = "Testing/Stub Nodes"
def stub_int_with_bounds(self, value):
return (torch.zeros(1, 64, 64, 3),)
class StubComboWithOptions:
"""COMBO input opted in to runtime validation.
Declares ``input_types`` in VALIDATE_INPUTS to bypass the engine's link-type compatibility
check, allowing tests to link a STRING into a COMBO and exercise the runtime membership check.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"choice": (["RED", "GREEN", "BLUE"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_combo"
RUNTIME_INPUT_VALIDATION = True
CATEGORY = "Testing/Stub Nodes"
@classmethod
def VALIDATE_INPUTS(cls, input_types):
return True
def stub_combo(self, choice):
return (torch.zeros(1, 64, 64, 3),)
TEST_STUB_NODE_CLASS_MAPPINGS = {
"StubImage": StubImage,
"StubConstantImage": StubConstantImage,
"StubMask": StubMask,
"StubInt": StubInt,
"StubFloat": StubFloat,
"StubStringOutput": StubStringOutput,
"StubStringWithLength": StubStringWithLength,
"StubStringWithLengthNoFlag": StubStringWithLengthNoFlag,
"StubIntWithBounds": StubIntWithBounds,
"StubComboWithOptions": StubComboWithOptions,
}
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
"StubImage": "Stub Image",
@ -126,4 +231,9 @@ TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
"StubMask": "Stub Mask",
"StubInt": "Stub Int",
"StubFloat": "Stub Float",
"StubStringOutput": "Stub String Output",
"StubStringWithLength": "Stub String With Length",
"StubStringWithLengthNoFlag": "Stub String With Length (No Flag)",
"StubIntWithBounds": "Stub Int With Bounds",
"StubComboWithOptions": "Stub Combo With Options",
}