Compare commits

..

8 Commits

9 changed files with 772 additions and 547 deletions

View File

@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
if (want_requant and len(fns) == 0 or update_weight): if (want_requant and len(fns) == 0 or update_weight):
seed = comfy.utils.string_to_seed(s.seed_key) seed = comfy.utils.string_to_seed(s.seed_key)
if isinstance(orig, QuantizedTensor): if isinstance(orig, QuantizedTensor):
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed)
else: else:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed) y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
if want_requant and len(fns) == 0: if want_requant and len(fns) == 0:
@ -1306,8 +1306,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None: if getattr(self, 'layout_type', None) is not None:
# dtype is now implicit in the layout class weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
else: else:
weight = weight.to(self.weight.dtype) weight = weight.to(self.weight.dtype)
if return_weight: if return_weight:

View File

@ -1261,6 +1261,158 @@ class DynamicSlot(ComfyTypeI):
out_dict[input_type][finalized_id] = value out_dict[input_type][finalized_id] = value
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
@comfytype(io_type="COMFY_DYNAMICGROUP_V3")
class DynamicGroup(ComfyTypeI):
"""A repeatable group of widget inputs (e.g. lora_name + strength stacked into N rows).
At execution time the node receives a ``list[dict]`` where each element is a row.
Example::
io.DynamicGroup.Input(
"loras",
template=[
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras")),
io.Float.Input("strength", default=1.0, min=-100, max=100, step=0.01),
],
min=0,
max=50,
)
# execute receives: loras: list[dict] = [{"lora_name": "x.safetensors", "strength": 1.0}, ...]
"""
Type = list[dict[str, Any]]
_MaxRows = 100
class Input(DynamicInput):
def __init__(
self,
id: str,
template: list["Input"],
min: int = 0,
max: int = 50,
display_name: str = None,
optional: bool = False,
tooltip: str = None,
lazy: bool = None,
extra_dict=None,
group_name: str = "Group",
):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
# Validate template entries: only WidgetInput subclasses, no nesting
assert len(template) > 0, "DynamicGroup template must have at least one field."
for t in template:
assert isinstance(t, WidgetInput), (
f"DynamicGroup template field '{t.id}' must be a WidgetInput subclass "
f"(Combo, Float, Int, String, Boolean, Color). Got {type(t).__name__}."
)
assert not isinstance(t, DynamicInput), (
f"DynamicGroup template field '{t.id}' must not be a DynamicInput. "
"Nesting dynamic inputs inside DynamicGroup is not supported."
)
# Enforce unique field ids within template
field_ids = [t.id for t in template]
assert len(field_ids) == len(set(field_ids)), (
f"DynamicGroup template field ids must be unique within a row. Got: {field_ids}"
)
# Reject "." in group id and template field ids: slot_id encoding uses "." as a
# delimiter (<group_id>.<row>.<field_id>), so any "." in these names would cause
# path.split(".") to produce the wrong number of segments during decoding.
assert "." not in id, (
f"DynamicGroup id must not contain '.'. Got: '{id}'"
)
for t in template:
assert "." not in t.id, (
f"DynamicGroup template field id must not contain '.'. Got: '{t.id}'"
)
assert min >= 0, "DynamicGroup min must be >= 0."
assert max >= 1, "DynamicGroup max must be >= 1."
assert max <= DynamicGroup._MaxRows, f"DynamicGroup max must be <= {DynamicGroup._MaxRows}."
assert min <= max, "DynamicGroup min must be <= max."
self.template = template
self.min = min
self.max = max
self.group_name = group_name
def get_all(self) -> list["Input"]:
return [self] + list(self.template)
def as_dict(self):
return super().as_dict() | prune_dict({
"template": create_input_dict_v1(self.template),
"min": self.min,
"max": self.max,
"group_name": self.group_name,
})
def validate(self):
for t in self.template:
t.validate()
@staticmethod
def _expand_schema_for_dynamic(
out_dict: dict[str, Any],
live_inputs: dict[str, Any],
value: tuple[str, dict[str, Any]],
input_type: str,
curr_prefix: list[str] | None,
):
info = value[1]
min_rows: int = info.get("min", 0)
max_rows: int = info.get("max", DynamicGroup._MaxRows)
template: dict[str, Any] = info.get("template", {})
# Collect all template field specs across required/optional sections
field_specs: list[tuple[str, tuple[str, dict[str, Any]], bool]] = []
for field_required_key in ("required", "optional"):
section = template.get(field_required_key, {})
is_required_field = field_required_key == "required"
for field_id, field_value in section.items():
field_specs.append((field_id, field_value, is_required_field))
# Determine how many rows are currently present by scanning live_inputs
finalized_prefix = finalize_prefix(curr_prefix)
present_rows = 0
for live_key in live_inputs:
# Keys look like "<prefix>.<row>.<field_id>"
if live_key.startswith(finalized_prefix + "."):
remainder = live_key[len(finalized_prefix) + 1:]
parts = remainder.split(".", 1)
if len(parts) >= 1:
try:
row_idx = int(parts[0])
present_rows = max(present_rows, row_idx + 1)
except ValueError:
pass
if present_rows > max_rows:
raise ValueError(
f"DynamicGroup input '{finalized_prefix}' received {present_rows} rows but max is {max_rows}."
)
row_count = max(min_rows, present_rows)
for row in range(row_count):
for field_id, field_value, is_required_field in field_specs:
slot_id = f"{finalized_prefix}.{row}.{field_id}"
# The first `min_rows` rows are required if the field itself is required
if row < min_rows and is_required_field:
out_dict["required"][slot_id] = field_value
else:
out_dict["optional"][slot_id] = field_value
# Register into dynamic_paths so build_nested_inputs places value at the right path
out_dict["dynamic_paths"][slot_id] = slot_id
# Track the list root path so build_nested_inputs can convert the index dict to a list
out_dict.setdefault("list_paths", set()).add(finalized_prefix)
# Handle the empty case (0 rows) emit an empty-list default for the parent.
# This must only fire when there are genuinely no rows; otherwise the parent
# path would clobber the per-row dict built from the slot ids above.
if row_count == 0:
out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix
out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_LIST
@comfytype(io_type="IMAGECOMPARE") @comfytype(io_type="IMAGECOMPARE")
class ImageCompare(ComfyTypeI): class ImageCompare(ComfyTypeI):
Type = dict Type = dict
@ -1418,6 +1570,8 @@ def setup_dynamic_input_funcs():
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic) register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
# DynamicSlot.Input # DynamicSlot.Input
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic) register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
# DynamicGroup.Input
register_dynamic_input_func(DynamicGroup.io_type, DynamicGroup._expand_schema_for_dynamic)
if len(DYNAMIC_INPUT_LOOKUP) == 0: if len(DYNAMIC_INPUT_LOOKUP) == 0:
setup_dynamic_input_funcs() setup_dynamic_input_funcs()
@ -1429,6 +1583,8 @@ class V3Data(TypedDict):
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.' 'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
dynamic_paths_default_value: dict[str, Any] dynamic_paths_default_value: dict[str, Any]
'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.' 'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.'
list_paths: set[str]
'Set of top-level keys whose index-keyed dict values should be converted to a sorted list[dict] after build_nested_inputs runs.'
create_dynamic_tuple: bool create_dynamic_tuple: bool
'When True, the value of the dynamic input will be in the format (value, path_key).' 'When True, the value of the dynamic input will be in the format (value, path_key).'
@ -1770,6 +1926,7 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
"optional": {}, "optional": {},
"dynamic_paths": {}, "dynamic_paths": {},
"dynamic_paths_default_value": {}, "dynamic_paths_default_value": {},
"list_paths": set(),
} }
d = d.copy() d = d.copy()
# ignore hidden for parsing # ignore hidden for parsing
@ -1785,6 +1942,10 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None) dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None)
if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0: if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0:
v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value
# list_paths: keys whose nested dict should be post-converted to a sorted list[dict]
list_paths = out_dict.pop("list_paths", None)
if list_paths:
v3_data["list_paths"] = list_paths
return out_dict, hidden, v3_data return out_dict, hidden, v3_data
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None: def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
@ -1820,10 +1981,12 @@ def add_to_dict_v1(i: Input, d: dict):
class DynamicPathsDefaultValue: class DynamicPathsDefaultValue:
EMPTY_DICT = "empty_dict" EMPTY_DICT = "empty_dict"
EMPTY_LIST = "empty_list"
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
paths = v3_data.get("dynamic_paths", None) paths = v3_data.get("dynamic_paths", None)
default_value_dict = v3_data.get("dynamic_paths_default_value", {}) default_value_dict = v3_data.get("dynamic_paths_default_value", {})
list_paths: set[str] = v3_data.get("list_paths", set()) or set()
if paths is None: if paths is None:
return values return values
values = values.copy() values = values.copy()
@ -1846,6 +2009,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
default_option = default_value_dict.get(key, None) default_option = default_value_dict.get(key, None)
if default_option == DynamicPathsDefaultValue.EMPTY_DICT: if default_option == DynamicPathsDefaultValue.EMPTY_DICT:
value = {} value = {}
elif default_option == DynamicPathsDefaultValue.EMPTY_LIST:
value = []
if create_tuple: if create_tuple:
value = (value, key) value = (value, key)
current[p] = value current[p] = value
@ -1853,6 +2018,34 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
current = current.setdefault(p, {}) current = current.setdefault(p, {})
values.update(result) values.update(result)
# Post-pass: convert index-keyed dicts to sorted lists for io.DynamicGroup fields
for list_path in list_paths:
parts = list_path.split(".")
# Navigate to the parent container, then convert the leaf
container = values
for part in parts[:-1]:
if not isinstance(container, dict) or part not in container:
container = None
break
container = container[part]
if container is None:
continue
leaf_key = parts[-1]
leaf = container.get(leaf_key, None)
if isinstance(leaf, dict):
try:
sorted_rows = [leaf[k] for k in sorted(leaf.keys(), key=int)]
container[leaf_key] = sorted_rows
except (ValueError, TypeError):
# Keys are not all integers; leave as-is
pass
elif isinstance(leaf, list):
# Already a list (e.g. the EMPTY_LIST default was applied above)
pass
elif leaf is None:
container[leaf_key] = []
return values return values
@ -2417,7 +2610,9 @@ __all__ = [
# Dynamic Types # Dynamic Types
"MatchType", "MatchType",
"DynamicCombo", "DynamicCombo",
"DynamicSlot",
"Autogrow", "Autogrow",
"DynamicGroup",
# Other classes # Other classes
"HiddenHolder", "HiddenHolder",
"Hidden", "Hidden",

View File

@ -1,85 +1,68 @@
import os import os
import sys import sys
import re import re
import ctypes
import logging import logging
import ctypes.util
import importlib.util
from typing import TypedDict from typing import TypedDict
import numpy as np import numpy as np
import torch import torch
import nodes import nodes
import comfy_angle
from comfy_api.latest import ComfyExtension, io, ui from comfy_api.latest import ComfyExtension, io, ui
from typing_extensions import override from typing_extensions import override
from utils.install_util import get_missing_requirements_message
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _check_opengl_availability(): def _preload_angle():
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work.""" egl_path = comfy_angle.get_egl_path()
logger.debug("_check_opengl_availability: starting") gles_path = comfy_angle.get_glesv2_path()
missing = []
# Check Python packages (using find_spec to avoid importing) if sys.platform == "win32":
logger.debug("_check_opengl_availability: checking for glfw package") angle_dir = comfy_angle.get_lib_dir()
if importlib.util.find_spec("glfw") is None: os.add_dll_directory(angle_dir)
missing.append("glfw") os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
logger.debug("_check_opengl_availability: checking for OpenGL package") mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
if importlib.util.find_spec("OpenGL") is None: ctypes.CDLL(str(egl_path), mode=mode)
missing.append("PyOpenGL") ctypes.CDLL(str(gles_path), mode=mode)
if missing:
raise RuntimeError(
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
)
# On Linux without display, check if headless backends are available
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
if sys.platform.startswith("linux"):
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
if not has_display:
# Check for EGL or OSMesa libraries
logger.debug("_check_opengl_availability: checking for EGL library")
has_egl = ctypes.util.find_library("EGL")
logger.debug("_check_opengl_availability: checking for OSMesa library")
has_osmesa = ctypes.util.find_library("OSMesa")
# Error disabled for CI as it fails this check
# if not has_egl and not has_osmesa:
# raise RuntimeError(
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
# "See error below for installation instructions."
# )
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
logger.debug("_check_opengl_availability: completed")
# Run early check at import time # Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
logger.debug("nodes_glsl: running _check_opengl_availability at import time") # plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
_check_opengl_availability() _preload_angle()
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
# OpenGL modules - initialized lazily when context is created
gl = None
glfw = None
EGL = None
def _import_opengl(): import OpenGL
"""Import OpenGL module. Called after context is created.""" OpenGL.USE_ACCELERATE = False
global gl
if gl is None:
logger.debug("_import_opengl: importing OpenGL.GL")
import OpenGL.GL as _gl
gl = _gl
logger.debug("_import_opengl: import completed")
return gl
def _patch_find_library():
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
'libGLESv2'. Patch find_library to return the full ANGLE paths so
PyOpenGL loads the same libraries we pre-loaded."""
if sys.platform == "linux":
return
import ctypes.util
_orig = ctypes.util.find_library
def _patched(name):
if name == 'EGL':
return comfy_angle.get_egl_path()
if name == 'GLESv2':
return comfy_angle.get_glesv2_path()
return _orig(name)
ctypes.util.find_library = _patched
_patch_find_library()
from OpenGL import EGL
from OpenGL import GLES3 as gl
class SizeModeInput(TypedDict): class SizeModeInput(TypedDict):
size_mode: str size_mode: str
width: int width: int
@ -102,7 +85,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
# (-1,-1)---(3,-1) # (-1,-1)---(3,-1)
# #
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1) # v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
VERTEX_SHADER = """#version 330 core VERTEX_SHADER = """#version 300 es
out vec2 v_texCoord; out vec2 v_texCoord;
void main() { void main() {
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3)); vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
@ -126,14 +109,99 @@ void main() {
""" """
def _convert_es_to_desktop(source: str) -> str:
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core.""" def _egl_attribs(*values):
# Remove any existing #version directive """Build an EGL_NONE-terminated EGLint attribute array."""
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE) vals = list(values) + [EGL.EGL_NONE]
# Remove precision qualifiers (not needed in desktop GLSL) return (ctypes.c_int32 * len(vals))(*vals)
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
# Prepend desktop GLSL version
return "#version 330 core\n" + source # EGL platform extension constants
EGL_PLATFORM_ANGLE_ANGLE = 0x3202
EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203
EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450
EGL_MESA_PLATFORM_SURFACELESS = 0x31DD
_eglGetPlatformDisplayEXT = None
def _get_egl_platform_display_ext(platform, native_display, attribs):
"""Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL)."""
global _eglGetPlatformDisplayEXT
if _eglGetPlatformDisplayEXT is None:
from OpenGL import platform as _plat
egl_lib = _plat.PLATFORM.EGL
_get_proc = egl_lib.eglGetProcAddress
_get_proc.restype = ctypes.c_void_p
_get_proc.argtypes = [ctypes.c_char_p]
ptr = _get_proc(b"eglGetPlatformDisplayEXT")
if not ptr:
return None
func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p)
_eglGetPlatformDisplayEXT = func_type(ptr)
raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs)
if not raw:
return None
return ctypes.cast(raw, EGL.EGLDisplay)
def _get_egl_display():
"""Get an EGL display, trying the default first then ANGLE's Vulkan
platform for headless environments without a display server."""
failures = []
# Try the default display first (works when X11/Wayland is available)
display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
if display:
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
try:
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
return display, major.value, minor.value
except Exception as e:
failures.append(f"default: {e}")
logger.info("Default EGL display unavailable, trying headless fallbacks")
# Headless fallback strategies, tried in order:
headless_strategies = [
("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None),
("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None,
_egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)),
]
for name, platform, native_display, attribs in headless_strategies:
display = _get_egl_platform_display_ext(platform, native_display, attribs)
if not display:
failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display")
continue
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
try:
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
logger.info(f"Using EGL {name} platform (headless)")
return display, major.value, minor.value
failures.append(f"{name}: eglInitialize returned false")
except Exception as e:
failures.append(f"{name}: {e}")
continue
details = "\n".join(f" - {f}" for f in failures)
raise RuntimeError(
"Failed to initialize EGL display.\n"
"No display server and no headless EGL platform available.\n"
f"Tried:\n{details}\n"
"Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer."
)
def _gl_str(name):
"""Get an OpenGL string parameter."""
v = gl.glGetString(name)
if not v:
return "Unknown"
if isinstance(v, bytes):
return v.decode(errors="replace")
return ctypes.string_at(v).decode(errors="replace")
def _detect_output_count(source: str) -> int: def _detect_output_count(source: str) -> int:
@ -159,163 +227,8 @@ def _detect_pass_count(source: str) -> int:
return 1 return 1
def _init_glfw():
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
logger.debug("_init_glfw: starting")
# On macOS, glfw.init() must be called from main thread or it hangs forever
if sys.platform == "darwin":
logger.debug("_init_glfw: skipping on macOS")
raise RuntimeError("GLFW backend not supported on macOS")
logger.debug("_init_glfw: importing glfw module")
import glfw as _glfw
logger.debug("_init_glfw: calling glfw.init()")
if not _glfw.init():
raise RuntimeError("glfw.init() failed")
try:
logger.debug("_init_glfw: setting window hints")
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
logger.debug("_init_glfw: calling create_window()")
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
if not window:
raise RuntimeError("glfw.create_window() failed")
logger.debug("_init_glfw: calling make_context_current()")
_glfw.make_context_current(window)
logger.debug("_init_glfw: completed successfully")
return window, _glfw
except Exception:
logger.debug("_init_glfw: failed, terminating glfw")
_glfw.terminate()
raise
def _init_egl():
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
logger.debug("_init_egl: starting")
from OpenGL import EGL as _EGL
from OpenGL.EGL import (
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
eglTerminate, eglDestroyContext, eglDestroySurface,
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
)
logger.debug("_init_egl: imports completed")
display = None
context = None
surface = None
try:
logger.debug("_init_egl: calling eglGetDisplay()")
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
if display == _EGL.EGL_NO_DISPLAY:
raise RuntimeError("eglGetDisplay() failed")
logger.debug("_init_egl: calling eglInitialize()")
major, minor = _EGL.EGLint(), _EGL.EGLint()
if not eglInitialize(display, major, minor):
display = None # Not initialized, don't terminate
raise RuntimeError("eglInitialize() failed")
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
config_attribs = [
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
EGL_DEPTH_SIZE, 0, EGL_NONE
]
configs = (_EGL.EGLConfig * 1)()
num_configs = _EGL.EGLint()
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
config = configs[0]
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
if not eglBindAPI(EGL_OPENGL_API):
raise RuntimeError("eglBindAPI() failed")
logger.debug("_init_egl: calling eglCreateContext()")
context_attribs = [
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
EGL_NONE
]
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
if context == EGL_NO_CONTEXT:
raise RuntimeError("eglCreateContext() failed")
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
if surface == _EGL.EGL_NO_SURFACE:
raise RuntimeError("eglCreatePbufferSurface() failed")
logger.debug("_init_egl: calling eglMakeCurrent()")
if not eglMakeCurrent(display, surface, surface, context):
raise RuntimeError("eglMakeCurrent() failed")
logger.debug("_init_egl: completed successfully")
return display, context, surface, _EGL
except Exception:
logger.debug("_init_egl: failed, cleaning up")
# Clean up any resources on failure
if surface is not None:
eglDestroySurface(display, surface)
if context is not None:
eglDestroyContext(display, context)
if display is not None:
eglTerminate(display)
raise
def _init_osmesa():
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
import ctypes
logger.debug("_init_osmesa: starting")
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
logger.debug("_init_osmesa: importing OpenGL.osmesa")
from OpenGL import GL as _gl
from OpenGL.osmesa import (
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
OSMESA_RGBA,
)
logger.debug("_init_osmesa: imports completed")
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
if not ctx:
raise RuntimeError("OSMesaCreateContextExt() failed")
width, height = 64, 64
buffer = (ctypes.c_ubyte * (width * height * 4))()
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
OSMesaDestroyContext(ctx)
raise RuntimeError("OSMesaMakeCurrent() failed")
logger.debug("_init_osmesa: completed successfully")
return ctx, buffer
class GLContext: class GLContext:
"""Manages OpenGL context and resources for shader execution. """Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
"""
_instance = None _instance = None
_initialized = False _initialized = False
@ -327,131 +240,105 @@ class GLContext:
def __init__(self): def __init__(self):
if GLContext._initialized: if GLContext._initialized:
logger.debug("GLContext.__init__: already initialized, skipping")
return return
logger.debug("GLContext.__init__: starting initialization")
global glfw, EGL
import time import time
start = time.perf_counter() start = time.perf_counter()
self._backend = None self._display = None
self._window = None self._surface = None
self._egl_display = None self._context = None
self._egl_context = None
self._egl_surface = None
self._osmesa_ctx = None
self._osmesa_buffer = None
self._vao = None self._vao = None
# Try backends in order: GLFW → EGL → OSMesa
errors = []
logger.debug("GLContext.__init__: trying GLFW backend")
try: try:
self._window, glfw = _init_glfw() self._display, self._egl_major, self._egl_minor = _get_egl_display()
self._backend = "glfw"
logger.debug("GLContext.__init__: GLFW backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
errors.append(("GLFW", e))
if self._backend is None: if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
logger.debug("GLContext.__init__: trying EGL backend") raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
try:
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
self._backend = "egl"
logger.debug("GLContext.__init__: EGL backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
errors.append(("EGL", e))
if self._backend is None: config = EGL.EGLConfig()
logger.debug("GLContext.__init__: trying OSMesa backend") n_configs = ctypes.c_int32(0)
try: if not EGL.eglChooseConfig(
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa() self._display,
self._backend = "osmesa" _egl_attribs(
logger.debug("GLContext.__init__: OSMesa backend succeeded") EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
except Exception as e: EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}") EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
errors.append(("OSMesa", e)) EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
),
ctypes.byref(config), 1, ctypes.byref(n_configs),
) or n_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
if self._backend is None: self._surface = EGL.eglCreatePbufferSurface(
if sys.platform == "win32": self._display, config,
platform_help = ( _egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
"Windows: Ensure GPU drivers are installed and display is available.\n"
" CPU-only/headless mode is not supported on Windows."
)
elif sys.platform == "darwin":
platform_help = (
"macOS: GLFW is not supported.\n"
" Install OSMesa via Homebrew: brew install mesa\n"
" Then: pip install PyOpenGL PyOpenGL-accelerate"
)
else:
platform_help = (
"Linux: Install one of these backends:\n"
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
" Headless (CPU): sudo apt install libosmesa6"
)
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
raise RuntimeError(
f"Failed to create OpenGL context.\n\n"
f"Backend errors:\n{error_details}\n\n"
f"{platform_help}"
) )
if not self._surface:
raise RuntimeError("eglCreatePbufferSurface() failed")
# Now import OpenGL.GL (after context is current) self._context = EGL.eglCreateContext(
logger.debug("GLContext.__init__: importing OpenGL.GL") self._display, config, EGL.EGL_NO_CONTEXT,
_import_opengl() _egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
)
if not self._context:
raise RuntimeError("eglCreateContext() failed")
# Create VAO (required for core profile, but OSMesa may use compat profile) if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
logger.debug("GLContext.__init__: creating VAO") raise RuntimeError("eglMakeCurrent() failed")
try:
vao = gl.glGenVertexArrays(1) self._vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(vao) gl.glBindVertexArray(self._vao)
self._vao = vao # Only store after successful bind
logger.debug("GLContext.__init__: VAO created successfully") except Exception:
except Exception as e: self._cleanup()
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}") raise
# OSMesa with older Mesa may not support VAOs
# Clean up if we created but couldn't bind
if vao:
try:
gl.glDeleteVertexArrays(1, [vao])
except Exception:
pass
elapsed = (time.perf_counter() - start) * 1000 elapsed = (time.perf_counter() - start) * 1000
# Log device info renderer = _gl_str(gl.GL_RENDERER)
renderer = gl.glGetString(gl.GL_RENDERER) vendor = _gl_str(gl.GL_VENDOR)
vendor = gl.glGetString(gl.GL_VENDOR) version = _gl_str(gl.GL_VERSION)
version = gl.glGetString(gl.GL_VERSION)
renderer = renderer.decode() if renderer else "Unknown"
vendor = vendor.decode() if vendor else "Unknown"
version = version.decode() if version else "Unknown"
GLContext._initialized = True GLContext._initialized = True
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}") logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}")
def make_current(self): def make_current(self):
if self._backend == "glfw": if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
glfw.make_context_current(self._window) err = EGL.eglGetError()
elif self._backend == "egl": raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})")
from OpenGL.EGL import eglMakeCurrent
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
elif self._backend == "osmesa":
from OpenGL.osmesa import OSMesaMakeCurrent
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
if self._vao is not None: if self._vao is not None:
gl.glBindVertexArray(self._vao) gl.glBindVertexArray(self._vao)
def _cleanup(self):
if not self._display:
return
try:
if self._vao is not None:
gl.glDeleteVertexArrays(1, [self._vao])
self._vao = None
except Exception:
pass
try:
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
except Exception:
pass
try:
if self._context:
EGL.eglDestroyContext(self._display, self._context)
except Exception:
pass
try:
if self._surface:
EGL.eglDestroySurface(self._display, self._surface)
except Exception:
pass
try:
EGL.eglTerminate(self._display)
except Exception:
pass
self._display = None
def _compile_shader(source: str, shader_type: int) -> int: def _compile_shader(source: str, shader_type: int) -> int:
"""Compile a shader and return its ID.""" """Compile a shader and return its ID."""
@ -459,8 +346,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
gl.glShaderSource(shader, source) gl.glShaderSource(shader, source)
gl.glCompileShader(shader) gl.glCompileShader(shader)
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE: if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
error = gl.glGetShaderInfoLog(shader).decode() error = gl.glGetShaderInfoLog(shader)
if isinstance(error, bytes):
error = error.decode(errors="replace")
gl.glDeleteShader(shader) gl.glDeleteShader(shader)
raise RuntimeError(f"Shader compilation failed:\n{error}") raise RuntimeError(f"Shader compilation failed:\n{error}")
@ -484,8 +373,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
gl.glDeleteShader(vertex_shader) gl.glDeleteShader(vertex_shader)
gl.glDeleteShader(fragment_shader) gl.glDeleteShader(fragment_shader)
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE: if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
error = gl.glGetProgramInfoLog(program).decode() error = gl.glGetProgramInfoLog(program)
if isinstance(error, bytes):
error = error.decode(errors="replace")
gl.glDeleteProgram(program) gl.glDeleteProgram(program)
raise RuntimeError(f"Program linking failed:\n{error}") raise RuntimeError(f"Program linking failed:\n{error}")
@ -530,9 +421,6 @@ def _render_shader_batch(
ctx = GLContext() ctx = GLContext()
ctx.make_current() ctx.make_current()
# Convert from GLSL ES to desktop GLSL 330
fragment_source = _convert_es_to_desktop(fragment_code)
# Detect how many outputs the shader actually uses # Detect how many outputs the shader actually uses
num_outputs = _detect_output_count(fragment_code) num_outputs = _detect_output_count(fragment_code)
@ -558,9 +446,9 @@ def _render_shader_batch(
try: try:
# Compile shaders (once for all batches) # Compile shaders (once for all batches)
try: try:
program = _create_program(VERTEX_SHADER, fragment_source) program = _create_program(VERTEX_SHADER, fragment_code)
except RuntimeError: except RuntimeError:
logger.error(f"Fragment shader:\n{fragment_source}") logger.error(f"Fragment shader:\n{fragment_code}")
raise raise
gl.glUseProgram(program) gl.glUseProgram(program)
@ -723,13 +611,13 @@ def _render_shader_batch(
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3) gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
# Read back outputs for this batch # Read back outputs for this batch
# (glGetTexImage is synchronous, implicitly waits for rendering) gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
batch_outputs = [] batch_outputs = []
for tex in output_textures: for i in range(num_outputs):
gl.glBindTexture(gl.GL_TEXTURE_2D, tex) gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT) buf = np.empty((height, width, 4), dtype=np.float32)
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4) gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
batch_outputs.append(img[::-1, :, :].copy()) batch_outputs.append(buf[::-1, :, :].copy())
# Pad with black images for unused outputs # Pad with black images for unused outputs
black_img = np.zeros((height, width, 4), dtype=np.float32) black_img = np.zeros((height, width, 4), dtype=np.float32)
@ -750,18 +638,18 @@ def _render_shader_batch(
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glUseProgram(0) gl.glUseProgram(0)
for tex in input_textures: if input_textures:
gl.glDeleteTextures(int(tex)) gl.glDeleteTextures(len(input_textures), input_textures)
for tex in curve_textures: if curve_textures:
gl.glDeleteTextures(int(tex)) gl.glDeleteTextures(len(curve_textures), curve_textures)
for tex in output_textures: if output_textures:
gl.glDeleteTextures(int(tex)) gl.glDeleteTextures(len(output_textures), output_textures)
for tex in ping_pong_textures: if ping_pong_textures:
gl.glDeleteTextures(int(tex)) gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
if fbo is not None: if fbo is not None:
gl.glDeleteFramebuffers(1, [fbo]) gl.glDeleteFramebuffers(1, [fbo])
for pp_fbo in ping_pong_fbos: if ping_pong_fbos:
gl.glDeleteFramebuffers(1, [pp_fbo]) gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
if program is not None: if program is not None:
gl.glDeleteProgram(program) gl.glDeleteProgram(program)

View File

@ -0,0 +1,130 @@
"""LoRA stacking loaders built on io.DynamicGroup.
Two nodes that let you stack any number of LoRAs in a single node, each row
carrying only a LoRA name and a strength:
LoadLoraModel
Applies a stack of LoRAs to a diffusion MODEL.
LoadLoraTextEncoder
Applies a stack of LoRAs to a CLIP text encoder.
Both are modelled on DynamicGroupLoraStyleTest in nodes_dynamic_group_test.py,
but operate on real models and real LoRA files.
"""
from __future__ import annotations
from typing_extensions import override
import comfy.sd
import comfy.utils
import folder_paths
from comfy_api.latest import ComfyExtension, io
# Module-level cache so repeated executions don't re-read the same file from disk.
_LORA_CACHE: dict[str, tuple] = {}
def _load_lora_file(lora_name: str):
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
cached = _LORA_CACHE.get(lora_path)
if cached is not None:
return cached
lora, metadata = comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True)
_LORA_CACHE[lora_path] = (lora, metadata)
return lora, metadata
def _lora_template() -> list[io.Input]:
return [
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras"),
tooltip="The name of the LoRA file to apply."),
io.Float.Input("strength", default=1.0, min=-100.0, max=100.0, step=0.01,
tooltip="How strongly to apply this LoRA. 0 = off, negative inverts the effect."),
]
class LoadLoraModel(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadLoraModel",
display_name="Load LoRA (Model)",
search_aliases=["lora", "load lora", "apply lora", "lora model", "lora stack"],
category="model/loaders",
description="Apply a stack of LoRAs to a diffusion model. Add one row per LoRA; "
"each row picks a LoRA file and its strength.",
inputs=[
io.Model.Input("model", tooltip="The diffusion model the LoRAs will be applied to."),
io.DynamicGroup.Input(
"loras",
template=_lora_template(),
min=1,
max=50,
tooltip="Each row applies one LoRA to the model.",
group_name="LoRA",
),
],
outputs=[io.Model.Output(tooltip="The modified diffusion model.")],
)
@classmethod
def execute(cls, model, loras: list[dict]) -> io.NodeOutput:
for row in loras:
lora_name = row.get("lora_name")
strength = row.get("strength", 1.0)
if not lora_name or lora_name == "none" or strength == 0:
continue
lora, metadata = _load_lora_file(lora_name)
model, _ = comfy.sd.load_lora_for_models(model, None, lora, strength, 0, lora_metadata=metadata)
return io.NodeOutput(model)
class LoadLoraTextEncoder(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadLoraTextEncoder",
display_name="Load LoRA (Text Encoder)",
search_aliases=["lora", "load lora", "apply lora", "clip lora", "lora stack"],
category="model/loaders",
description="Apply a stack of LoRAs to a CLIP text encoder. Add one row per LoRA; "
"each row picks a LoRA file and its strength.",
inputs=[
io.Clip.Input("clip", tooltip="The CLIP text encoder the LoRAs will be applied to."),
io.DynamicGroup.Input(
"loras",
template=_lora_template(),
min=1,
max=50,
tooltip="Each row applies one LoRA to the text encoder.",
group_name="LoRA",
),
],
outputs=[io.Clip.Output(tooltip="The modified CLIP text encoder.")],
)
@classmethod
def execute(cls, clip, loras: list[dict]) -> io.NodeOutput:
for row in loras:
lora_name = row.get("lora_name")
strength = row.get("strength", 1.0)
if not lora_name or lora_name == "none" or strength == 0:
continue
lora, metadata = _load_lora_file(lora_name)
_, clip = comfy.sd.load_lora_for_models(None, clip, lora, 0, strength, lora_metadata=metadata)
return io.NodeOutput(clip)
class LoraStackExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LoadLoraModel,
LoadLoraTextEncoder,
]
async def comfy_entrypoint() -> LoraStackExtension:
return LoraStackExtension()

