mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 03:15:59 +08:00
Compare commits
2 Commits
feat/strin
...
glary/save
| Author | SHA1 | Date | |
|---|---|---|---|
| be9fd3545e | |||
| da90bc93e4 |
@ -327,14 +327,11 @@ class String(ComfyTypeIO):
|
||||
'''String input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None,
|
||||
min_length: int=None, max_length: int=None):
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||
self.multiline = multiline
|
||||
self.placeholder = placeholder
|
||||
self.dynamic_prompts = dynamic_prompts
|
||||
self.min_length = min_length
|
||||
self.max_length = max_length
|
||||
self.default: str
|
||||
|
||||
def as_dict(self):
|
||||
@ -342,8 +339,6 @@ class String(ComfyTypeIO):
|
||||
"multiline": self.multiline,
|
||||
"placeholder": self.placeholder,
|
||||
"dynamicPrompts": self.dynamic_prompts,
|
||||
"minLength": self.min_length,
|
||||
"maxLength": self.max_length,
|
||||
})
|
||||
|
||||
@comfytype(io_type="COMBO")
|
||||
@ -1556,12 +1551,6 @@ class Schema:
|
||||
Use this for nodes with interactive/operable UI regions that produce intermediate outputs
|
||||
(e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image).
|
||||
"""
|
||||
runtime_input_validation: bool = False
|
||||
"""Opt this node into runtime validation of declared input bounds (STRING minLength/maxLength,
|
||||
INT/FLOAT min/max, COMBO membership) against resolved values, including values that arrive via links.
|
||||
|
||||
When False, only direct widget values are validated pre-execution and linked values flow through unchecked.
|
||||
"""
|
||||
|
||||
def validate(self):
|
||||
'''Validate the schema:
|
||||
@ -2017,14 +2006,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls.GET_SCHEMA()
|
||||
return cls._ACCEPT_ALL_INPUTS
|
||||
|
||||
_RUNTIME_INPUT_VALIDATION = None
|
||||
@final
|
||||
@classproperty
|
||||
def RUNTIME_INPUT_VALIDATION(cls): # noqa
|
||||
if cls._RUNTIME_INPUT_VALIDATION is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._RUNTIME_INPUT_VALIDATION
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||
@ -2069,8 +2050,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls._NOT_IDEMPOTENT = schema.not_idempotent
|
||||
if cls._ACCEPT_ALL_INPUTS is None:
|
||||
cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs
|
||||
if cls._RUNTIME_INPUT_VALIDATION is None:
|
||||
cls._RUNTIME_INPUT_VALIDATION = schema.runtime_input_validation
|
||||
|
||||
if cls._RETURN_TYPES is None:
|
||||
output = []
|
||||
|
||||
@ -82,8 +82,6 @@ class VAEEncodeAudio(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, audio) -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
raise ValueError("VAEEncodeAudio: input audio is None (source video may have no audio track).")
|
||||
sample_rate = audio["sample_rate"]
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||
if vae_sample_rate != sample_rate:
|
||||
@ -173,8 +171,6 @@ class SaveAudio(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
raise ValueError("SaveAudio: input audio is None (source video may have no audio track).")
|
||||
return IO.NodeOutput(
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
|
||||
)
|
||||
@ -202,8 +198,6 @@ class SaveAudioMP3(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
raise ValueError("SaveAudioMP3: input audio is None (source video may have no audio track).")
|
||||
return IO.NodeOutput(
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
||||
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
||||
@ -232,8 +226,6 @@ class SaveAudioOpus(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
raise ValueError("SaveAudioOpus: input audio is None (source video may have no audio track).")
|
||||
return IO.NodeOutput(
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
||||
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
||||
@ -260,8 +252,6 @@ class PreviewAudio(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio) -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
raise ValueError("PreviewAudio: input audio is None (source video may have no audio track).")
|
||||
return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
|
||||
|
||||
save_flac = execute # TODO: remove
|
||||
@ -402,26 +392,21 @@ class TrimAudioDuration(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
return IO.NodeOutput(None)
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
audio_length = waveform.shape[-1]
|
||||
|
||||
if audio_length == 0:
|
||||
return IO.NodeOutput(audio)
|
||||
|
||||
if start_index < 0:
|
||||
start_frame = audio_length + int(round(start_index * sample_rate))
|
||||
else:
|
||||
start_frame = int(round(start_index * sample_rate))
|
||||
start_frame = max(0, min(start_frame, audio_length))
|
||||
start_frame = max(0, min(start_frame, audio_length - 1))
|
||||
|
||||
end_frame = start_frame + int(round(duration * sample_rate))
|
||||
end_frame = max(0, min(end_frame, audio_length))
|
||||
|
||||
if start_frame >= end_frame:
|
||||
raise ValueError("TrimAudioDuration: Start time must be less than end time and be within the audio length.")
|
||||
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
|
||||
|
||||
return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
|
||||
|
||||
@ -448,13 +433,11 @@ class SplitAudioChannels(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio) -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
return IO.NodeOutput(None, None)
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
if waveform.shape[1] != 2:
|
||||
raise ValueError(f"AudioSplit: Input audio must be stereo (2 channels), got {waveform.shape[1]} channel(s).")
|
||||
raise ValueError("AudioSplit: Input audio has only one channel.")
|
||||
|
||||
left_channel = waveform[..., 0:1, :]
|
||||
right_channel = waveform[..., 1:2, :]
|
||||
@ -482,12 +465,6 @@ class JoinAudioChannels(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio_left, audio_right) -> IO.NodeOutput:
|
||||
if audio_left is None and audio_right is None:
|
||||
return IO.NodeOutput(None)
|
||||
if audio_left is None:
|
||||
return IO.NodeOutput(audio_right)
|
||||
if audio_right is None:
|
||||
return IO.NodeOutput(audio_left)
|
||||
waveform_left = audio_left["waveform"]
|
||||
sample_rate_left = audio_left["sample_rate"]
|
||||
waveform_right = audio_right["waveform"]
|
||||
@ -561,12 +538,6 @@ class AudioConcat(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
|
||||
if audio1 is None and audio2 is None:
|
||||
return IO.NodeOutput(None)
|
||||
if audio1 is None:
|
||||
return IO.NodeOutput(audio2)
|
||||
if audio2 is None:
|
||||
return IO.NodeOutput(audio1)
|
||||
waveform_1 = audio1["waveform"]
|
||||
waveform_2 = audio2["waveform"]
|
||||
sample_rate_1 = audio1["sample_rate"]
|
||||
@ -614,12 +585,6 @@ class AudioMerge(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
|
||||
if audio1 is None and audio2 is None:
|
||||
return IO.NodeOutput(None)
|
||||
if audio1 is None:
|
||||
return IO.NodeOutput(audio2)
|
||||
if audio2 is None:
|
||||
return IO.NodeOutput(audio1)
|
||||
waveform_1 = audio1["waveform"]
|
||||
waveform_2 = audio2["waveform"]
|
||||
sample_rate_1 = audio1["sample_rate"]
|
||||
@ -630,9 +595,6 @@ class AudioMerge(IO.ComfyNode):
|
||||
length_1 = waveform_1.shape[-1]
|
||||
length_2 = waveform_2.shape[-1]
|
||||
|
||||
if length_1 == 0 or length_2 == 0:
|
||||
return IO.NodeOutput({"waveform": waveform_1, "sample_rate": output_sample_rate})
|
||||
|
||||
if length_2 > length_1:
|
||||
logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.")
|
||||
waveform_2 = waveform_2[..., :length_1]
|
||||
@ -684,8 +646,6 @@ class AudioAdjustVolume(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, volume) -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
return IO.NodeOutput(None)
|
||||
if volume == 0:
|
||||
return IO.NodeOutput(audio)
|
||||
waveform = audio["waveform"]
|
||||
@ -769,14 +729,8 @@ class AudioEqualizer3Band(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
return IO.NodeOutput(None)
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
if waveform.shape[-1] == 0:
|
||||
return IO.NodeOutput(audio)
|
||||
|
||||
eq_waveform = waveform.clone()
|
||||
|
||||
# 1. Apply Low Shelf (Bass)
|
||||
|
||||
249
comfy_extras/nodes_save_image_promotable.py
Normal file
249
comfy_extras/nodes_save_image_promotable.py
Normal file
@ -0,0 +1,249 @@
|
||||
"""
|
||||
SaveImagePromotable: a pass-through SaveImage variant with accumulating previews
|
||||
and a "promote/lock" feature.
|
||||
|
||||
Modes:
|
||||
- Pass-through (default): saves incoming images, emits preview UI, returns the
|
||||
input tensor as output. With `accumulate=True`, the frontend appends previews
|
||||
to a gallery instead of replacing it.
|
||||
- Locked: when `promoted_asset_ref` is a non-empty JSON ref to a saved asset,
|
||||
the node skips saving, loads the referenced image, and outputs that image.
|
||||
The frontend is expected to write the ref into the widget when the user
|
||||
clicks the "lock" UI on a preview.
|
||||
|
||||
Caching: IS_CHANGED returns a stable key derived from the ref (+ file mtime)
|
||||
when locked, so re-queues with the same lock are cache hits and upstream
|
||||
ancestors are skipped. Unlocked, IS_CHANGED returns False to defer to normal
|
||||
input-signature caching.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageOps, ImageSequence
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy.cli_args import args
|
||||
|
||||
|
||||
def _parse_promoted_ref(promoted_asset_ref: str) -> dict | None:
|
||||
if not promoted_asset_ref:
|
||||
return None
|
||||
try:
|
||||
ref = json.loads(promoted_asset_ref)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
if not isinstance(ref, dict):
|
||||
return None
|
||||
filename = ref.get("filename")
|
||||
if not isinstance(filename, str) or not filename:
|
||||
return None
|
||||
subfolder = ref.get("subfolder", "") or ""
|
||||
asset_type = ref.get("type", "output") or "output"
|
||||
if not isinstance(subfolder, str) or not isinstance(asset_type, str):
|
||||
return None
|
||||
# Reject anything that could escape the base directory.
|
||||
if os.path.isabs(subfolder) or ".." in subfolder.split(os.sep):
|
||||
return None
|
||||
if os.path.isabs(filename) or ".." in filename.split(os.sep):
|
||||
return None
|
||||
return {"filename": filename, "subfolder": subfolder, "type": asset_type}
|
||||
|
||||
|
||||
def _resolve_ref_path(ref: dict) -> str | None:
|
||||
asset_type = ref["type"]
|
||||
if asset_type == "output":
|
||||
base = folder_paths.get_output_directory()
|
||||
elif asset_type == "input":
|
||||
base = folder_paths.get_input_directory()
|
||||
elif asset_type == "temp":
|
||||
base = folder_paths.get_temp_directory()
|
||||
else:
|
||||
return None
|
||||
path = os.path.join(base, ref["subfolder"], ref["filename"])
|
||||
# Defense-in-depth: ensure the resolved path stays inside the base dir.
|
||||
base_real = os.path.realpath(base)
|
||||
path_real = os.path.realpath(path)
|
||||
if not path_real.startswith(base_real + os.sep) and path_real != base_real:
|
||||
return None
|
||||
if not os.path.isfile(path_real):
|
||||
return None
|
||||
return path_real
|
||||
|
||||
|
||||
def _load_image_tensor(path: str) -> torch.Tensor:
|
||||
img = node_helpers.pillow(Image.open, path)
|
||||
output_images: list[torch.Tensor] = []
|
||||
w: int | None = None
|
||||
h: int | None = None
|
||||
for frame in ImageSequence.Iterator(img):
|
||||
frame = node_helpers.pillow(ImageOps.exif_transpose, frame)
|
||||
image = frame.convert("RGB")
|
||||
if not output_images:
|
||||
w, h = image.size
|
||||
if image.size != (w, h):
|
||||
continue
|
||||
arr = np.array(image).astype(np.float32) / 255.0
|
||||
output_images.append(torch.from_numpy(arr)[None,])
|
||||
if not output_images:
|
||||
raise RuntimeError(f"Failed to decode any frames from {path}")
|
||||
return torch.cat(output_images, dim=0)
|
||||
|
||||
|
||||
class SaveImagePromotable:
|
||||
"""Pass-through SaveImage with accumulating previews and promote/lock.
|
||||
|
||||
Inputs:
|
||||
images: IMAGE tensor to save + pass through (ignored when locked).
|
||||
filename_prefix: STRING prefix for saved files.
|
||||
accumulate: BOOLEAN — when True, frontend appends previews to gallery.
|
||||
promoted_asset_ref: STRING — JSON ref written by the frontend on lock.
|
||||
Empty string means "not locked, normal pass-through".
|
||||
Output:
|
||||
IMAGE — input pass-through, or the loaded promoted image when locked.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.compress_level = 4
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": (
|
||||
"IMAGE",
|
||||
{
|
||||
"tooltip": "Images to save and pass through. Ignored when a promoted asset is locked."
|
||||
},
|
||||
),
|
||||
"filename_prefix": (
|
||||
"STRING",
|
||||
{"default": "ComfyUI", "tooltip": "Prefix for saved files."},
|
||||
),
|
||||
"accumulate": (
|
||||
"BOOLEAN",
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "When enabled, previews append to a per-node gallery instead of replacing it.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"promoted_asset_ref": (
|
||||
"STRING",
|
||||
{
|
||||
"default": "",
|
||||
"multiline": False,
|
||||
"tooltip": "JSON ref to a saved asset. Set by the UI; do not edit manually.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("images",)
|
||||
FUNCTION = "execute"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "image"
|
||||
DESCRIPTION = "Saves images, shows accumulating previews, and passes the input through. A promoted (locked) preview overrides pass-through to output the chosen image."
|
||||
|
||||
def _save_images(self, images, filename_prefix, prompt, extra_pnginfo):
|
||||
full_output_folder, filename, counter, subfolder, _ = (
|
||||
folder_paths.get_save_image_path(
|
||||
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
)
|
||||
results: list[dict] = []
|
||||
for batch_number, image in enumerate(images):
|
||||
arr = 255.0 * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8))
|
||||
metadata: PngInfo | None = None
|
||||
if not args.disable_metadata:
|
||||
metadata = PngInfo()
|
||||
if prompt is not None:
|
||||
metadata.add_text("prompt", json.dumps(prompt))
|
||||
if extra_pnginfo is not None:
|
||||
for key in extra_pnginfo:
|
||||
metadata.add_text(key, json.dumps(extra_pnginfo[key]))
|
||||
filename_with_batch = filename.replace("%batch_num%", str(batch_number))
|
||||
out_name = f"{filename_with_batch}_{counter:05}_.png"
|
||||
img.save(
|
||||
os.path.join(full_output_folder, out_name),
|
||||
pnginfo=metadata,
|
||||
compress_level=self.compress_level,
|
||||
)
|
||||
results.append(
|
||||
{"filename": out_name, "subfolder": subfolder, "type": self.type}
|
||||
)
|
||||
counter += 1
|
||||
return results
|
||||
|
||||
def execute(
|
||||
self,
|
||||
images,
|
||||
filename_prefix="ComfyUI",
|
||||
accumulate=False, # noqa: ARG002
|
||||
promoted_asset_ref="",
|
||||
prompt=None,
|
||||
extra_pnginfo=None,
|
||||
):
|
||||
ref = _parse_promoted_ref(promoted_asset_ref)
|
||||
if ref is not None:
|
||||
path = _resolve_ref_path(ref)
|
||||
if path is not None:
|
||||
tensor = _load_image_tensor(path)
|
||||
tensor = tensor.to(device=images.device, dtype=images.dtype)
|
||||
return {
|
||||
"ui": {"images": [ref]},
|
||||
"result": (tensor,),
|
||||
}
|
||||
# Ref is set but stale (file deleted / failed validation): fall
|
||||
# through to pass-through so the user gets a working graph rather
|
||||
# than an execution error.
|
||||
|
||||
saved = self._save_images(images, filename_prefix, prompt, extra_pnginfo)
|
||||
return {"ui": {"images": saved}, "result": (images,)}
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(
|
||||
cls,
|
||||
images, # noqa: ARG003
|
||||
filename_prefix="ComfyUI",
|
||||
accumulate=False, # noqa: ARG003
|
||||
promoted_asset_ref="",
|
||||
prompt=None, # noqa: ARG003
|
||||
extra_pnginfo=None, # noqa: ARG003
|
||||
):
|
||||
ref = _parse_promoted_ref(promoted_asset_ref)
|
||||
if ref is None:
|
||||
return False
|
||||
path = _resolve_ref_path(ref)
|
||||
if path is None:
|
||||
return f"PROMOTED::MISSING::{promoted_asset_ref}"
|
||||
try:
|
||||
stat = os.stat(path)
|
||||
sig = f"{stat.st_size}:{stat.st_mtime_ns}"
|
||||
except OSError:
|
||||
sig = "NOSTAT"
|
||||
return f"PROMOTED::{promoted_asset_ref}::{sig}"
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SaveImagePromotable": SaveImagePromotable,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SaveImagePromotable": "Save Image (Promotable, PoC)",
|
||||
}
|
||||
85
execution.py
85
execution.py
@ -83,7 +83,7 @@ class IsChangedCache:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _, v3_data, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
|
||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||
@ -215,52 +215,7 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
v3_data["hidden_inputs"] = hidden_inputs_v3
|
||||
return input_data_all, missing_keys, v3_data, valid_inputs
|
||||
|
||||
def _check_resolved_input_bounds(name, val, input_type, extra_info):
|
||||
"""Raise ValueError if a single resolved value violates declared bounds."""
|
||||
if input_type == "STRING":
|
||||
if not isinstance(val, str):
|
||||
return
|
||||
min_length = extra_info.get("minLength")
|
||||
max_length = extra_info.get("maxLength")
|
||||
if min_length is not None and len(val) < min_length:
|
||||
raise ValueError(f"Input '{name}': string length {len(val)} is shorter than minLength of {min_length}")
|
||||
if max_length is not None and len(val) > max_length:
|
||||
raise ValueError(f"Input '{name}': string length {len(val)} is longer than maxLength of {max_length}")
|
||||
elif input_type in ("INT", "FLOAT"):
|
||||
if isinstance(val, bool) or not isinstance(val, (int, float)):
|
||||
return
|
||||
min_v = extra_info.get("min")
|
||||
max_v = extra_info.get("max")
|
||||
if min_v is not None and val < min_v:
|
||||
raise ValueError(f"Input '{name}': value {val} is smaller than min of {min_v}")
|
||||
if max_v is not None and val > max_v:
|
||||
raise ValueError(f"Input '{name}': value {val} is bigger than max of {max_v}")
|
||||
elif isinstance(input_type, list) or input_type == io.Combo.io_type:
|
||||
combo_options = extra_info.get("options", []) if input_type == io.Combo.io_type else input_type
|
||||
is_multiselect = extra_info.get("multiselect", False)
|
||||
if is_multiselect and isinstance(val, list):
|
||||
invalid_vals = [v for v in val if v not in combo_options]
|
||||
else:
|
||||
invalid_vals = [val] if val not in combo_options else []
|
||||
if invalid_vals:
|
||||
raise ValueError(f"Input '{name}': value(s) {invalid_vals} not in combo options")
|
||||
|
||||
|
||||
def _validate_resolved_inputs(class_def, input_data_all, valid_inputs):
|
||||
"""Enforce declared input bounds against resolved values, including values that arrive via links."""
|
||||
if not getattr(class_def, "RUNTIME_INPUT_VALIDATION", False):
|
||||
return
|
||||
|
||||
for x, values in input_data_all.items():
|
||||
input_type, _, extra_info = get_input_info(class_def, x, valid_inputs)
|
||||
if input_type is None or extra_info is None:
|
||||
continue
|
||||
for val in values:
|
||||
if val is None:
|
||||
continue
|
||||
_check_resolved_input_bounds(x, val, input_type, extra_info)
|
||||
return input_data_all, missing_keys, v3_data
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
@ -525,7 +480,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
has_subgraph = False
|
||||
else:
|
||||
get_progress_state().start_progress(unique_id)
|
||||
input_data_all, missing_keys, v3_data, valid_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
@ -554,8 +509,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
execution_list.make_input_strong_link(unique_id, i)
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
|
||||
_validate_resolved_inputs(class_def, input_data_all, valid_inputs)
|
||||
|
||||
def execution_block_cb(block):
|
||||
if block.message is not None:
|
||||
mes = {
|
||||
@ -1061,36 +1014,6 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if input_type == "STRING":
|
||||
min_length = extra_info.get("minLength")
|
||||
max_length = extra_info.get("maxLength")
|
||||
if min_length is not None and len(val) < min_length:
|
||||
error = {
|
||||
"type": "value_shorter_than_min_length",
|
||||
"message": f"Value length {len(val)} shorter than min length of {min_length}",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
if max_length is not None and len(val) > max_length:
|
||||
error = {
|
||||
"type": "value_longer_than_max_length",
|
||||
"message": f"Value length {len(val)} longer than max length of {max_length}",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if isinstance(input_type, list) or input_type == io.Combo.io_type:
|
||||
if input_type == io.Combo.io_type:
|
||||
combo_options = extra_info.get("options", [])
|
||||
@ -1127,7 +1050,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
||||
continue
|
||||
|
||||
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||
input_data_all, _, v3_data, _ = get_input_data(inputs, obj_class, unique_id)
|
||||
input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
|
||||
input_filtered = {}
|
||||
for x in input_data_all:
|
||||
if x in validate_function_inputs or validate_has_kwargs:
|
||||
|
||||
1
nodes.py
1
nodes.py
@ -2397,6 +2397,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_fresca.py",
|
||||
"nodes_apg.py",
|
||||
"nodes_preview_any.py",
|
||||
"nodes_save_image_promotable.py",
|
||||
"nodes_ace.py",
|
||||
"nodes_string.py",
|
||||
"nodes_camera_trajectory.py",
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.43.18
|
||||
comfyui-workflow-templates==0.9.77
|
||||
comfyui-workflow-templates==0.9.75
|
||||
comfyui-embedded-docs==0.5.0
|
||||
torch
|
||||
torchsde
|
||||
|
||||
289
tests-unit/comfy_extras_test/save_image_promotable_test.py
Normal file
289
tests-unit/comfy_extras_test/save_image_promotable_test.py
Normal file
@ -0,0 +1,289 @@
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.MAX_RESOLUTION = 16384
|
||||
mock_server = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
|
||||
from comfy_extras import nodes_save_image_promotable as mod
|
||||
|
||||
|
||||
def _make_image(width=8, height=4):
|
||||
return torch.rand(1, height, width, 3)
|
||||
|
||||
|
||||
def _write_png(path: str, width=8, height=4):
|
||||
arr = (np.random.rand(height, width, 3) * 255).astype(np.uint8)
|
||||
Image.fromarray(arr).save(path)
|
||||
|
||||
|
||||
class TestParseRef:
|
||||
def test_empty(self):
|
||||
assert mod._parse_promoted_ref("") is None
|
||||
|
||||
def test_invalid_json(self):
|
||||
assert mod._parse_promoted_ref("{not json") is None
|
||||
|
||||
def test_non_object(self):
|
||||
assert mod._parse_promoted_ref('"a string"') is None
|
||||
assert mod._parse_promoted_ref("[]") is None
|
||||
|
||||
def test_missing_filename(self):
|
||||
assert mod._parse_promoted_ref('{"subfolder":"x","type":"output"}') is None
|
||||
|
||||
def test_path_traversal_filename(self):
|
||||
ref = json.dumps(
|
||||
{"filename": "../etc/passwd", "subfolder": "", "type": "output"}
|
||||
)
|
||||
assert mod._parse_promoted_ref(ref) is None
|
||||
|
||||
def test_path_traversal_subfolder(self):
|
||||
ref = json.dumps({"filename": "x.png", "subfolder": "../..", "type": "output"})
|
||||
assert mod._parse_promoted_ref(ref) is None
|
||||
|
||||
def test_absolute_filename(self):
|
||||
ref = json.dumps({"filename": "/etc/passwd", "subfolder": "", "type": "output"})
|
||||
assert mod._parse_promoted_ref(ref) is None
|
||||
|
||||
def test_valid(self):
|
||||
ref = json.dumps({"filename": "x.png", "subfolder": "sub", "type": "output"})
|
||||
parsed = mod._parse_promoted_ref(ref)
|
||||
assert parsed == {"filename": "x.png", "subfolder": "sub", "type": "output"}
|
||||
|
||||
def test_defaults_applied(self):
|
||||
ref = json.dumps({"filename": "x.png"})
|
||||
parsed = mod._parse_promoted_ref(ref)
|
||||
assert parsed == {"filename": "x.png", "subfolder": "", "type": "output"}
|
||||
|
||||
|
||||
class TestResolveRefPath:
|
||||
def test_unknown_type(self, tmp_path):
|
||||
with patch.object(
|
||||
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
|
||||
):
|
||||
assert (
|
||||
mod._resolve_ref_path(
|
||||
{"filename": "x.png", "subfolder": "", "type": "garbage"}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_missing_file(self, tmp_path):
|
||||
with patch.object(
|
||||
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
|
||||
):
|
||||
assert (
|
||||
mod._resolve_ref_path(
|
||||
{"filename": "missing.png", "subfolder": "", "type": "output"}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_resolves_file(self, tmp_path):
|
||||
target = tmp_path / "img.png"
|
||||
_write_png(str(target))
|
||||
with patch.object(
|
||||
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
|
||||
):
|
||||
resolved = mod._resolve_ref_path(
|
||||
{"filename": "img.png", "subfolder": "", "type": "output"}
|
||||
)
|
||||
assert resolved is not None
|
||||
assert os.path.realpath(resolved) == os.path.realpath(str(target))
|
||||
|
||||
def test_resolves_file_in_subfolder(self, tmp_path):
|
||||
sub = tmp_path / "nested"
|
||||
sub.mkdir()
|
||||
target = sub / "img.png"
|
||||
_write_png(str(target))
|
||||
with patch.object(
|
||||
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
|
||||
):
|
||||
resolved = mod._resolve_ref_path(
|
||||
{"filename": "img.png", "subfolder": "nested", "type": "output"}
|
||||
)
|
||||
assert resolved is not None
|
||||
assert os.path.realpath(resolved) == os.path.realpath(str(target))
|
||||
|
||||
def test_symlink_escape_rejected(self, tmp_path):
|
||||
outside = tmp_path / "outside"
|
||||
outside.mkdir()
|
||||
secret = outside / "secret.png"
|
||||
_write_png(str(secret))
|
||||
base = tmp_path / "base"
|
||||
base.mkdir()
|
||||
link = base / "link.png"
|
||||
os.symlink(str(secret), str(link))
|
||||
with patch.object(
|
||||
mod.folder_paths, "get_output_directory", return_value=str(base)
|
||||
):
|
||||
resolved = mod._resolve_ref_path(
|
||||
{"filename": "link.png", "subfolder": "", "type": "output"}
|
||||
)
|
||||
assert resolved is None
|
||||
|
||||
|
||||
class TestNodeContract:
|
||||
def test_input_types_shape(self):
|
||||
inp = mod.SaveImagePromotable.INPUT_TYPES()
|
||||
assert set(inp["required"].keys()) == {
|
||||
"images",
|
||||
"filename_prefix",
|
||||
"accumulate",
|
||||
}
|
||||
assert set(inp["optional"].keys()) == {"promoted_asset_ref"}
|
||||
assert set(inp["hidden"].keys()) == {"prompt", "extra_pnginfo"}
|
||||
assert inp["required"]["accumulate"][0] == "BOOLEAN"
|
||||
assert inp["required"]["accumulate"][1]["default"] is False
|
||||
|
||||
def test_class_metadata(self):
|
||||
cls = mod.SaveImagePromotable
|
||||
assert cls.RETURN_TYPES == ("IMAGE",)
|
||||
assert cls.RETURN_NAMES == ("images",)
|
||||
assert cls.OUTPUT_NODE is True
|
||||
assert cls.FUNCTION == "execute"
|
||||
assert "SaveImagePromotable" in mod.NODE_CLASS_MAPPINGS
|
||||
assert mod.NODE_CLASS_MAPPINGS["SaveImagePromotable"] is cls
|
||||
|
||||
|
||||
class TestExecutePassthrough:
|
||||
def test_passthrough_saves_and_returns_input(self, tmp_path):
|
||||
node = mod.SaveImagePromotable()
|
||||
node.output_dir = str(tmp_path)
|
||||
images = _make_image()
|
||||
|
||||
with (
|
||||
patch.object(mod.args, "disable_metadata", True),
|
||||
patch.object(mod.folder_paths, "get_save_image_path") as get_path,
|
||||
):
|
||||
get_path.return_value = (str(tmp_path), "ComfyUI", 1, "", "ComfyUI")
|
||||
result = node.execute(
|
||||
images,
|
||||
filename_prefix="ComfyUI",
|
||||
accumulate=False,
|
||||
promoted_asset_ref="",
|
||||
)
|
||||
|
||||
assert "ui" in result
|
||||
assert "result" in result
|
||||
assert torch.equal(result["result"][0], images)
|
||||
assert len(result["ui"]["images"]) == 1
|
||||
saved_name = result["ui"]["images"][0]["filename"]
|
||||
assert os.path.isfile(os.path.join(str(tmp_path), saved_name))
|
||||
|
||||
def test_stale_ref_falls_through_to_passthrough(self, tmp_path):
|
||||
node = mod.SaveImagePromotable()
|
||||
node.output_dir = str(tmp_path)
|
||||
images = _make_image()
|
||||
ref = json.dumps(
|
||||
{"filename": "does_not_exist.png", "subfolder": "", "type": "output"}
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(mod.args, "disable_metadata", True),
|
||||
patch.object(mod.folder_paths, "get_save_image_path") as get_path,
|
||||
patch.object(
|
||||
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
|
||||
),
|
||||
):
|
||||
get_path.return_value = (str(tmp_path), "ComfyUI", 1, "", "ComfyUI")
|
||||
result = node.execute(images, promoted_asset_ref=ref)
|
||||
|
||||
assert torch.equal(result["result"][0], images)
|
||||
|
||||
|
||||
class TestExecuteLocked:
|
||||
def test_locked_outputs_loaded_image(self, tmp_path):
|
||||
target = tmp_path / "promoted.png"
|
||||
_write_png(str(target), width=8, height=4)
|
||||
ref = json.dumps(
|
||||
{"filename": "promoted.png", "subfolder": "", "type": "output"}
|
||||
)
|
||||
node = mod.SaveImagePromotable()
|
||||
node.output_dir = str(tmp_path)
|
||||
upstream = _make_image(width=8, height=4)
|
||||
|
||||
with patch.object(
|
||||
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
|
||||
):
|
||||
result = node.execute(upstream, promoted_asset_ref=ref)
|
||||
|
||||
assert result["ui"]["images"] == [
|
||||
{"filename": "promoted.png", "subfolder": "", "type": "output"}
|
||||
]
|
||||
out = result["result"][0]
|
||||
assert out.shape == upstream.shape
|
||||
assert out.dtype == upstream.dtype
|
||||
assert not torch.equal(out, upstream)
|
||||
|
||||
def test_locked_does_not_save(self, tmp_path):
|
||||
target = tmp_path / "promoted.png"
|
||||
_write_png(str(target))
|
||||
ref = json.dumps(
|
||||
{"filename": "promoted.png", "subfolder": "", "type": "output"}
|
||||
)
|
||||
node = mod.SaveImagePromotable()
|
||||
node.output_dir = str(tmp_path)
|
||||
images = _make_image()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
|
||||
),
|
||||
patch.object(node, "_save_images") as save_mock,
|
||||
):
|
||||
node.execute(images, promoted_asset_ref=ref)
|
||||
|
||||
save_mock.assert_not_called()
|
||||
|
||||
|
||||
class TestIsChanged:
|
||||
def test_unlocked_returns_false(self):
|
||||
assert (
|
||||
mod.SaveImagePromotable.IS_CHANGED(images=None, promoted_asset_ref="")
|
||||
is False
|
||||
)
|
||||
|
||||
def test_locked_missing_file(self, tmp_path):
|
||||
ref = json.dumps({"filename": "missing.png", "subfolder": "", "type": "output"})
|
||||
with patch.object(
|
||||
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
|
||||
):
|
||||
key = mod.SaveImagePromotable.IS_CHANGED(
|
||||
images=None, promoted_asset_ref=ref
|
||||
)
|
||||
assert isinstance(key, str)
|
||||
assert key.startswith("PROMOTED::MISSING::")
|
||||
|
||||
def test_locked_stable_key(self, tmp_path):
|
||||
target = tmp_path / "p.png"
|
||||
_write_png(str(target))
|
||||
ref = json.dumps({"filename": "p.png", "subfolder": "", "type": "output"})
|
||||
with patch.object(
|
||||
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
|
||||
):
|
||||
k1 = mod.SaveImagePromotable.IS_CHANGED(images=None, promoted_asset_ref=ref)
|
||||
k2 = mod.SaveImagePromotable.IS_CHANGED(images=None, promoted_asset_ref=ref)
|
||||
assert k1 == k2
|
||||
assert k1.startswith("PROMOTED::")
|
||||
|
||||
def test_locked_key_changes_when_file_changes(self, tmp_path):
|
||||
target = tmp_path / "p.png"
|
||||
_write_png(str(target), width=8, height=4)
|
||||
ref = json.dumps({"filename": "p.png", "subfolder": "", "type": "output"})
|
||||
with patch.object(
|
||||
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
|
||||
):
|
||||
k1 = mod.SaveImagePromotable.IS_CHANGED(images=None, promoted_asset_ref=ref)
|
||||
os.utime(str(target), (1234567890, 1234567890))
|
||||
with patch.object(
|
||||
mod.folder_paths, "get_output_directory", return_value=str(tmp_path)
|
||||
):
|
||||
k2 = mod.SaveImagePromotable.IS_CHANGED(images=None, promoted_asset_ref=ref)
|
||||
assert k1 != k2
|
||||
@ -1,23 +1,9 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.model_detection import detect_unet_config, model_config_from_unet_config
|
||||
import comfy.supported_models
|
||||
|
||||
|
||||
def _freeze(value):
|
||||
"""Recursively convert a value to a hashable form so configs can be
|
||||
compared/used as dict keys or set members."""
|
||||
if isinstance(value, dict):
|
||||
return frozenset((k, _freeze(v)) for k, v in value.items())
|
||||
if isinstance(value, (list, tuple)):
|
||||
return tuple(_freeze(v) for v in value)
|
||||
if isinstance(value, set):
|
||||
return frozenset(_freeze(v) for v in value)
|
||||
return value
|
||||
|
||||
|
||||
def _make_longcat_comfyui_sd():
|
||||
"""Minimal ComfyUI-format state dict for pre-converted LongCat-Image weights."""
|
||||
sd = {}
|
||||
@ -124,21 +110,3 @@ class TestModelDetection:
|
||||
model_config = model_config_from_unet_config(unet_config, sd)
|
||||
assert model_config is not None
|
||||
assert type(model_config).__name__ == "FluxSchnell"
|
||||
|
||||
def test_unet_config_and_required_keys_combination_is_unique(self):
|
||||
"""Each model in the registry must have a unique combination of
|
||||
``unet_config`` and ``required_keys``. If two models share the same
|
||||
combination, ``BASE.matches`` cannot disambiguate between them and the
|
||||
first one in the list will always win."""
|
||||
models = comfy.supported_models.models
|
||||
groups = defaultdict(list)
|
||||
for model in models:
|
||||
key = (_freeze(model.unet_config), _freeze(model.required_keys))
|
||||
groups[key].append(model.__name__)
|
||||
|
||||
duplicates = {k: names for k, names in groups.items() if len(names) > 1}
|
||||
assert not duplicates, (
|
||||
"Found models sharing the same (unet_config, required_keys) "
|
||||
"combination, which makes detection ambiguous: "
|
||||
+ "; ".join(", ".join(names) for names in duplicates.values())
|
||||
)
|
||||
|
||||
@ -1011,124 +1011,3 @@ class TestExecution:
|
||||
"""Test getting a non-existent job returns 404"""
|
||||
job = client.get_job("nonexistent-job-id")
|
||||
assert job is None, "Non-existent job should return None"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text, expect_error", [
|
||||
("hello", False), # 5 chars, within [3, 10]
|
||||
("abc", False), # 3 chars, exact min boundary
|
||||
("abcdefghij", False), # 10 chars, exact max boundary
|
||||
("ab", True), # 2 chars, below min
|
||||
("abcdefghijk", True), # 11 chars, above max
|
||||
("", True), # 0 chars, below min
|
||||
])
|
||||
def test_string_length_widget_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test minLength/maxLength validation for direct widget values (validate_inputs path)."""
|
||||
g = builder
|
||||
node = g.node("StubStringWithLength", text=text)
|
||||
g.node("SaveImage", images=node.out(0))
|
||||
if expect_error:
|
||||
with pytest.raises(urllib.error.HTTPError) as exc_info:
|
||||
client.run(g)
|
||||
assert exc_info.value.code == 400
|
||||
else:
|
||||
client.run(g)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text, expect_error", [
|
||||
("hello", False), # within bounds
|
||||
("ab", True), # below min
|
||||
("abcdefghijk", True), # above max
|
||||
])
|
||||
def test_string_length_linked_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test minLength/maxLength validation for linked inputs when node opts in via RUNTIME_INPUT_VALIDATION=True."""
|
||||
g = builder
|
||||
str_node = g.node("StubStringOutput", value=text)
|
||||
node = g.node("StubStringWithLength", text=str_node.out(0))
|
||||
g.node("SaveImage", images=node.out(0))
|
||||
|
||||
if expect_error:
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
|
||||
else:
|
||||
client.run(g)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text", [
|
||||
"ab", # below declared minLength
|
||||
"abcdefghijk", # above declared maxLength
|
||||
"", # empty
|
||||
"hello", # within bounds
|
||||
])
|
||||
def test_string_length_linked_skipped_without_flag(self, text, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Without RUNTIME_INPUT_VALIDATION=True, declared bounds must NOT be enforced for linked values.
|
||||
|
||||
Preserves V1 behavior: many existing workflows rely on out-of-bounds values passing
|
||||
through links. Adding declared bounds without the flag must not break them.
|
||||
"""
|
||||
g = builder
|
||||
str_node = g.node("StubStringOutput", value=text)
|
||||
node = g.node("StubStringWithLengthNoFlag", text=str_node.out(0))
|
||||
g.node("SaveImage", images=node.out(0))
|
||||
client.run(g)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, expect_error", [
|
||||
(5, False), # within [1, 10]
|
||||
(1, False), # exact min boundary
|
||||
(10, False), # exact max boundary
|
||||
(0, True), # below min
|
||||
(11, True), # above max
|
||||
(-7, True), # well below min
|
||||
])
|
||||
def test_int_bounds_linked_validation(self, value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
"""min/max validation for linked INT inputs when node opts in via RUNTIME_INPUT_VALIDATION=True.
|
||||
|
||||
Direct widget INT values are already validated pre-execution. This test exercises the
|
||||
symmetric runtime path for values arriving through a connection.
|
||||
"""
|
||||
g = builder
|
||||
int_node = g.node("StubInt", value=value)
|
||||
node = g.node("StubIntWithBounds", value=int_node.out(0))
|
||||
g.node("SaveImage", images=node.out(0))
|
||||
|
||||
if expect_error:
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
|
||||
else:
|
||||
client.run(g)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("choice, expect_error", [
|
||||
("RED", False),
|
||||
("GREEN", False),
|
||||
("BLUE", False),
|
||||
("PURPLE", True),
|
||||
("", True),
|
||||
("red", True), # case-sensitive
|
||||
])
|
||||
def test_combo_membership_linked_validation(self, choice, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
"""COMBO option membership for linked values when node opts in via RUNTIME_INPUT_VALIDATION=True.
|
||||
|
||||
StubComboWithOptions declares ``input_types`` in VALIDATE_INPUTS to bypass the engine's
|
||||
link-type compatibility check, so we can feed a STRING into a COMBO and verify the
|
||||
runtime membership check fires.
|
||||
"""
|
||||
g = builder
|
||||
str_node = g.node("StubStringOutput", value=choice)
|
||||
node = g.node("StubComboWithOptions", choice=str_node.out(0))
|
||||
g.node("SaveImage", images=node.out(0))
|
||||
|
||||
if expect_error:
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
|
||||
else:
|
||||
client.run(g)
|
||||
|
||||
@ -113,117 +113,12 @@ class StubFloat:
|
||||
def stub_float(self, value):
|
||||
return (value,)
|
||||
|
||||
class StubStringOutput:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("STRING", {"default": ""}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "stub_string"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_string(self, value):
|
||||
return (value,)
|
||||
|
||||
class StubStringWithLength:
|
||||
"""STRING input with declared bounds AND opted in to runtime validation (RUNTIME_INPUT_VALIDATION = True)."""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"text": ("STRING", {"default": "hello", "minLength": 3, "maxLength": 10}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stub_string_with_length"
|
||||
RUNTIME_INPUT_VALIDATION = True
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_string_with_length(self, text):
|
||||
return (torch.zeros(1, 64, 64, 3),)
|
||||
|
||||
class StubStringWithLengthNoFlag:
|
||||
"""Same bounds as StubStringWithLength but NOT opted in - linked values must flow through unchecked."""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"text": ("STRING", {"default": "hello", "minLength": 3, "maxLength": 10}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stub_string_with_length_no_flag"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_string_with_length_no_flag(self, text):
|
||||
return (torch.zeros(1, 64, 64, 3),)
|
||||
|
||||
class StubIntWithBounds:
|
||||
"""INT input with min/max bounds AND opted in to runtime validation."""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("INT", {"default": 5, "min": 1, "max": 10}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stub_int_with_bounds"
|
||||
RUNTIME_INPUT_VALIDATION = True
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_int_with_bounds(self, value):
|
||||
return (torch.zeros(1, 64, 64, 3),)
|
||||
|
||||
class StubComboWithOptions:
|
||||
"""COMBO input opted in to runtime validation.
|
||||
|
||||
Declares ``input_types`` in VALIDATE_INPUTS to bypass the engine's link-type compatibility
|
||||
check, allowing tests to link a STRING into a COMBO and exercise the runtime membership check.
|
||||
"""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"choice": (["RED", "GREEN", "BLUE"],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stub_combo"
|
||||
RUNTIME_INPUT_VALIDATION = True
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, input_types):
|
||||
return True
|
||||
|
||||
def stub_combo(self, choice):
|
||||
return (torch.zeros(1, 64, 64, 3),)
|
||||
|
||||
TEST_STUB_NODE_CLASS_MAPPINGS = {
|
||||
"StubImage": StubImage,
|
||||
"StubConstantImage": StubConstantImage,
|
||||
"StubMask": StubMask,
|
||||
"StubInt": StubInt,
|
||||
"StubFloat": StubFloat,
|
||||
"StubStringOutput": StubStringOutput,
|
||||
"StubStringWithLength": StubStringWithLength,
|
||||
"StubStringWithLengthNoFlag": StubStringWithLengthNoFlag,
|
||||
"StubIntWithBounds": StubIntWithBounds,
|
||||
"StubComboWithOptions": StubComboWithOptions,
|
||||
}
|
||||
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StubImage": "Stub Image",
|
||||
@ -231,9 +126,4 @@ TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StubMask": "Stub Mask",
|
||||
"StubInt": "Stub Int",
|
||||
"StubFloat": "Stub Float",
|
||||
"StubStringOutput": "Stub String Output",
|
||||
"StubStringWithLength": "Stub String With Length",
|
||||
"StubStringWithLengthNoFlag": "Stub String With Length (No Flag)",
|
||||
"StubIntWithBounds": "Stub Int With Bounds",
|
||||
"StubComboWithOptions": "Stub Combo With Options",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user