Compare commits

..

4 Commits

Author SHA1 Message Date
9a98cdc389 feat: declarative input validation with opt-in runtime enforcement
- Add `minLength`/`maxLength` to `IO.String.Input`, mirroring existing `min`/`max` for `Int`/`Float`.
- Add `runtime_input_validation` to V3 `Schema` (and `RUNTIME_INPUT_VALIDATION` class attribute for V1 nodes). Default `False`

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-14 20:56:09 +03:00
3d870ff51f chore: update workflow templates to v0.9.77 (#13895) 2026-05-15 01:25:18 +08:00
1f28908d6e Make audio processing nodes handle None -inputs (#13879) 2026-05-14 10:51:35 +08:00
fb51a988b6 Add test that each model has unique identifiers CORE-134 (#13654) 2026-05-14 10:41:25 +08:00
7 changed files with 416 additions and 9 deletions

View File

@ -327,11 +327,14 @@ class String(ComfyTypeIO):
'''String input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None,
min_length: int=None, max_length: int=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.multiline = multiline
self.placeholder = placeholder
self.dynamic_prompts = dynamic_prompts
self.min_length = min_length
self.max_length = max_length
self.default: str
def as_dict(self):
@ -339,6 +342,8 @@ class String(ComfyTypeIO):
"multiline": self.multiline,
"placeholder": self.placeholder,
"dynamicPrompts": self.dynamic_prompts,
"minLength": self.min_length,
"maxLength": self.max_length,
})
@comfytype(io_type="COMBO")
@ -1551,6 +1556,12 @@ class Schema:
Use this for nodes with interactive/operable UI regions that produce intermediate outputs
(e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image).
"""
runtime_input_validation: bool = False
"""Opt this node into runtime validation of declared input bounds (STRING minLength/maxLength,
INT/FLOAT min/max, COMBO membership) against resolved values, including values that arrive via links.
When False, only direct widget values are validated pre-execution and linked values flow through unchecked.
"""
def validate(self):
'''Validate the schema:
@ -2006,6 +2017,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls.GET_SCHEMA()
return cls._ACCEPT_ALL_INPUTS
_RUNTIME_INPUT_VALIDATION = None
@final
@classproperty
def RUNTIME_INPUT_VALIDATION(cls): # noqa
if cls._RUNTIME_INPUT_VALIDATION is None:
cls.GET_SCHEMA()
return cls._RUNTIME_INPUT_VALIDATION
@final
@classmethod
def INPUT_TYPES(cls) -> dict[str, dict]:
@ -2050,6 +2069,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls._NOT_IDEMPOTENT = schema.not_idempotent
if cls._ACCEPT_ALL_INPUTS is None:
cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs
if cls._RUNTIME_INPUT_VALIDATION is None:
cls._RUNTIME_INPUT_VALIDATION = schema.runtime_input_validation
if cls._RETURN_TYPES is None:
output = []

View File

@ -82,6 +82,8 @@ class VAEEncodeAudio(IO.ComfyNode):
@classmethod
def execute(cls, vae, audio) -> IO.NodeOutput:
if audio is None:
raise ValueError("VAEEncodeAudio: input audio is None (source video may have no audio track).")
sample_rate = audio["sample_rate"]
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
if vae_sample_rate != sample_rate:
@ -171,6 +173,8 @@ class SaveAudio(IO.ComfyNode):
@classmethod
def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
if audio is None:
raise ValueError("SaveAudio: input audio is None (source video may have no audio track).")
return IO.NodeOutput(
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
)
@ -198,6 +202,8 @@ class SaveAudioMP3(IO.ComfyNode):
@classmethod
def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
if audio is None:
raise ValueError("SaveAudioMP3: input audio is None (source video may have no audio track).")
return IO.NodeOutput(
ui=UI.AudioSaveHelper.get_save_audio_ui(
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
@ -226,6 +232,8 @@ class SaveAudioOpus(IO.ComfyNode):
@classmethod
def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
if audio is None:
raise ValueError("SaveAudioOpus: input audio is None (source video may have no audio track).")
return IO.NodeOutput(
ui=UI.AudioSaveHelper.get_save_audio_ui(
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
@ -252,6 +260,8 @@ class PreviewAudio(IO.ComfyNode):
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
if audio is None:
raise ValueError("PreviewAudio: input audio is None (source video may have no audio track).")
return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
save_flac = execute # TODO: remove
@ -392,21 +402,26 @@ class TrimAudioDuration(IO.ComfyNode):
@classmethod
def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
if audio is None:
return IO.NodeOutput(None)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
audio_length = waveform.shape[-1]
if audio_length == 0:
return IO.NodeOutput(audio)
if start_index < 0:
start_frame = audio_length + int(round(start_index * sample_rate))
else:
start_frame = int(round(start_index * sample_rate))
start_frame = max(0, min(start_frame, audio_length - 1))
start_frame = max(0, min(start_frame, audio_length))
end_frame = start_frame + int(round(duration * sample_rate))
end_frame = max(0, min(end_frame, audio_length))
if start_frame >= end_frame:
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
raise ValueError("TrimAudioDuration: Start time must be less than end time and be within the audio length.")
return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
@ -433,11 +448,13 @@ class SplitAudioChannels(IO.ComfyNode):
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
if audio is None:
return IO.NodeOutput(None, None)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
if waveform.shape[1] != 2:
raise ValueError("AudioSplit: Input audio has only one channel.")
raise ValueError(f"AudioSplit: Input audio must be stereo (2 channels), got {waveform.shape[1]} channel(s).")
left_channel = waveform[..., 0:1, :]
right_channel = waveform[..., 1:2, :]
@ -465,6 +482,12 @@ class JoinAudioChannels(IO.ComfyNode):
@classmethod
def execute(cls, audio_left, audio_right) -> IO.NodeOutput:
if audio_left is None and audio_right is None:
return IO.NodeOutput(None)
if audio_left is None:
return IO.NodeOutput(audio_right)
if audio_right is None:
return IO.NodeOutput(audio_left)
waveform_left = audio_left["waveform"]
sample_rate_left = audio_left["sample_rate"]
waveform_right = audio_right["waveform"]
@ -538,6 +561,12 @@ class AudioConcat(IO.ComfyNode):
@classmethod
def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
if audio1 is None and audio2 is None:
return IO.NodeOutput(None)
if audio1 is None:
return IO.NodeOutput(audio2)
if audio2 is None:
return IO.NodeOutput(audio1)
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@ -585,6 +614,12 @@ class AudioMerge(IO.ComfyNode):
@classmethod
def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
if audio1 is None and audio2 is None:
return IO.NodeOutput(None)
if audio1 is None:
return IO.NodeOutput(audio2)
if audio2 is None:
return IO.NodeOutput(audio1)
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@ -595,6 +630,9 @@ class AudioMerge(IO.ComfyNode):
length_1 = waveform_1.shape[-1]
length_2 = waveform_2.shape[-1]
if length_1 == 0 or length_2 == 0:
return IO.NodeOutput({"waveform": waveform_1, "sample_rate": output_sample_rate})
if length_2 > length_1:
logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.")
waveform_2 = waveform_2[..., :length_1]
@ -646,6 +684,8 @@ class AudioAdjustVolume(IO.ComfyNode):
@classmethod
def execute(cls, audio, volume) -> IO.NodeOutput:
if audio is None:
return IO.NodeOutput(None)
if volume == 0:
return IO.NodeOutput(audio)
waveform = audio["waveform"]
@ -729,8 +769,14 @@ class AudioEqualizer3Band(IO.ComfyNode):
@classmethod
def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput:
if audio is None:
return IO.NodeOutput(None)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
if waveform.shape[-1] == 0:
return IO.NodeOutput(audio)
eq_waveform = waveform.clone()
# 1. Apply Low Shelf (Bass)

View File

@ -83,7 +83,7 @@ class IsChangedCache:
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
input_data_all, _, v3_data, _ = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
is_changed = await resolve_map_node_over_list_results(is_changed)
@ -215,7 +215,52 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
v3_data["hidden_inputs"] = hidden_inputs_v3
return input_data_all, missing_keys, v3_data
return input_data_all, missing_keys, v3_data, valid_inputs
def _check_resolved_input_bounds(name, val, input_type, extra_info):
"""Raise ValueError if a single resolved value violates declared bounds."""
if input_type == "STRING":
if not isinstance(val, str):
return
min_length = extra_info.get("minLength")
max_length = extra_info.get("maxLength")
if min_length is not None and len(val) < min_length:
raise ValueError(f"Input '{name}': string length {len(val)} is shorter than minLength of {min_length}")
if max_length is not None and len(val) > max_length:
raise ValueError(f"Input '{name}': string length {len(val)} is longer than maxLength of {max_length}")
elif input_type in ("INT", "FLOAT"):
if isinstance(val, bool) or not isinstance(val, (int, float)):
return
min_v = extra_info.get("min")
max_v = extra_info.get("max")
if min_v is not None and val < min_v:
raise ValueError(f"Input '{name}': value {val} is smaller than min of {min_v}")
if max_v is not None and val > max_v:
raise ValueError(f"Input '{name}': value {val} is bigger than max of {max_v}")
elif isinstance(input_type, list) or input_type == io.Combo.io_type:
combo_options = extra_info.get("options", []) if input_type == io.Combo.io_type else input_type
is_multiselect = extra_info.get("multiselect", False)
if is_multiselect and isinstance(val, list):
invalid_vals = [v for v in val if v not in combo_options]
else:
invalid_vals = [val] if val not in combo_options else []
if invalid_vals:
raise ValueError(f"Input '{name}': value(s) {invalid_vals} not in combo options")
def _validate_resolved_inputs(class_def, input_data_all, valid_inputs):
"""Enforce declared input bounds against resolved values, including values that arrive via links."""
if not getattr(class_def, "RUNTIME_INPUT_VALIDATION", False):
return
for x, values in input_data_all.items():
input_type, _, extra_info = get_input_info(class_def, x, valid_inputs)
if input_type is None or extra_info is None:
continue
for val in values:
if val is None:
continue
_check_resolved_input_bounds(x, val, input_type, extra_info)
map_node_over_list = None #Don't hook this please
@ -480,7 +525,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
input_data_all, missing_keys, v3_data, valid_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@ -509,6 +554,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
execution_list.make_input_strong_link(unique_id, i)
return (ExecutionResult.PENDING, None, None)
_validate_resolved_inputs(class_def, input_data_all, valid_inputs)
def execution_block_cb(block):
if block.message is not None:
mes = {
@ -1014,6 +1061,36 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
errors.append(error)
continue
if input_type == "STRING":
min_length = extra_info.get("minLength")
max_length = extra_info.get("maxLength")
if min_length is not None and len(val) < min_length:
error = {
"type": "value_shorter_than_min_length",
"message": f"Value length {len(val)} shorter than min length of {min_length}",
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if max_length is not None and len(val) > max_length:
error = {
"type": "value_longer_than_max_length",
"message": f"Value length {len(val)} longer than max length of {max_length}",
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if isinstance(input_type, list) or input_type == io.Combo.io_type:
if input_type == io.Combo.io_type:
combo_options = extra_info.get("options", [])
@ -1050,7 +1127,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
continue
if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
input_data_all, _, v3_data, _ = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs:

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.43.18
comfyui-workflow-templates==0.9.75
comfyui-workflow-templates==0.9.77
comfyui-embedded-docs==0.5.0
torch
torchsde

View File

@ -1,9 +1,23 @@
from collections import defaultdict
import torch
from comfy.model_detection import detect_unet_config, model_config_from_unet_config
import comfy.supported_models
def _freeze(value):
"""Recursively convert a value to a hashable form so configs can be
compared/used as dict keys or set members."""
if isinstance(value, dict):
return frozenset((k, _freeze(v)) for k, v in value.items())
if isinstance(value, (list, tuple)):
return tuple(_freeze(v) for v in value)
if isinstance(value, set):
return frozenset(_freeze(v) for v in value)
return value
def _make_longcat_comfyui_sd():
"""Minimal ComfyUI-format state dict for pre-converted LongCat-Image weights."""
sd = {}
@ -110,3 +124,21 @@ class TestModelDetection:
model_config = model_config_from_unet_config(unet_config, sd)
assert model_config is not None
assert type(model_config).__name__ == "FluxSchnell"
def test_unet_config_and_required_keys_combination_is_unique(self):
"""Each model in the registry must have a unique combination of
``unet_config`` and ``required_keys``. If two models share the same
combination, ``BASE.matches`` cannot disambiguate between them and the
first one in the list will always win."""
models = comfy.supported_models.models
groups = defaultdict(list)
for model in models:
key = (_freeze(model.unet_config), _freeze(model.required_keys))
groups[key].append(model.__name__)
duplicates = {k: names for k, names in groups.items() if len(names) > 1}
assert not duplicates, (
"Found models sharing the same (unet_config, required_keys) "
"combination, which makes detection ambiguous: "
+ "; ".join(", ".join(names) for names in duplicates.values())
)

View File

@ -1011,3 +1011,124 @@ class TestExecution:
"""Test getting a non-existent job returns 404"""
job = client.get_job("nonexistent-job-id")
assert job is None, "Non-existent job should return None"
@pytest.mark.parametrize("text, expect_error", [
("hello", False), # 5 chars, within [3, 10]
("abc", False), # 3 chars, exact min boundary
("abcdefghij", False), # 10 chars, exact max boundary
("ab", True), # 2 chars, below min
("abcdefghijk", True), # 11 chars, above max
("", True), # 0 chars, below min
])
def test_string_length_widget_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder):
"""Test minLength/maxLength validation for direct widget values (validate_inputs path)."""
g = builder
node = g.node("StubStringWithLength", text=text)
g.node("SaveImage", images=node.out(0))
if expect_error:
with pytest.raises(urllib.error.HTTPError) as exc_info:
client.run(g)
assert exc_info.value.code == 400
else:
client.run(g)
@pytest.mark.parametrize("text, expect_error", [
("hello", False), # within bounds
("ab", True), # below min
("abcdefghijk", True), # above max
])
def test_string_length_linked_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder):
"""Test minLength/maxLength validation for linked inputs when node opts in via RUNTIME_INPUT_VALIDATION=True."""
g = builder
str_node = g.node("StubStringOutput", value=text)
node = g.node("StubStringWithLength", text=str_node.out(0))
g.node("SaveImage", images=node.out(0))
if expect_error:
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
else:
client.run(g)
@pytest.mark.parametrize("text", [
"ab", # below declared minLength
"abcdefghijk", # above declared maxLength
"", # empty
"hello", # within bounds
])
def test_string_length_linked_skipped_without_flag(self, text, client: ComfyClient, builder: GraphBuilder):
"""Without RUNTIME_INPUT_VALIDATION=True, declared bounds must NOT be enforced for linked values.
Preserves V1 behavior: many existing workflows rely on out-of-bounds values passing
through links. Adding declared bounds without the flag must not break them.
"""
g = builder
str_node = g.node("StubStringOutput", value=text)
node = g.node("StubStringWithLengthNoFlag", text=str_node.out(0))
g.node("SaveImage", images=node.out(0))
client.run(g)
@pytest.mark.parametrize("value, expect_error", [
(5, False), # within [1, 10]
(1, False), # exact min boundary
(10, False), # exact max boundary
(0, True), # below min
(11, True), # above max
(-7, True), # well below min
])
def test_int_bounds_linked_validation(self, value, expect_error, client: ComfyClient, builder: GraphBuilder):
"""min/max validation for linked INT inputs when node opts in via RUNTIME_INPUT_VALIDATION=True.
Direct widget INT values are already validated pre-execution. This test exercises the
symmetric runtime path for values arriving through a connection.
"""
g = builder
int_node = g.node("StubInt", value=value)
node = g.node("StubIntWithBounds", value=int_node.out(0))
g.node("SaveImage", images=node.out(0))
if expect_error:
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
else:
client.run(g)
@pytest.mark.parametrize("choice, expect_error", [
("RED", False),
("GREEN", False),
("BLUE", False),
("PURPLE", True),
("", True),
("red", True), # case-sensitive
])
def test_combo_membership_linked_validation(self, choice, expect_error, client: ComfyClient, builder: GraphBuilder):
"""COMBO option membership for linked values when node opts in via RUNTIME_INPUT_VALIDATION=True.
StubComboWithOptions declares ``input_types`` in VALIDATE_INPUTS to bypass the engine's
link-type compatibility check, so we can feed a STRING into a COMBO and verify the
runtime membership check fires.
"""
g = builder
str_node = g.node("StubStringOutput", value=choice)
node = g.node("StubComboWithOptions", choice=str_node.out(0))
g.node("SaveImage", images=node.out(0))
if expect_error:
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
else:
client.run(g)

View File

@ -113,12 +113,117 @@ class StubFloat:
def stub_float(self, value):
return (value,)
class StubStringOutput:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("STRING", {"default": ""}),
},
}
RETURN_TYPES = ("STRING",)
FUNCTION = "stub_string"
CATEGORY = "Testing/Stub Nodes"
def stub_string(self, value):
return (value,)
class StubStringWithLength:
"""STRING input with declared bounds AND opted in to runtime validation (RUNTIME_INPUT_VALIDATION = True)."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": ("STRING", {"default": "hello", "minLength": 3, "maxLength": 10}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_string_with_length"
RUNTIME_INPUT_VALIDATION = True
CATEGORY = "Testing/Stub Nodes"
def stub_string_with_length(self, text):
return (torch.zeros(1, 64, 64, 3),)
class StubStringWithLengthNoFlag:
"""Same bounds as StubStringWithLength but NOT opted in - linked values must flow through unchecked."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": ("STRING", {"default": "hello", "minLength": 3, "maxLength": 10}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_string_with_length_no_flag"
CATEGORY = "Testing/Stub Nodes"
def stub_string_with_length_no_flag(self, text):
return (torch.zeros(1, 64, 64, 3),)
class StubIntWithBounds:
"""INT input with min/max bounds AND opted in to runtime validation."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("INT", {"default": 5, "min": 1, "max": 10}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_int_with_bounds"
RUNTIME_INPUT_VALIDATION = True
CATEGORY = "Testing/Stub Nodes"
def stub_int_with_bounds(self, value):
return (torch.zeros(1, 64, 64, 3),)
class StubComboWithOptions:
"""COMBO input opted in to runtime validation.
Declares ``input_types`` in VALIDATE_INPUTS to bypass the engine's link-type compatibility
check, allowing tests to link a STRING into a COMBO and exercise the runtime membership check.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"choice": (["RED", "GREEN", "BLUE"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_combo"
RUNTIME_INPUT_VALIDATION = True
CATEGORY = "Testing/Stub Nodes"
@classmethod
def VALIDATE_INPUTS(cls, input_types):
return True
def stub_combo(self, choice):
return (torch.zeros(1, 64, 64, 3),)
TEST_STUB_NODE_CLASS_MAPPINGS = {
"StubImage": StubImage,
"StubConstantImage": StubConstantImage,
"StubMask": StubMask,
"StubInt": StubInt,
"StubFloat": StubFloat,
"StubStringOutput": StubStringOutput,
"StubStringWithLength": StubStringWithLength,
"StubStringWithLengthNoFlag": StubStringWithLengthNoFlag,
"StubIntWithBounds": StubIntWithBounds,
"StubComboWithOptions": StubComboWithOptions,
}
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
"StubImage": "Stub Image",
@ -126,4 +231,9 @@ TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
"StubMask": "Stub Mask",
"StubInt": "Stub Int",
"StubFloat": "Stub Float",
"StubStringOutput": "Stub String Output",
"StubStringWithLength": "Stub String With Length",
"StubStringWithLengthNoFlag": "Stub String With Length (No Flag)",
"StubIntWithBounds": "Stub Int With Bounds",
"StubComboWithOptions": "Stub Combo With Options",
}