View File

@ -1113,32 +1113,6 @@ def full_type_name(klass):
return klass.__qualname__ return klass.__qualname__
return module + '.' + klass.__qualname__ return module + '.' + klass.__qualname__
def node_not_executable_reason(class_def, class_type):
"""Return a human-readable reason the node cannot be executed, or None if it's fine.
Catches a node whose declared entry point doesn't resolve to a real method
(e.g. a V1 ``FUNCTION = "invert"`` where the method is misspelled, or a V3 node
missing its ``execute`` override). Running this during validation surfaces the
problem before execution starts, instead of after upstream nodes have run.
Only the class is inspected; the node is never instantiated here, so a node's
``__init__`` side effects cannot run (or fail) during validation.
"""
try:
if issubclass(class_def, _ComfyNodeInternal):
# V3: validates that execute()/define_schema() overrides exist.
class_def.VALIDATE_CLASS()
return None
# V1: FUNCTION names the method to call; it must exist on the class.
function_name = getattr(class_def, "FUNCTION", None)
if function_name is None:
return f"'{class_type}' does not define FUNCTION"
if not callable(getattr(class_def, function_name, None)):
return f"'{class_type}' has no method '{function_name}' (declared in FUNCTION)"
return None
except Exception as ex:
return str(ex)
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]): async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
outputs = set() outputs = set()
for x in prompt: for x in prompt:
@ -1174,35 +1148,6 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[
} }
return (False, error, [], {}) return (False, error, [], {})
# Make sure the node is actually executable (its FUNCTION/execute entry
# point resolves to a real method) before we touch any schema-derived
# attributes below or start execution. Catches code typos up front and
# attributes the error to the offending node.
not_executable = node_not_executable_reason(class_, class_type)
if not_executable is not None:
node_title = prompt[x].get('_meta', {}).get('title', class_type)
error = {
"type": "invalid_node_definition",
"message": "Node is not executable",
"details": f"{not_executable} (Node ID '#{x}')",
"extra_info": {
"node_id": x,
"class_type": class_type,
"node_title": node_title,
}
}
node_errors = {x: {
"errors": [{
"type": "invalid_node_definition",
"message": "Node is not executable",
"details": not_executable,
"extra_info": {},
}],
"dependent_outputs": [],
"class_type": class_type,
}}
return (False, error, [], node_errors)
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True: if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
if partial_execution_list is None or x in partial_execution_list: if partial_execution_list is None or x in partial_execution_list:
outputs.add(x) outputs.add(x)

View File

@ -2476,6 +2476,7 @@ async def init_builtin_extra_nodes():
"nodes_triposplat.py", "nodes_triposplat.py",
"nodes_depth_anything_3.py", "nodes_depth_anything_3.py",
"nodes_seed.py", "nodes_seed.py",
"nodes_lora_stack.py",
] ]
import_failed = [] import_failed = []

View File

@ -22,7 +22,7 @@ alembic
SQLAlchemy>=2.0.0 SQLAlchemy>=2.0.0
filelock filelock
av>=16.0.0 av>=16.0.0
comfy-kitchen==0.2.12 comfy-kitchen==0.2.13
comfy-aimdo==0.4.10 comfy-aimdo==0.4.10
requests requests
simpleeval>=1.0.0 simpleeval>=1.0.0
@ -33,5 +33,5 @@ kornia>=0.7.1
spandrel spandrel
pydantic~=2.0 pydantic~=2.0
pydantic-settings~=2.0 pydantic-settings~=2.0
PyOpenGL PyOpenGL>=3.1.8
glfw comfy-angle

View File

@ -0,0 +1,204 @@
"""Unit tests for io.DynamicGroup: expansion/reconstruction (0-row and N-row cases)."""
import sys
import types
import pytest
# Stub torch (type-hint only in _io.py; real torch not available in unit-test env)
if "torch" not in sys.modules:
_torch_stub = types.ModuleType("torch")
_torch_stub.Tensor = object # type: ignore[attr-defined]
sys.modules["torch"] = _torch_stub
from comfy_api.latest._io import ( # noqa: E402
DynamicGroup,
Float,
Int,
String,
Boolean,
get_finalized_class_inputs,
build_nested_inputs,
create_input_dict_v1,
setup_dynamic_input_funcs,
)
# Make sure dynamic input funcs are registered (may already be done at import time)
setup_dynamic_input_funcs()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_class_inputs(group_input: DynamicGroup.Input) -> dict:
"""Wrap a DynamicGroup.Input into the required/optional dict structure."""
return create_input_dict_v1([group_input])
def _run(group_input: DynamicGroup.Input, live_values: dict) -> dict:
"""End-to-end helper: expand schema + reconstruct values.
Mirrors the production split in execution.py:
1. get_finalized_class_inputs (schema expansion, line 162)
2. build_nested_inputs (value reconstruction, line 281)
The two steps are separate in production because the engine resolves
linked node outputs between them, but in tests we supply values directly.
"""
class_inputs = _make_class_inputs(group_input)
_, _, v3_data = get_finalized_class_inputs(class_inputs, live_values)
return build_nested_inputs(dict(live_values), v3_data)
# ---------------------------------------------------------------------------
# Schema construction
# ---------------------------------------------------------------------------
class TestDynamicGroupInputConstruction:
def test_basic_construction(self):
inp = DynamicGroup.Input(
"loras",
template=[
Float.Input("strength", default=1.0),
String.Input("name"),
],
min=0,
max=10,
)
assert inp.id == "loras"
assert inp.min == 0
assert inp.max == 10
assert len(inp.template) == 2
def test_get_all_includes_self_and_template(self):
inp = DynamicGroup.Input(
"items",
template=[Float.Input("value")],
)
all_inputs = inp.get_all()
assert all_inputs[0] is inp
assert all_inputs[1].id == "value"
def test_as_dict_has_template_min_max(self):
inp = DynamicGroup.Input(
"items",
template=[Float.Input("val", default=0.5)],
min=1,
max=5,
)
d = inp.as_dict()
assert "template" in d
assert d["min"] == 1
assert d["max"] == 5
def test_duplicate_field_ids_raises(self):
with pytest.raises(AssertionError):
DynamicGroup.Input(
"bad",
template=[Float.Input("x"), Float.Input("x")],
)
def test_empty_template_raises(self):
with pytest.raises(AssertionError):
DynamicGroup.Input("bad", template=[])
def test_min_gt_max_raises(self):
with pytest.raises(AssertionError):
DynamicGroup.Input("bad", template=[Float.Input("x")], min=5, max=3)
def test_max_exceeds_limit_raises(self):
with pytest.raises(AssertionError):
DynamicGroup.Input("bad", template=[Float.Input("x")], max=101)
def test_dynamic_input_in_template_raises(self):
with pytest.raises(AssertionError):
DynamicGroup.Input(
"bad",
template=[DynamicGroup.Input("nested", template=[Float.Input("x")])],
)
def test_validate_calls_through(self):
inp = DynamicGroup.Input("items", template=[Float.Input("val", min=-1.0, max=1.0)])
inp.validate() # should not raise
# ---------------------------------------------------------------------------
# 0-row case
# ---------------------------------------------------------------------------
class TestZeroRows:
def test_empty_live_inputs_produces_empty_list(self):
"""With min=0 and no live values, the result should be an empty list."""
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
assert _run(inp, {}).get("loras") == []
def test_min_zero_with_values(self):
"""min=0 but 2 rows of live data."""
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
result = _run(inp, {"loras.0.strength": 0.8, "loras.1.strength": 0.5})
assert result["loras"] == [{"strength": 0.8}, {"strength": 0.5}]
# ---------------------------------------------------------------------------
# N-row case
# ---------------------------------------------------------------------------
class TestNRows:
def test_two_rows_two_fields(self):
"""Two rows with two fields each produce a list[dict]."""
inp = DynamicGroup.Input(
"loras",
template=[String.Input("lora_name"), Float.Input("strength", default=1.0)],
min=0, max=50,
)
result = _run(inp, {
"loras.0.lora_name": "model_a.safetensors", "loras.0.strength": 0.9,
"loras.1.lora_name": "model_b.safetensors", "loras.1.strength": 0.4,
})
assert result["loras"] == [
{"lora_name": "model_a.safetensors", "strength": 0.9},
{"lora_name": "model_b.safetensors", "strength": 0.4},
]
def test_rows_are_sorted_by_index(self):
"""Rows must be in ascending index order even if dict iteration is unordered."""
inp = DynamicGroup.Input("items", template=[Int.Input("v", default=0)], min=0, max=10)
result = _run(inp, {"items.0.v": 10, "items.2.v": 30, "items.1.v": 20})
assert [row["v"] for row in result["items"]] == [10, 20, 30]
def test_min_rows_schema_slots(self):
"""With min=2 and no live data, 2 slots must appear in the expanded schema."""
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
out, _, _ = get_finalized_class_inputs(_make_class_inputs(inp), {})
all_slots = {**out.get("required", {}), **out.get("optional", {})}
assert "items.0.val" in all_slots
assert "items.1.val" in all_slots
def test_min_rows_reconstructs_when_no_values(self):
"""min=2 with NO live values must still yield a 2-element list,
not collapse to [] (regression: parent-path clobber)."""
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
result = _run(inp, {})
assert len(result["items"]) == 2
assert all("val" in row for row in result["items"])
def test_min_rows_reconstructs_with_partial_values(self):
"""min=2 with only the first row's value present still yields 2 rows."""
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
result = _run(inp, {"items.0.val": 0.7})
assert len(result["items"]) == 2
assert result["items"][0]["val"] == 0.7
assert result["items"][1]["val"] is None
def test_list_paths_in_v3_data(self):
"""list_paths must contain the group id so build_nested_inputs knows to convert."""
inp = DynamicGroup.Input("things", template=[Boolean.Input("flag")], min=0, max=5)
_, _, v3_data = get_finalized_class_inputs(_make_class_inputs(inp), {})
assert "things" in v3_data.get("list_paths", set())
def test_no_leftover_flat_keys(self):
"""Flat keys must be consumed; only the reconstructed list remains."""
inp = DynamicGroup.Input("rows", template=[Float.Input("x", default=0.0)], min=0, max=5)
result = _run(inp, {"rows.0.x": 1.0, "rows.1.x": 2.0})
assert "rows.0.x" not in result
assert "rows.1.x" not in result
assert isinstance(result["rows"], list)

View File

@ -1,137 +0,0 @@
"""Tests for pre-execution validation that a node is actually executable.
validate_prompt rejects a node whose declared entry point does not resolve to a
real method (a V1 FUNCTION typo, or a V3 node missing its execute override) before
any node runs, attributing the error to the offending node.
"""
import asyncio
import nodes
from comfy_api.latest import io
from execution import node_not_executable_reason, validate_prompt
class _GoodV1Node:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "run"
OUTPUT_NODE = True
CATEGORY = "Test"
def run(self):
return (None,)
class _TypoV1Node:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "invert" # method below is misspelled
OUTPUT_NODE = True
CATEGORY = "Test"
def invvert(self):
return (None,)
class _SideEffectInitV1Node:
"""Valid class-level method, but a constructor that must never run in validation."""
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "run"
OUTPUT_NODE = True
CATEGORY = "Test"
def __init__(self):
raise RuntimeError("__init__ must not run during validation")
def run(self):
return (None,)
def _v3_schema(node_id):
return io.Schema(
node_id=node_id,
display_name=node_id,
category="Test",
inputs=[],
outputs=[io.Image.Output()],
is_output_node=True,
)
class _GoodV3Node(io.ComfyNode):
@classmethod
def define_schema(cls):
return _v3_schema("GoodV3Node")
@classmethod
def execute(cls):
return io.NodeOutput(None)
class _TypoV3Node(io.ComfyNode):
@classmethod
def define_schema(cls):
return _v3_schema("TypoV3Node")
@classmethod
def exicute(cls): # typo: should be "execute"
return io.NodeOutput(None)
def _register(class_type, class_def):
nodes.NODE_CLASS_MAPPINGS[class_type] = class_def
def _validate(class_type):
prompt = {"1": {"class_type": class_type, "inputs": {}}}
return asyncio.run(validate_prompt("pid", prompt, None))
def test_good_node_passes():
_register("GoodV1Node", _GoodV1Node)
assert node_not_executable_reason(_GoodV1Node, "GoodV1Node") is None
valid, _, _, _ = _validate("GoodV1Node")
assert valid is True
def test_typo_node_rejected_with_node_error():
_register("TypoV1Node", _TypoV1Node)
valid, error, _, node_errors = _validate("TypoV1Node")
assert valid is False
assert error["type"] == "invalid_node_definition"
assert node_errors["1"]["class_type"] == "TypoV1Node"
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"
assert "invert" in node_errors["1"]["errors"][0]["details"]
def test_validation_does_not_instantiate_node():
"""A valid node is not constructed during validation, so __init__ never runs."""
_register("SideEffectInitV1Node", _SideEffectInitV1Node)
assert node_not_executable_reason(_SideEffectInitV1Node, "SideEffectInitV1Node") is None
valid, _, _, _ = _validate("SideEffectInitV1Node")
assert valid is True
def test_good_v3_node_passes():
_register("GoodV3Node", _GoodV3Node)
assert node_not_executable_reason(_GoodV3Node, "GoodV3Node") is None
valid, _, _, _ = _validate("GoodV3Node")
assert valid is True
def test_typo_v3_node_rejected_with_node_error():
_register("TypoV3Node", _TypoV3Node)
valid, error, _, node_errors = _validate("TypoV3Node")
assert valid is False
assert error["type"] == "invalid_node_definition"
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"