mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-28 00:46:52 +08:00
Compare commits
8 Commits
fix/valida
...
LoraNodes
| Author | SHA1 | Date | |
|---|---|---|---|
| 4badc89490 | |||
| 330a37db94 | |||
| 30b19c6872 | |||
| 2dd281d8a6 | |||
| 911e0b2acf | |||
| 46c7e8055c | |||
| 603d891eaf | |||
| 470ac36a0a |
@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
|
|||||||
if (want_requant and len(fns) == 0 or update_weight):
|
if (want_requant and len(fns) == 0 or update_weight):
|
||||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||||
if isinstance(orig, QuantizedTensor):
|
if isinstance(orig, QuantizedTensor):
|
||||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed)
|
||||||
else:
|
else:
|
||||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
|
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
|
||||||
if want_requant and len(fns) == 0:
|
if want_requant and len(fns) == 0:
|
||||||
@ -1306,8 +1306,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||||
if getattr(self, 'layout_type', None) is not None:
|
if getattr(self, 'layout_type', None) is not None:
|
||||||
# dtype is now implicit in the layout class
|
weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
|
||||||
else:
|
else:
|
||||||
weight = weight.to(self.weight.dtype)
|
weight = weight.to(self.weight.dtype)
|
||||||
if return_weight:
|
if return_weight:
|
||||||
|
|||||||
@ -1261,6 +1261,158 @@ class DynamicSlot(ComfyTypeI):
|
|||||||
out_dict[input_type][finalized_id] = value
|
out_dict[input_type][finalized_id] = value
|
||||||
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||||
|
|
||||||
|
@comfytype(io_type="COMFY_DYNAMICGROUP_V3")
|
||||||
|
class DynamicGroup(ComfyTypeI):
|
||||||
|
"""A repeatable group of widget inputs (e.g. lora_name + strength stacked into N rows).
|
||||||
|
|
||||||
|
At execution time the node receives a ``list[dict]`` where each element is a row.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
io.DynamicGroup.Input(
|
||||||
|
"loras",
|
||||||
|
template=[
|
||||||
|
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras")),
|
||||||
|
io.Float.Input("strength", default=1.0, min=-100, max=100, step=0.01),
|
||||||
|
],
|
||||||
|
min=0,
|
||||||
|
max=50,
|
||||||
|
)
|
||||||
|
# execute receives: loras: list[dict] = [{"lora_name": "x.safetensors", "strength": 1.0}, ...]
|
||||||
|
"""
|
||||||
|
|
||||||
|
Type = list[dict[str, Any]]
|
||||||
|
_MaxRows = 100
|
||||||
|
|
||||||
|
class Input(DynamicInput):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
template: list["Input"],
|
||||||
|
min: int = 0,
|
||||||
|
max: int = 50,
|
||||||
|
display_name: str = None,
|
||||||
|
optional: bool = False,
|
||||||
|
tooltip: str = None,
|
||||||
|
lazy: bool = None,
|
||||||
|
extra_dict=None,
|
||||||
|
group_name: str = "Group",
|
||||||
|
):
|
||||||
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||||
|
# Validate template entries: only WidgetInput subclasses, no nesting
|
||||||
|
assert len(template) > 0, "DynamicGroup template must have at least one field."
|
||||||
|
for t in template:
|
||||||
|
assert isinstance(t, WidgetInput), (
|
||||||
|
f"DynamicGroup template field '{t.id}' must be a WidgetInput subclass "
|
||||||
|
f"(Combo, Float, Int, String, Boolean, Color). Got {type(t).__name__}."
|
||||||
|
)
|
||||||
|
assert not isinstance(t, DynamicInput), (
|
||||||
|
f"DynamicGroup template field '{t.id}' must not be a DynamicInput. "
|
||||||
|
"Nesting dynamic inputs inside DynamicGroup is not supported."
|
||||||
|
)
|
||||||
|
# Enforce unique field ids within template
|
||||||
|
field_ids = [t.id for t in template]
|
||||||
|
assert len(field_ids) == len(set(field_ids)), (
|
||||||
|
f"DynamicGroup template field ids must be unique within a row. Got: {field_ids}"
|
||||||
|
)
|
||||||
|
# Reject "." in group id and template field ids: slot_id encoding uses "." as a
|
||||||
|
# delimiter (<group_id>.<row>.<field_id>), so any "." in these names would cause
|
||||||
|
# path.split(".") to produce the wrong number of segments during decoding.
|
||||||
|
assert "." not in id, (
|
||||||
|
f"DynamicGroup id must not contain '.'. Got: '{id}'"
|
||||||
|
)
|
||||||
|
for t in template:
|
||||||
|
assert "." not in t.id, (
|
||||||
|
f"DynamicGroup template field id must not contain '.'. Got: '{t.id}'"
|
||||||
|
)
|
||||||
|
assert min >= 0, "DynamicGroup min must be >= 0."
|
||||||
|
assert max >= 1, "DynamicGroup max must be >= 1."
|
||||||
|
assert max <= DynamicGroup._MaxRows, f"DynamicGroup max must be <= {DynamicGroup._MaxRows}."
|
||||||
|
assert min <= max, "DynamicGroup min must be <= max."
|
||||||
|
self.template = template
|
||||||
|
self.min = min
|
||||||
|
self.max = max
|
||||||
|
self.group_name = group_name
|
||||||
|
|
||||||
|
def get_all(self) -> list["Input"]:
|
||||||
|
return [self] + list(self.template)
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict() | prune_dict({
|
||||||
|
"template": create_input_dict_v1(self.template),
|
||||||
|
"min": self.min,
|
||||||
|
"max": self.max,
|
||||||
|
"group_name": self.group_name,
|
||||||
|
})
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
for t in self.template:
|
||||||
|
t.validate()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _expand_schema_for_dynamic(
|
||||||
|
out_dict: dict[str, Any],
|
||||||
|
live_inputs: dict[str, Any],
|
||||||
|
value: tuple[str, dict[str, Any]],
|
||||||
|
input_type: str,
|
||||||
|
curr_prefix: list[str] | None,
|
||||||
|
):
|
||||||
|
info = value[1]
|
||||||
|
min_rows: int = info.get("min", 0)
|
||||||
|
max_rows: int = info.get("max", DynamicGroup._MaxRows)
|
||||||
|
template: dict[str, Any] = info.get("template", {})
|
||||||
|
|
||||||
|
# Collect all template field specs across required/optional sections
|
||||||
|
field_specs: list[tuple[str, tuple[str, dict[str, Any]], bool]] = []
|
||||||
|
for field_required_key in ("required", "optional"):
|
||||||
|
section = template.get(field_required_key, {})
|
||||||
|
is_required_field = field_required_key == "required"
|
||||||
|
for field_id, field_value in section.items():
|
||||||
|
field_specs.append((field_id, field_value, is_required_field))
|
||||||
|
|
||||||
|
# Determine how many rows are currently present by scanning live_inputs
|
||||||
|
finalized_prefix = finalize_prefix(curr_prefix)
|
||||||
|
present_rows = 0
|
||||||
|
for live_key in live_inputs:
|
||||||
|
# Keys look like "<prefix>.<row>.<field_id>"
|
||||||
|
if live_key.startswith(finalized_prefix + "."):
|
||||||
|
remainder = live_key[len(finalized_prefix) + 1:]
|
||||||
|
parts = remainder.split(".", 1)
|
||||||
|
if len(parts) >= 1:
|
||||||
|
try:
|
||||||
|
row_idx = int(parts[0])
|
||||||
|
present_rows = max(present_rows, row_idx + 1)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if present_rows > max_rows:
|
||||||
|
raise ValueError(
|
||||||
|
f"DynamicGroup input '{finalized_prefix}' received {present_rows} rows but max is {max_rows}."
|
||||||
|
)
|
||||||
|
row_count = max(min_rows, present_rows)
|
||||||
|
|
||||||
|
for row in range(row_count):
|
||||||
|
for field_id, field_value, is_required_field in field_specs:
|
||||||
|
slot_id = f"{finalized_prefix}.{row}.{field_id}"
|
||||||
|
# The first `min_rows` rows are required if the field itself is required
|
||||||
|
if row < min_rows and is_required_field:
|
||||||
|
out_dict["required"][slot_id] = field_value
|
||||||
|
else:
|
||||||
|
out_dict["optional"][slot_id] = field_value
|
||||||
|
# Register into dynamic_paths so build_nested_inputs places value at the right path
|
||||||
|
out_dict["dynamic_paths"][slot_id] = slot_id
|
||||||
|
|
||||||
|
# Track the list root path so build_nested_inputs can convert the index dict to a list
|
||||||
|
out_dict.setdefault("list_paths", set()).add(finalized_prefix)
|
||||||
|
|
||||||
|
# Handle the empty case (0 rows) – emit an empty-list default for the parent.
|
||||||
|
# This must only fire when there are genuinely no rows; otherwise the parent
|
||||||
|
# path would clobber the per-row dict built from the slot ids above.
|
||||||
|
if row_count == 0:
|
||||||
|
out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix
|
||||||
|
out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_LIST
|
||||||
|
|
||||||
|
|
||||||
@comfytype(io_type="IMAGECOMPARE")
|
@comfytype(io_type="IMAGECOMPARE")
|
||||||
class ImageCompare(ComfyTypeI):
|
class ImageCompare(ComfyTypeI):
|
||||||
Type = dict
|
Type = dict
|
||||||
@ -1418,6 +1570,8 @@ def setup_dynamic_input_funcs():
|
|||||||
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
|
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
|
||||||
# DynamicSlot.Input
|
# DynamicSlot.Input
|
||||||
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
|
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
|
||||||
|
# DynamicGroup.Input
|
||||||
|
register_dynamic_input_func(DynamicGroup.io_type, DynamicGroup._expand_schema_for_dynamic)
|
||||||
|
|
||||||
if len(DYNAMIC_INPUT_LOOKUP) == 0:
|
if len(DYNAMIC_INPUT_LOOKUP) == 0:
|
||||||
setup_dynamic_input_funcs()
|
setup_dynamic_input_funcs()
|
||||||
@ -1429,6 +1583,8 @@ class V3Data(TypedDict):
|
|||||||
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
||||||
dynamic_paths_default_value: dict[str, Any]
|
dynamic_paths_default_value: dict[str, Any]
|
||||||
'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.'
|
'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.'
|
||||||
|
list_paths: set[str]
|
||||||
|
'Set of top-level keys whose index-keyed dict values should be converted to a sorted list[dict] after build_nested_inputs runs.'
|
||||||
create_dynamic_tuple: bool
|
create_dynamic_tuple: bool
|
||||||
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
||||||
|
|
||||||
@ -1770,6 +1926,7 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
|||||||
"optional": {},
|
"optional": {},
|
||||||
"dynamic_paths": {},
|
"dynamic_paths": {},
|
||||||
"dynamic_paths_default_value": {},
|
"dynamic_paths_default_value": {},
|
||||||
|
"list_paths": set(),
|
||||||
}
|
}
|
||||||
d = d.copy()
|
d = d.copy()
|
||||||
# ignore hidden for parsing
|
# ignore hidden for parsing
|
||||||
@ -1785,6 +1942,10 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
|||||||
dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None)
|
dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None)
|
||||||
if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0:
|
if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0:
|
||||||
v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value
|
v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value
|
||||||
|
# list_paths: keys whose nested dict should be post-converted to a sorted list[dict]
|
||||||
|
list_paths = out_dict.pop("list_paths", None)
|
||||||
|
if list_paths:
|
||||||
|
v3_data["list_paths"] = list_paths
|
||||||
return out_dict, hidden, v3_data
|
return out_dict, hidden, v3_data
|
||||||
|
|
||||||
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
||||||
@ -1820,10 +1981,12 @@ def add_to_dict_v1(i: Input, d: dict):
|
|||||||
|
|
||||||
class DynamicPathsDefaultValue:
|
class DynamicPathsDefaultValue:
|
||||||
EMPTY_DICT = "empty_dict"
|
EMPTY_DICT = "empty_dict"
|
||||||
|
EMPTY_LIST = "empty_list"
|
||||||
|
|
||||||
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||||
paths = v3_data.get("dynamic_paths", None)
|
paths = v3_data.get("dynamic_paths", None)
|
||||||
default_value_dict = v3_data.get("dynamic_paths_default_value", {})
|
default_value_dict = v3_data.get("dynamic_paths_default_value", {})
|
||||||
|
list_paths: set[str] = v3_data.get("list_paths", set()) or set()
|
||||||
if paths is None:
|
if paths is None:
|
||||||
return values
|
return values
|
||||||
values = values.copy()
|
values = values.copy()
|
||||||
@ -1846,6 +2009,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
|||||||
default_option = default_value_dict.get(key, None)
|
default_option = default_value_dict.get(key, None)
|
||||||
if default_option == DynamicPathsDefaultValue.EMPTY_DICT:
|
if default_option == DynamicPathsDefaultValue.EMPTY_DICT:
|
||||||
value = {}
|
value = {}
|
||||||
|
elif default_option == DynamicPathsDefaultValue.EMPTY_LIST:
|
||||||
|
value = []
|
||||||
if create_tuple:
|
if create_tuple:
|
||||||
value = (value, key)
|
value = (value, key)
|
||||||
current[p] = value
|
current[p] = value
|
||||||
@ -1853,6 +2018,34 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
|||||||
current = current.setdefault(p, {})
|
current = current.setdefault(p, {})
|
||||||
|
|
||||||
values.update(result)
|
values.update(result)
|
||||||
|
|
||||||
|
# Post-pass: convert index-keyed dicts to sorted lists for io.DynamicGroup fields
|
||||||
|
for list_path in list_paths:
|
||||||
|
parts = list_path.split(".")
|
||||||
|
# Navigate to the parent container, then convert the leaf
|
||||||
|
container = values
|
||||||
|
for part in parts[:-1]:
|
||||||
|
if not isinstance(container, dict) or part not in container:
|
||||||
|
container = None
|
||||||
|
break
|
||||||
|
container = container[part]
|
||||||
|
if container is None:
|
||||||
|
continue
|
||||||
|
leaf_key = parts[-1]
|
||||||
|
leaf = container.get(leaf_key, None)
|
||||||
|
if isinstance(leaf, dict):
|
||||||
|
try:
|
||||||
|
sorted_rows = [leaf[k] for k in sorted(leaf.keys(), key=int)]
|
||||||
|
container[leaf_key] = sorted_rows
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# Keys are not all integers; leave as-is
|
||||||
|
pass
|
||||||
|
elif isinstance(leaf, list):
|
||||||
|
# Already a list (e.g. the EMPTY_LIST default was applied above)
|
||||||
|
pass
|
||||||
|
elif leaf is None:
|
||||||
|
container[leaf_key] = []
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
@ -2417,7 +2610,9 @@ __all__ = [
|
|||||||
# Dynamic Types
|
# Dynamic Types
|
||||||
"MatchType",
|
"MatchType",
|
||||||
"DynamicCombo",
|
"DynamicCombo",
|
||||||
|
"DynamicSlot",
|
||||||
"Autogrow",
|
"Autogrow",
|
||||||
|
"DynamicGroup",
|
||||||
# Other classes
|
# Other classes
|
||||||
"HiddenHolder",
|
"HiddenHolder",
|
||||||
"Hidden",
|
"Hidden",
|
||||||
|
|||||||
@ -1,85 +1,68 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import re
|
import re
|
||||||
|
import ctypes
|
||||||
import logging
|
import logging
|
||||||
import ctypes.util
|
|
||||||
import importlib.util
|
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
|
import comfy_angle
|
||||||
from comfy_api.latest import ComfyExtension, io, ui
|
from comfy_api.latest import ComfyExtension, io, ui
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from utils.install_util import get_missing_requirements_message
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _check_opengl_availability():
|
def _preload_angle():
|
||||||
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
|
egl_path = comfy_angle.get_egl_path()
|
||||||
logger.debug("_check_opengl_availability: starting")
|
gles_path = comfy_angle.get_glesv2_path()
|
||||||
missing = []
|
|
||||||
|
|
||||||
# Check Python packages (using find_spec to avoid importing)
|
if sys.platform == "win32":
|
||||||
logger.debug("_check_opengl_availability: checking for glfw package")
|
angle_dir = comfy_angle.get_lib_dir()
|
||||||
if importlib.util.find_spec("glfw") is None:
|
os.add_dll_directory(angle_dir)
|
||||||
missing.append("glfw")
|
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
|
||||||
|
|
||||||
logger.debug("_check_opengl_availability: checking for OpenGL package")
|
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
|
||||||
if importlib.util.find_spec("OpenGL") is None:
|
ctypes.CDLL(str(egl_path), mode=mode)
|
||||||
missing.append("PyOpenGL")
|
ctypes.CDLL(str(gles_path), mode=mode)
|
||||||
|
|
||||||
if missing:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
# On Linux without display, check if headless backends are available
|
|
||||||
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
|
|
||||||
if sys.platform.startswith("linux"):
|
|
||||||
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
|
|
||||||
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
|
|
||||||
if not has_display:
|
|
||||||
# Check for EGL or OSMesa libraries
|
|
||||||
logger.debug("_check_opengl_availability: checking for EGL library")
|
|
||||||
has_egl = ctypes.util.find_library("EGL")
|
|
||||||
logger.debug("_check_opengl_availability: checking for OSMesa library")
|
|
||||||
has_osmesa = ctypes.util.find_library("OSMesa")
|
|
||||||
|
|
||||||
# Error disabled for CI as it fails this check
|
|
||||||
# if not has_egl and not has_osmesa:
|
|
||||||
# raise RuntimeError(
|
|
||||||
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
|
|
||||||
# "See error below for installation instructions."
|
|
||||||
# )
|
|
||||||
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
|
|
||||||
|
|
||||||
logger.debug("_check_opengl_availability: completed")
|
|
||||||
|
|
||||||
|
|
||||||
# Run early check at import time
|
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
|
||||||
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
|
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
|
||||||
_check_opengl_availability()
|
_preload_angle()
|
||||||
|
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||||
# OpenGL modules - initialized lazily when context is created
|
|
||||||
gl = None
|
|
||||||
glfw = None
|
|
||||||
EGL = None
|
|
||||||
|
|
||||||
|
|
||||||
def _import_opengl():
|
import OpenGL
|
||||||
"""Import OpenGL module. Called after context is created."""
|
OpenGL.USE_ACCELERATE = False
|
||||||
global gl
|
|
||||||
if gl is None:
|
|
||||||
logger.debug("_import_opengl: importing OpenGL.GL")
|
|
||||||
import OpenGL.GL as _gl
|
|
||||||
gl = _gl
|
|
||||||
logger.debug("_import_opengl: import completed")
|
|
||||||
return gl
|
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_find_library():
|
||||||
|
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
|
||||||
|
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
|
||||||
|
'libGLESv2'. Patch find_library to return the full ANGLE paths so
|
||||||
|
PyOpenGL loads the same libraries we pre-loaded."""
|
||||||
|
if sys.platform == "linux":
|
||||||
|
return
|
||||||
|
import ctypes.util
|
||||||
|
_orig = ctypes.util.find_library
|
||||||
|
def _patched(name):
|
||||||
|
if name == 'EGL':
|
||||||
|
return comfy_angle.get_egl_path()
|
||||||
|
if name == 'GLESv2':
|
||||||
|
return comfy_angle.get_glesv2_path()
|
||||||
|
return _orig(name)
|
||||||
|
ctypes.util.find_library = _patched
|
||||||
|
|
||||||
|
|
||||||
|
_patch_find_library()
|
||||||
|
|
||||||
|
from OpenGL import EGL
|
||||||
|
from OpenGL import GLES3 as gl
|
||||||
|
|
||||||
class SizeModeInput(TypedDict):
|
class SizeModeInput(TypedDict):
|
||||||
size_mode: str
|
size_mode: str
|
||||||
width: int
|
width: int
|
||||||
@ -102,7 +85,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
|||||||
# (-1,-1)---(3,-1)
|
# (-1,-1)---(3,-1)
|
||||||
#
|
#
|
||||||
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
||||||
VERTEX_SHADER = """#version 330 core
|
VERTEX_SHADER = """#version 300 es
|
||||||
out vec2 v_texCoord;
|
out vec2 v_texCoord;
|
||||||
void main() {
|
void main() {
|
||||||
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
||||||
@ -126,14 +109,99 @@ void main() {
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _convert_es_to_desktop(source: str) -> str:
|
|
||||||
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
def _egl_attribs(*values):
|
||||||
# Remove any existing #version directive
|
"""Build an EGL_NONE-terminated EGLint attribute array."""
|
||||||
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
vals = list(values) + [EGL.EGL_NONE]
|
||||||
# Remove precision qualifiers (not needed in desktop GLSL)
|
return (ctypes.c_int32 * len(vals))(*vals)
|
||||||
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
|
||||||
# Prepend desktop GLSL version
|
|
||||||
return "#version 330 core\n" + source
|
# EGL platform extension constants
|
||||||
|
EGL_PLATFORM_ANGLE_ANGLE = 0x3202
|
||||||
|
EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203
|
||||||
|
EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450
|
||||||
|
EGL_MESA_PLATFORM_SURFACELESS = 0x31DD
|
||||||
|
|
||||||
|
|
||||||
|
_eglGetPlatformDisplayEXT = None
|
||||||
|
|
||||||
|
def _get_egl_platform_display_ext(platform, native_display, attribs):
|
||||||
|
"""Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL)."""
|
||||||
|
global _eglGetPlatformDisplayEXT
|
||||||
|
if _eglGetPlatformDisplayEXT is None:
|
||||||
|
from OpenGL import platform as _plat
|
||||||
|
egl_lib = _plat.PLATFORM.EGL
|
||||||
|
_get_proc = egl_lib.eglGetProcAddress
|
||||||
|
_get_proc.restype = ctypes.c_void_p
|
||||||
|
_get_proc.argtypes = [ctypes.c_char_p]
|
||||||
|
ptr = _get_proc(b"eglGetPlatformDisplayEXT")
|
||||||
|
if not ptr:
|
||||||
|
return None
|
||||||
|
func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p)
|
||||||
|
_eglGetPlatformDisplayEXT = func_type(ptr)
|
||||||
|
|
||||||
|
raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs)
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
return ctypes.cast(raw, EGL.EGLDisplay)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_egl_display():
|
||||||
|
"""Get an EGL display, trying the default first then ANGLE's Vulkan
|
||||||
|
platform for headless environments without a display server."""
|
||||||
|
failures = []
|
||||||
|
|
||||||
|
# Try the default display first (works when X11/Wayland is available)
|
||||||
|
display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
|
||||||
|
if display:
|
||||||
|
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||||
|
try:
|
||||||
|
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||||
|
return display, major.value, minor.value
|
||||||
|
except Exception as e:
|
||||||
|
failures.append(f"default: {e}")
|
||||||
|
|
||||||
|
logger.info("Default EGL display unavailable, trying headless fallbacks")
|
||||||
|
|
||||||
|
# Headless fallback strategies, tried in order:
|
||||||
|
headless_strategies = [
|
||||||
|
("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None),
|
||||||
|
("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None,
|
||||||
|
_egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)),
|
||||||
|
]
|
||||||
|
|
||||||
|
for name, platform, native_display, attribs in headless_strategies:
|
||||||
|
display = _get_egl_platform_display_ext(platform, native_display, attribs)
|
||||||
|
if not display:
|
||||||
|
failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display")
|
||||||
|
continue
|
||||||
|
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||||
|
try:
|
||||||
|
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||||
|
logger.info(f"Using EGL {name} platform (headless)")
|
||||||
|
return display, major.value, minor.value
|
||||||
|
failures.append(f"{name}: eglInitialize returned false")
|
||||||
|
except Exception as e:
|
||||||
|
failures.append(f"{name}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
details = "\n".join(f" - {f}" for f in failures)
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to initialize EGL display.\n"
|
||||||
|
"No display server and no headless EGL platform available.\n"
|
||||||
|
f"Tried:\n{details}\n"
|
||||||
|
"Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _gl_str(name):
|
||||||
|
"""Get an OpenGL string parameter."""
|
||||||
|
v = gl.glGetString(name)
|
||||||
|
if not v:
|
||||||
|
return "Unknown"
|
||||||
|
if isinstance(v, bytes):
|
||||||
|
return v.decode(errors="replace")
|
||||||
|
return ctypes.string_at(v).decode(errors="replace")
|
||||||
|
|
||||||
|
|
||||||
def _detect_output_count(source: str) -> int:
|
def _detect_output_count(source: str) -> int:
|
||||||
@ -159,163 +227,8 @@ def _detect_pass_count(source: str) -> int:
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
def _init_glfw():
|
|
||||||
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
|
|
||||||
logger.debug("_init_glfw: starting")
|
|
||||||
# On macOS, glfw.init() must be called from main thread or it hangs forever
|
|
||||||
if sys.platform == "darwin":
|
|
||||||
logger.debug("_init_glfw: skipping on macOS")
|
|
||||||
raise RuntimeError("GLFW backend not supported on macOS")
|
|
||||||
|
|
||||||
logger.debug("_init_glfw: importing glfw module")
|
|
||||||
import glfw as _glfw
|
|
||||||
|
|
||||||
logger.debug("_init_glfw: calling glfw.init()")
|
|
||||||
if not _glfw.init():
|
|
||||||
raise RuntimeError("glfw.init() failed")
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.debug("_init_glfw: setting window hints")
|
|
||||||
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
|
|
||||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
|
|
||||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
|
|
||||||
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
|
|
||||||
|
|
||||||
logger.debug("_init_glfw: calling create_window()")
|
|
||||||
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
|
||||||
if not window:
|
|
||||||
raise RuntimeError("glfw.create_window() failed")
|
|
||||||
|
|
||||||
logger.debug("_init_glfw: calling make_context_current()")
|
|
||||||
_glfw.make_context_current(window)
|
|
||||||
logger.debug("_init_glfw: completed successfully")
|
|
||||||
return window, _glfw
|
|
||||||
except Exception:
|
|
||||||
logger.debug("_init_glfw: failed, terminating glfw")
|
|
||||||
_glfw.terminate()
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def _init_egl():
|
|
||||||
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
|
|
||||||
logger.debug("_init_egl: starting")
|
|
||||||
from OpenGL import EGL as _EGL
|
|
||||||
from OpenGL.EGL import (
|
|
||||||
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
|
|
||||||
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
|
|
||||||
eglTerminate, eglDestroyContext, eglDestroySurface,
|
|
||||||
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
|
|
||||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
|
||||||
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
|
|
||||||
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
|
|
||||||
)
|
|
||||||
logger.debug("_init_egl: imports completed")
|
|
||||||
|
|
||||||
display = None
|
|
||||||
context = None
|
|
||||||
surface = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.debug("_init_egl: calling eglGetDisplay()")
|
|
||||||
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
|
|
||||||
if display == _EGL.EGL_NO_DISPLAY:
|
|
||||||
raise RuntimeError("eglGetDisplay() failed")
|
|
||||||
|
|
||||||
logger.debug("_init_egl: calling eglInitialize()")
|
|
||||||
major, minor = _EGL.EGLint(), _EGL.EGLint()
|
|
||||||
if not eglInitialize(display, major, minor):
|
|
||||||
display = None # Not initialized, don't terminate
|
|
||||||
raise RuntimeError("eglInitialize() failed")
|
|
||||||
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
|
|
||||||
|
|
||||||
config_attribs = [
|
|
||||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
|
|
||||||
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
|
||||||
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
|
|
||||||
EGL_DEPTH_SIZE, 0, EGL_NONE
|
|
||||||
]
|
|
||||||
configs = (_EGL.EGLConfig * 1)()
|
|
||||||
num_configs = _EGL.EGLint()
|
|
||||||
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
|
|
||||||
raise RuntimeError("eglChooseConfig() failed")
|
|
||||||
config = configs[0]
|
|
||||||
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
|
|
||||||
|
|
||||||
if not eglBindAPI(EGL_OPENGL_API):
|
|
||||||
raise RuntimeError("eglBindAPI() failed")
|
|
||||||
|
|
||||||
logger.debug("_init_egl: calling eglCreateContext()")
|
|
||||||
context_attribs = [
|
|
||||||
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
|
|
||||||
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
|
|
||||||
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
|
|
||||||
EGL_NONE
|
|
||||||
]
|
|
||||||
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
|
|
||||||
if context == EGL_NO_CONTEXT:
|
|
||||||
raise RuntimeError("eglCreateContext() failed")
|
|
||||||
|
|
||||||
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
|
|
||||||
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
|
|
||||||
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
|
|
||||||
if surface == _EGL.EGL_NO_SURFACE:
|
|
||||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
|
||||||
|
|
||||||
logger.debug("_init_egl: calling eglMakeCurrent()")
|
|
||||||
if not eglMakeCurrent(display, surface, surface, context):
|
|
||||||
raise RuntimeError("eglMakeCurrent() failed")
|
|
||||||
|
|
||||||
logger.debug("_init_egl: completed successfully")
|
|
||||||
return display, context, surface, _EGL
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
logger.debug("_init_egl: failed, cleaning up")
|
|
||||||
# Clean up any resources on failure
|
|
||||||
if surface is not None:
|
|
||||||
eglDestroySurface(display, surface)
|
|
||||||
if context is not None:
|
|
||||||
eglDestroyContext(display, context)
|
|
||||||
if display is not None:
|
|
||||||
eglTerminate(display)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def _init_osmesa():
|
|
||||||
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
|
|
||||||
import ctypes
|
|
||||||
|
|
||||||
logger.debug("_init_osmesa: starting")
|
|
||||||
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
|
||||||
|
|
||||||
logger.debug("_init_osmesa: importing OpenGL.osmesa")
|
|
||||||
from OpenGL import GL as _gl
|
|
||||||
from OpenGL.osmesa import (
|
|
||||||
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
|
|
||||||
OSMESA_RGBA,
|
|
||||||
)
|
|
||||||
logger.debug("_init_osmesa: imports completed")
|
|
||||||
|
|
||||||
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
|
|
||||||
if not ctx:
|
|
||||||
raise RuntimeError("OSMesaCreateContextExt() failed")
|
|
||||||
|
|
||||||
width, height = 64, 64
|
|
||||||
buffer = (ctypes.c_ubyte * (width * height * 4))()
|
|
||||||
|
|
||||||
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
|
|
||||||
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
|
|
||||||
OSMesaDestroyContext(ctx)
|
|
||||||
raise RuntimeError("OSMesaMakeCurrent() failed")
|
|
||||||
|
|
||||||
logger.debug("_init_osmesa: completed successfully")
|
|
||||||
return ctx, buffer
|
|
||||||
|
|
||||||
|
|
||||||
class GLContext:
|
class GLContext:
|
||||||
"""Manages OpenGL context and resources for shader execution.
|
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
|
||||||
|
|
||||||
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
|
|
||||||
"""
|
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
_initialized = False
|
_initialized = False
|
||||||
@ -327,131 +240,105 @@ class GLContext:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if GLContext._initialized:
|
if GLContext._initialized:
|
||||||
logger.debug("GLContext.__init__: already initialized, skipping")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug("GLContext.__init__: starting initialization")
|
|
||||||
|
|
||||||
global glfw, EGL
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
self._backend = None
|
self._display = None
|
||||||
self._window = None
|
self._surface = None
|
||||||
self._egl_display = None
|
self._context = None
|
||||||
self._egl_context = None
|
|
||||||
self._egl_surface = None
|
|
||||||
self._osmesa_ctx = None
|
|
||||||
self._osmesa_buffer = None
|
|
||||||
self._vao = None
|
self._vao = None
|
||||||
|
|
||||||
# Try backends in order: GLFW → EGL → OSMesa
|
|
||||||
errors = []
|
|
||||||
|
|
||||||
logger.debug("GLContext.__init__: trying GLFW backend")
|
|
||||||
try:
|
try:
|
||||||
self._window, glfw = _init_glfw()
|
self._display, self._egl_major, self._egl_minor = _get_egl_display()
|
||||||
self._backend = "glfw"
|
|
||||||
logger.debug("GLContext.__init__: GLFW backend succeeded")
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
|
|
||||||
errors.append(("GLFW", e))
|
|
||||||
|
|
||||||
if self._backend is None:
|
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
|
||||||
logger.debug("GLContext.__init__: trying EGL backend")
|
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
|
||||||
try:
|
|
||||||
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
|
|
||||||
self._backend = "egl"
|
|
||||||
logger.debug("GLContext.__init__: EGL backend succeeded")
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
|
|
||||||
errors.append(("EGL", e))
|
|
||||||
|
|
||||||
if self._backend is None:
|
config = EGL.EGLConfig()
|
||||||
logger.debug("GLContext.__init__: trying OSMesa backend")
|
n_configs = ctypes.c_int32(0)
|
||||||
try:
|
if not EGL.eglChooseConfig(
|
||||||
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
|
self._display,
|
||||||
self._backend = "osmesa"
|
_egl_attribs(
|
||||||
logger.debug("GLContext.__init__: OSMesa backend succeeded")
|
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
|
||||||
except Exception as e:
|
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
|
||||||
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
|
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
|
||||||
errors.append(("OSMesa", e))
|
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
|
||||||
|
),
|
||||||
|
ctypes.byref(config), 1, ctypes.byref(n_configs),
|
||||||
|
) or n_configs.value == 0:
|
||||||
|
raise RuntimeError("eglChooseConfig() failed")
|
||||||
|
|
||||||
if self._backend is None:
|
self._surface = EGL.eglCreatePbufferSurface(
|
||||||
if sys.platform == "win32":
|
self._display, config,
|
||||||
platform_help = (
|
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
|
||||||
"Windows: Ensure GPU drivers are installed and display is available.\n"
|
|
||||||
" CPU-only/headless mode is not supported on Windows."
|
|
||||||
)
|
|
||||||
elif sys.platform == "darwin":
|
|
||||||
platform_help = (
|
|
||||||
"macOS: GLFW is not supported.\n"
|
|
||||||
" Install OSMesa via Homebrew: brew install mesa\n"
|
|
||||||
" Then: pip install PyOpenGL PyOpenGL-accelerate"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
platform_help = (
|
|
||||||
"Linux: Install one of these backends:\n"
|
|
||||||
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
|
|
||||||
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
|
|
||||||
" Headless (CPU): sudo apt install libosmesa6"
|
|
||||||
)
|
|
||||||
|
|
||||||
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to create OpenGL context.\n\n"
|
|
||||||
f"Backend errors:\n{error_details}\n\n"
|
|
||||||
f"{platform_help}"
|
|
||||||
)
|
)
|
||||||
|
if not self._surface:
|
||||||
|
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||||
|
|
||||||
# Now import OpenGL.GL (after context is current)
|
self._context = EGL.eglCreateContext(
|
||||||
logger.debug("GLContext.__init__: importing OpenGL.GL")
|
self._display, config, EGL.EGL_NO_CONTEXT,
|
||||||
_import_opengl()
|
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
|
||||||
|
)
|
||||||
|
if not self._context:
|
||||||
|
raise RuntimeError("eglCreateContext() failed")
|
||||||
|
|
||||||
# Create VAO (required for core profile, but OSMesa may use compat profile)
|
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||||
logger.debug("GLContext.__init__: creating VAO")
|
raise RuntimeError("eglMakeCurrent() failed")
|
||||||
try:
|
|
||||||
vao = gl.glGenVertexArrays(1)
|
self._vao = gl.glGenVertexArrays(1)
|
||||||
gl.glBindVertexArray(vao)
|
gl.glBindVertexArray(self._vao)
|
||||||
self._vao = vao # Only store after successful bind
|
|
||||||
logger.debug("GLContext.__init__: VAO created successfully")
|
except Exception:
|
||||||
except Exception as e:
|
self._cleanup()
|
||||||
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
|
raise
|
||||||
# OSMesa with older Mesa may not support VAOs
|
|
||||||
# Clean up if we created but couldn't bind
|
|
||||||
if vao:
|
|
||||||
try:
|
|
||||||
gl.glDeleteVertexArrays(1, [vao])
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
elapsed = (time.perf_counter() - start) * 1000
|
elapsed = (time.perf_counter() - start) * 1000
|
||||||
|
|
||||||
# Log device info
|
renderer = _gl_str(gl.GL_RENDERER)
|
||||||
renderer = gl.glGetString(gl.GL_RENDERER)
|
vendor = _gl_str(gl.GL_VENDOR)
|
||||||
vendor = gl.glGetString(gl.GL_VENDOR)
|
version = _gl_str(gl.GL_VERSION)
|
||||||
version = gl.glGetString(gl.GL_VERSION)
|
|
||||||
renderer = renderer.decode() if renderer else "Unknown"
|
|
||||||
vendor = vendor.decode() if vendor else "Unknown"
|
|
||||||
version = version.decode() if version else "Unknown"
|
|
||||||
|
|
||||||
GLContext._initialized = True
|
GLContext._initialized = True
|
||||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
|
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}")
|
||||||
|
|
||||||
def make_current(self):
|
def make_current(self):
|
||||||
if self._backend == "glfw":
|
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||||
glfw.make_context_current(self._window)
|
err = EGL.eglGetError()
|
||||||
elif self._backend == "egl":
|
raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})")
|
||||||
from OpenGL.EGL import eglMakeCurrent
|
|
||||||
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
|
|
||||||
elif self._backend == "osmesa":
|
|
||||||
from OpenGL.osmesa import OSMesaMakeCurrent
|
|
||||||
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
|
|
||||||
|
|
||||||
if self._vao is not None:
|
if self._vao is not None:
|
||||||
gl.glBindVertexArray(self._vao)
|
gl.glBindVertexArray(self._vao)
|
||||||
|
|
||||||
|
def _cleanup(self):
|
||||||
|
if not self._display:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
if self._vao is not None:
|
||||||
|
gl.glDeleteVertexArrays(1, [self._vao])
|
||||||
|
self._vao = None
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
if self._context:
|
||||||
|
EGL.eglDestroyContext(self._display, self._context)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
if self._surface:
|
||||||
|
EGL.eglDestroySurface(self._display, self._surface)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
EGL.eglTerminate(self._display)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._display = None
|
||||||
|
|
||||||
|
|
||||||
def _compile_shader(source: str, shader_type: int) -> int:
|
def _compile_shader(source: str, shader_type: int) -> int:
|
||||||
"""Compile a shader and return its ID."""
|
"""Compile a shader and return its ID."""
|
||||||
@ -459,8 +346,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
|
|||||||
gl.glShaderSource(shader, source)
|
gl.glShaderSource(shader, source)
|
||||||
gl.glCompileShader(shader)
|
gl.glCompileShader(shader)
|
||||||
|
|
||||||
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
|
||||||
error = gl.glGetShaderInfoLog(shader).decode()
|
error = gl.glGetShaderInfoLog(shader)
|
||||||
|
if isinstance(error, bytes):
|
||||||
|
error = error.decode(errors="replace")
|
||||||
gl.glDeleteShader(shader)
|
gl.glDeleteShader(shader)
|
||||||
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
||||||
|
|
||||||
@ -484,8 +373,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
|
|||||||
gl.glDeleteShader(vertex_shader)
|
gl.glDeleteShader(vertex_shader)
|
||||||
gl.glDeleteShader(fragment_shader)
|
gl.glDeleteShader(fragment_shader)
|
||||||
|
|
||||||
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
|
||||||
error = gl.glGetProgramInfoLog(program).decode()
|
error = gl.glGetProgramInfoLog(program)
|
||||||
|
if isinstance(error, bytes):
|
||||||
|
error = error.decode(errors="replace")
|
||||||
gl.glDeleteProgram(program)
|
gl.glDeleteProgram(program)
|
||||||
raise RuntimeError(f"Program linking failed:\n{error}")
|
raise RuntimeError(f"Program linking failed:\n{error}")
|
||||||
|
|
||||||
@ -530,9 +421,6 @@ def _render_shader_batch(
|
|||||||
ctx = GLContext()
|
ctx = GLContext()
|
||||||
ctx.make_current()
|
ctx.make_current()
|
||||||
|
|
||||||
# Convert from GLSL ES to desktop GLSL 330
|
|
||||||
fragment_source = _convert_es_to_desktop(fragment_code)
|
|
||||||
|
|
||||||
# Detect how many outputs the shader actually uses
|
# Detect how many outputs the shader actually uses
|
||||||
num_outputs = _detect_output_count(fragment_code)
|
num_outputs = _detect_output_count(fragment_code)
|
||||||
|
|
||||||
@ -558,9 +446,9 @@ def _render_shader_batch(
|
|||||||
try:
|
try:
|
||||||
# Compile shaders (once for all batches)
|
# Compile shaders (once for all batches)
|
||||||
try:
|
try:
|
||||||
program = _create_program(VERTEX_SHADER, fragment_source)
|
program = _create_program(VERTEX_SHADER, fragment_code)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
logger.error(f"Fragment shader:\n{fragment_source}")
|
logger.error(f"Fragment shader:\n{fragment_code}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
gl.glUseProgram(program)
|
gl.glUseProgram(program)
|
||||||
@ -723,13 +611,13 @@ def _render_shader_batch(
|
|||||||
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
||||||
|
|
||||||
# Read back outputs for this batch
|
# Read back outputs for this batch
|
||||||
# (glGetTexImage is synchronous, implicitly waits for rendering)
|
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
||||||
batch_outputs = []
|
batch_outputs = []
|
||||||
for tex in output_textures:
|
for i in range(num_outputs):
|
||||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
|
||||||
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
buf = np.empty((height, width, 4), dtype=np.float32)
|
||||||
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
|
||||||
batch_outputs.append(img[::-1, :, :].copy())
|
batch_outputs.append(buf[::-1, :, :].copy())
|
||||||
|
|
||||||
# Pad with black images for unused outputs
|
# Pad with black images for unused outputs
|
||||||
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
||||||
@ -750,18 +638,18 @@ def _render_shader_batch(
|
|||||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
||||||
gl.glUseProgram(0)
|
gl.glUseProgram(0)
|
||||||
|
|
||||||
for tex in input_textures:
|
if input_textures:
|
||||||
gl.glDeleteTextures(int(tex))
|
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||||
for tex in curve_textures:
|
if curve_textures:
|
||||||
gl.glDeleteTextures(int(tex))
|
gl.glDeleteTextures(len(curve_textures), curve_textures)
|
||||||
for tex in output_textures:
|
if output_textures:
|
||||||
gl.glDeleteTextures(int(tex))
|
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||||
for tex in ping_pong_textures:
|
if ping_pong_textures:
|
||||||
gl.glDeleteTextures(int(tex))
|
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
|
||||||
if fbo is not None:
|
if fbo is not None:
|
||||||
gl.glDeleteFramebuffers(1, [fbo])
|
gl.glDeleteFramebuffers(1, [fbo])
|
||||||
for pp_fbo in ping_pong_fbos:
|
if ping_pong_fbos:
|
||||||
gl.glDeleteFramebuffers(1, [pp_fbo])
|
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
|
||||||
if program is not None:
|
if program is not None:
|
||||||
gl.glDeleteProgram(program)
|
gl.glDeleteProgram(program)
|
||||||
|
|
||||||
|
|||||||
130
comfy_extras/nodes_lora_stack.py
Normal file
130
comfy_extras/nodes_lora_stack.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
"""LoRA stacking loaders built on io.DynamicGroup.
|
||||||
|
|
||||||
|
Two nodes that let you stack any number of LoRAs in a single node, each row
|
||||||
|
carrying only a LoRA name and a strength:
|
||||||
|
|
||||||
|
LoadLoraModel
|
||||||
|
Applies a stack of LoRAs to a diffusion MODEL.
|
||||||
|
|
||||||
|
LoadLoraTextEncoder
|
||||||
|
Applies a stack of LoRAs to a CLIP text encoder.
|
||||||
|
|
||||||
|
Both are modelled on DynamicGroupLoraStyleTest in nodes_dynamic_group_test.py,
|
||||||
|
but operate on real models and real LoRA files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
import comfy.sd
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
# Module-level cache so repeated executions don't re-read the same file from disk.
|
||||||
|
_LORA_CACHE: dict[str, tuple] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _load_lora_file(lora_name: str):
|
||||||
|
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
||||||
|
cached = _LORA_CACHE.get(lora_path)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
lora, metadata = comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True)
|
||||||
|
_LORA_CACHE[lora_path] = (lora, metadata)
|
||||||
|
return lora, metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _lora_template() -> list[io.Input]:
|
||||||
|
return [
|
||||||
|
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras"),
|
||||||
|
tooltip="The name of the LoRA file to apply."),
|
||||||
|
io.Float.Input("strength", default=1.0, min=-100.0, max=100.0, step=0.01,
|
||||||
|
tooltip="How strongly to apply this LoRA. 0 = off, negative inverts the effect."),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class LoadLoraModel(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LoadLoraModel",
|
||||||
|
display_name="Load LoRA (Model)",
|
||||||
|
search_aliases=["lora", "load lora", "apply lora", "lora model", "lora stack"],
|
||||||
|
category="model/loaders",
|
||||||
|
description="Apply a stack of LoRAs to a diffusion model. Add one row per LoRA; "
|
||||||
|
"each row picks a LoRA file and its strength.",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model", tooltip="The diffusion model the LoRAs will be applied to."),
|
||||||
|
io.DynamicGroup.Input(
|
||||||
|
"loras",
|
||||||
|
template=_lora_template(),
|
||||||
|
min=1,
|
||||||
|
max=50,
|
||||||
|
tooltip="Each row applies one LoRA to the model.",
|
||||||
|
group_name="LoRA",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output(tooltip="The modified diffusion model.")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, loras: list[dict]) -> io.NodeOutput:
|
||||||
|
for row in loras:
|
||||||
|
lora_name = row.get("lora_name")
|
||||||
|
strength = row.get("strength", 1.0)
|
||||||
|
if not lora_name or lora_name == "none" or strength == 0:
|
||||||
|
continue
|
||||||
|
lora, metadata = _load_lora_file(lora_name)
|
||||||
|
model, _ = comfy.sd.load_lora_for_models(model, None, lora, strength, 0, lora_metadata=metadata)
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadLoraTextEncoder(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LoadLoraTextEncoder",
|
||||||
|
display_name="Load LoRA (Text Encoder)",
|
||||||
|
search_aliases=["lora", "load lora", "apply lora", "clip lora", "lora stack"],
|
||||||
|
category="model/loaders",
|
||||||
|
description="Apply a stack of LoRAs to a CLIP text encoder. Add one row per LoRA; "
|
||||||
|
"each row picks a LoRA file and its strength.",
|
||||||
|
inputs=[
|
||||||
|
io.Clip.Input("clip", tooltip="The CLIP text encoder the LoRAs will be applied to."),
|
||||||
|
io.DynamicGroup.Input(
|
||||||
|
"loras",
|
||||||
|
template=_lora_template(),
|
||||||
|
min=1,
|
||||||
|
max=50,
|
||||||
|
tooltip="Each row applies one LoRA to the text encoder.",
|
||||||
|
group_name="LoRA",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Clip.Output(tooltip="The modified CLIP text encoder.")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, clip, loras: list[dict]) -> io.NodeOutput:
|
||||||
|
for row in loras:
|
||||||
|
lora_name = row.get("lora_name")
|
||||||
|
strength = row.get("strength", 1.0)
|
||||||
|
if not lora_name or lora_name == "none" or strength == 0:
|
||||||
|
continue
|
||||||
|
lora, metadata = _load_lora_file(lora_name)
|
||||||
|
_, clip = comfy.sd.load_lora_for_models(None, clip, lora, 0, strength, lora_metadata=metadata)
|
||||||
|
return io.NodeOutput(clip)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraStackExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
LoadLoraModel,
|
||||||
|
LoadLoraTextEncoder,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> LoraStackExtension:
|
||||||
|
return LoraStackExtension()
|
||||||
55
execution.py
55
execution.py
@ -1113,32 +1113,6 @@ def full_type_name(klass):
|
|||||||
return klass.__qualname__
|
return klass.__qualname__
|
||||||
return module + '.' + klass.__qualname__
|
return module + '.' + klass.__qualname__
|
||||||
|
|
||||||
def node_not_executable_reason(class_def, class_type):
|
|
||||||
"""Return a human-readable reason the node cannot be executed, or None if it's fine.
|
|
||||||
|
|
||||||
Catches a node whose declared entry point doesn't resolve to a real method
|
|
||||||
(e.g. a V1 ``FUNCTION = "invert"`` where the method is misspelled, or a V3 node
|
|
||||||
missing its ``execute`` override). Running this during validation surfaces the
|
|
||||||
problem before execution starts, instead of after upstream nodes have run.
|
|
||||||
|
|
||||||
Only the class is inspected; the node is never instantiated here, so a node's
|
|
||||||
``__init__`` side effects cannot run (or fail) during validation.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if issubclass(class_def, _ComfyNodeInternal):
|
|
||||||
# V3: validates that execute()/define_schema() overrides exist.
|
|
||||||
class_def.VALIDATE_CLASS()
|
|
||||||
return None
|
|
||||||
# V1: FUNCTION names the method to call; it must exist on the class.
|
|
||||||
function_name = getattr(class_def, "FUNCTION", None)
|
|
||||||
if function_name is None:
|
|
||||||
return f"'{class_type}' does not define FUNCTION"
|
|
||||||
if not callable(getattr(class_def, function_name, None)):
|
|
||||||
return f"'{class_type}' has no method '{function_name}' (declared in FUNCTION)"
|
|
||||||
return None
|
|
||||||
except Exception as ex:
|
|
||||||
return str(ex)
|
|
||||||
|
|
||||||
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
|
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
|
||||||
outputs = set()
|
outputs = set()
|
||||||
for x in prompt:
|
for x in prompt:
|
||||||
@ -1174,35 +1148,6 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[
|
|||||||
}
|
}
|
||||||
return (False, error, [], {})
|
return (False, error, [], {})
|
||||||
|
|
||||||
# Make sure the node is actually executable (its FUNCTION/execute entry
|
|
||||||
# point resolves to a real method) before we touch any schema-derived
|
|
||||||
# attributes below or start execution. Catches code typos up front and
|
|
||||||
# attributes the error to the offending node.
|
|
||||||
not_executable = node_not_executable_reason(class_, class_type)
|
|
||||||
if not_executable is not None:
|
|
||||||
node_title = prompt[x].get('_meta', {}).get('title', class_type)
|
|
||||||
error = {
|
|
||||||
"type": "invalid_node_definition",
|
|
||||||
"message": "Node is not executable",
|
|
||||||
"details": f"{not_executable} (Node ID '#{x}')",
|
|
||||||
"extra_info": {
|
|
||||||
"node_id": x,
|
|
||||||
"class_type": class_type,
|
|
||||||
"node_title": node_title,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
node_errors = {x: {
|
|
||||||
"errors": [{
|
|
||||||
"type": "invalid_node_definition",
|
|
||||||
"message": "Node is not executable",
|
|
||||||
"details": not_executable,
|
|
||||||
"extra_info": {},
|
|
||||||
}],
|
|
||||||
"dependent_outputs": [],
|
|
||||||
"class_type": class_type,
|
|
||||||
}}
|
|
||||||
return (False, error, [], node_errors)
|
|
||||||
|
|
||||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||||
if partial_execution_list is None or x in partial_execution_list:
|
if partial_execution_list is None or x in partial_execution_list:
|
||||||
outputs.add(x)
|
outputs.add(x)
|
||||||
|
|||||||
1
nodes.py
1
nodes.py
@ -2476,6 +2476,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_triposplat.py",
|
"nodes_triposplat.py",
|
||||||
"nodes_depth_anything_3.py",
|
"nodes_depth_anything_3.py",
|
||||||
"nodes_seed.py",
|
"nodes_seed.py",
|
||||||
|
"nodes_lora_stack.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@ -22,7 +22,7 @@ alembic
|
|||||||
SQLAlchemy>=2.0.0
|
SQLAlchemy>=2.0.0
|
||||||
filelock
|
filelock
|
||||||
av>=16.0.0
|
av>=16.0.0
|
||||||
comfy-kitchen==0.2.12
|
comfy-kitchen==0.2.13
|
||||||
comfy-aimdo==0.4.10
|
comfy-aimdo==0.4.10
|
||||||
requests
|
requests
|
||||||
simpleeval>=1.0.0
|
simpleeval>=1.0.0
|
||||||
@ -33,5 +33,5 @@ kornia>=0.7.1
|
|||||||
spandrel
|
spandrel
|
||||||
pydantic~=2.0
|
pydantic~=2.0
|
||||||
pydantic-settings~=2.0
|
pydantic-settings~=2.0
|
||||||
PyOpenGL
|
PyOpenGL>=3.1.8
|
||||||
glfw
|
comfy-angle
|
||||||
|
|||||||
204
tests-unit/comfy_api_test/io_dynamic_group_test.py
Normal file
204
tests-unit/comfy_api_test/io_dynamic_group_test.py
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
"""Unit tests for io.DynamicGroup: expansion/reconstruction (0-row and N-row cases)."""
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Stub torch (type-hint only in _io.py; real torch not available in unit-test env)
|
||||||
|
if "torch" not in sys.modules:
|
||||||
|
_torch_stub = types.ModuleType("torch")
|
||||||
|
_torch_stub.Tensor = object # type: ignore[attr-defined]
|
||||||
|
sys.modules["torch"] = _torch_stub
|
||||||
|
|
||||||
|
from comfy_api.latest._io import ( # noqa: E402
|
||||||
|
DynamicGroup,
|
||||||
|
Float,
|
||||||
|
Int,
|
||||||
|
String,
|
||||||
|
Boolean,
|
||||||
|
get_finalized_class_inputs,
|
||||||
|
build_nested_inputs,
|
||||||
|
create_input_dict_v1,
|
||||||
|
setup_dynamic_input_funcs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure dynamic input funcs are registered (may already be done at import time)
|
||||||
|
setup_dynamic_input_funcs()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_class_inputs(group_input: DynamicGroup.Input) -> dict:
|
||||||
|
"""Wrap a DynamicGroup.Input into the required/optional dict structure."""
|
||||||
|
return create_input_dict_v1([group_input])
|
||||||
|
|
||||||
|
|
||||||
|
def _run(group_input: DynamicGroup.Input, live_values: dict) -> dict:
|
||||||
|
"""End-to-end helper: expand schema + reconstruct values.
|
||||||
|
|
||||||
|
Mirrors the production split in execution.py:
|
||||||
|
1. get_finalized_class_inputs (schema expansion, line 162)
|
||||||
|
2. build_nested_inputs (value reconstruction, line 281)
|
||||||
|
|
||||||
|
The two steps are separate in production because the engine resolves
|
||||||
|
linked node outputs between them, but in tests we supply values directly.
|
||||||
|
"""
|
||||||
|
class_inputs = _make_class_inputs(group_input)
|
||||||
|
_, _, v3_data = get_finalized_class_inputs(class_inputs, live_values)
|
||||||
|
return build_nested_inputs(dict(live_values), v3_data)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Schema construction
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestDynamicGroupInputConstruction:
|
||||||
|
def test_basic_construction(self):
|
||||||
|
inp = DynamicGroup.Input(
|
||||||
|
"loras",
|
||||||
|
template=[
|
||||||
|
Float.Input("strength", default=1.0),
|
||||||
|
String.Input("name"),
|
||||||
|
],
|
||||||
|
min=0,
|
||||||
|
max=10,
|
||||||
|
)
|
||||||
|
assert inp.id == "loras"
|
||||||
|
assert inp.min == 0
|
||||||
|
assert inp.max == 10
|
||||||
|
assert len(inp.template) == 2
|
||||||
|
|
||||||
|
def test_get_all_includes_self_and_template(self):
|
||||||
|
inp = DynamicGroup.Input(
|
||||||
|
"items",
|
||||||
|
template=[Float.Input("value")],
|
||||||
|
)
|
||||||
|
all_inputs = inp.get_all()
|
||||||
|
assert all_inputs[0] is inp
|
||||||
|
assert all_inputs[1].id == "value"
|
||||||
|
|
||||||
|
def test_as_dict_has_template_min_max(self):
|
||||||
|
inp = DynamicGroup.Input(
|
||||||
|
"items",
|
||||||
|
template=[Float.Input("val", default=0.5)],
|
||||||
|
min=1,
|
||||||
|
max=5,
|
||||||
|
)
|
||||||
|
d = inp.as_dict()
|
||||||
|
assert "template" in d
|
||||||
|
assert d["min"] == 1
|
||||||
|
assert d["max"] == 5
|
||||||
|
|
||||||
|
def test_duplicate_field_ids_raises(self):
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
DynamicGroup.Input(
|
||||||
|
"bad",
|
||||||
|
template=[Float.Input("x"), Float.Input("x")],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_empty_template_raises(self):
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
DynamicGroup.Input("bad", template=[])
|
||||||
|
|
||||||
|
def test_min_gt_max_raises(self):
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
DynamicGroup.Input("bad", template=[Float.Input("x")], min=5, max=3)
|
||||||
|
|
||||||
|
def test_max_exceeds_limit_raises(self):
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
DynamicGroup.Input("bad", template=[Float.Input("x")], max=101)
|
||||||
|
|
||||||
|
def test_dynamic_input_in_template_raises(self):
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
DynamicGroup.Input(
|
||||||
|
"bad",
|
||||||
|
template=[DynamicGroup.Input("nested", template=[Float.Input("x")])],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_validate_calls_through(self):
|
||||||
|
inp = DynamicGroup.Input("items", template=[Float.Input("val", min=-1.0, max=1.0)])
|
||||||
|
inp.validate() # should not raise
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 0-row case
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestZeroRows:
|
||||||
|
def test_empty_live_inputs_produces_empty_list(self):
|
||||||
|
"""With min=0 and no live values, the result should be an empty list."""
|
||||||
|
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
|
||||||
|
assert _run(inp, {}).get("loras") == []
|
||||||
|
|
||||||
|
def test_min_zero_with_values(self):
|
||||||
|
"""min=0 but 2 rows of live data."""
|
||||||
|
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
|
||||||
|
result = _run(inp, {"loras.0.strength": 0.8, "loras.1.strength": 0.5})
|
||||||
|
assert result["loras"] == [{"strength": 0.8}, {"strength": 0.5}]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# N-row case
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestNRows:
|
||||||
|
def test_two_rows_two_fields(self):
|
||||||
|
"""Two rows with two fields each produce a list[dict]."""
|
||||||
|
inp = DynamicGroup.Input(
|
||||||
|
"loras",
|
||||||
|
template=[String.Input("lora_name"), Float.Input("strength", default=1.0)],
|
||||||
|
min=0, max=50,
|
||||||
|
)
|
||||||
|
result = _run(inp, {
|
||||||
|
"loras.0.lora_name": "model_a.safetensors", "loras.0.strength": 0.9,
|
||||||
|
"loras.1.lora_name": "model_b.safetensors", "loras.1.strength": 0.4,
|
||||||
|
})
|
||||||
|
assert result["loras"] == [
|
||||||
|
{"lora_name": "model_a.safetensors", "strength": 0.9},
|
||||||
|
{"lora_name": "model_b.safetensors", "strength": 0.4},
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_rows_are_sorted_by_index(self):
|
||||||
|
"""Rows must be in ascending index order even if dict iteration is unordered."""
|
||||||
|
inp = DynamicGroup.Input("items", template=[Int.Input("v", default=0)], min=0, max=10)
|
||||||
|
result = _run(inp, {"items.0.v": 10, "items.2.v": 30, "items.1.v": 20})
|
||||||
|
assert [row["v"] for row in result["items"]] == [10, 20, 30]
|
||||||
|
|
||||||
|
def test_min_rows_schema_slots(self):
|
||||||
|
"""With min=2 and no live data, 2 slots must appear in the expanded schema."""
|
||||||
|
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||||
|
out, _, _ = get_finalized_class_inputs(_make_class_inputs(inp), {})
|
||||||
|
all_slots = {**out.get("required", {}), **out.get("optional", {})}
|
||||||
|
assert "items.0.val" in all_slots
|
||||||
|
assert "items.1.val" in all_slots
|
||||||
|
|
||||||
|
def test_min_rows_reconstructs_when_no_values(self):
|
||||||
|
"""min=2 with NO live values must still yield a 2-element list,
|
||||||
|
not collapse to [] (regression: parent-path clobber)."""
|
||||||
|
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||||
|
result = _run(inp, {})
|
||||||
|
assert len(result["items"]) == 2
|
||||||
|
assert all("val" in row for row in result["items"])
|
||||||
|
|
||||||
|
def test_min_rows_reconstructs_with_partial_values(self):
|
||||||
|
"""min=2 with only the first row's value present still yields 2 rows."""
|
||||||
|
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||||
|
result = _run(inp, {"items.0.val": 0.7})
|
||||||
|
assert len(result["items"]) == 2
|
||||||
|
assert result["items"][0]["val"] == 0.7
|
||||||
|
assert result["items"][1]["val"] is None
|
||||||
|
|
||||||
|
def test_list_paths_in_v3_data(self):
|
||||||
|
"""list_paths must contain the group id so build_nested_inputs knows to convert."""
|
||||||
|
inp = DynamicGroup.Input("things", template=[Boolean.Input("flag")], min=0, max=5)
|
||||||
|
_, _, v3_data = get_finalized_class_inputs(_make_class_inputs(inp), {})
|
||||||
|
assert "things" in v3_data.get("list_paths", set())
|
||||||
|
|
||||||
|
def test_no_leftover_flat_keys(self):
|
||||||
|
"""Flat keys must be consumed; only the reconstructed list remains."""
|
||||||
|
inp = DynamicGroup.Input("rows", template=[Float.Input("x", default=0.0)], min=0, max=5)
|
||||||
|
result = _run(inp, {"rows.0.x": 1.0, "rows.1.x": 2.0})
|
||||||
|
assert "rows.0.x" not in result
|
||||||
|
assert "rows.1.x" not in result
|
||||||
|
assert isinstance(result["rows"], list)
|
||||||
@ -1,137 +0,0 @@
|
|||||||
"""Tests for pre-execution validation that a node is actually executable.
|
|
||||||
|
|
||||||
validate_prompt rejects a node whose declared entry point does not resolve to a
|
|
||||||
real method (a V1 FUNCTION typo, or a V3 node missing its execute override) before
|
|
||||||
any node runs, attributing the error to the offending node.
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
import nodes
|
|
||||||
from comfy_api.latest import io
|
|
||||||
from execution import node_not_executable_reason, validate_prompt
|
|
||||||
|
|
||||||
|
|
||||||
class _GoodV1Node:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {"required": {}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
|
||||||
FUNCTION = "run"
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = "Test"
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
return (None,)
|
|
||||||
|
|
||||||
|
|
||||||
class _TypoV1Node:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {"required": {}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
|
||||||
FUNCTION = "invert" # method below is misspelled
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = "Test"
|
|
||||||
|
|
||||||
def invvert(self):
|
|
||||||
return (None,)
|
|
||||||
|
|
||||||
|
|
||||||
class _SideEffectInitV1Node:
|
|
||||||
"""Valid class-level method, but a constructor that must never run in validation."""
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {"required": {}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
|
||||||
FUNCTION = "run"
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = "Test"
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
raise RuntimeError("__init__ must not run during validation")
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
return (None,)
|
|
||||||
|
|
||||||
|
|
||||||
def _v3_schema(node_id):
|
|
||||||
return io.Schema(
|
|
||||||
node_id=node_id,
|
|
||||||
display_name=node_id,
|
|
||||||
category="Test",
|
|
||||||
inputs=[],
|
|
||||||
outputs=[io.Image.Output()],
|
|
||||||
is_output_node=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _GoodV3Node(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return _v3_schema("GoodV3Node")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls):
|
|
||||||
return io.NodeOutput(None)
|
|
||||||
|
|
||||||
|
|
||||||
class _TypoV3Node(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return _v3_schema("TypoV3Node")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def exicute(cls): # typo: should be "execute"
|
|
||||||
return io.NodeOutput(None)
|
|
||||||
|
|
||||||
|
|
||||||
def _register(class_type, class_def):
|
|
||||||
nodes.NODE_CLASS_MAPPINGS[class_type] = class_def
|
|
||||||
|
|
||||||
|
|
||||||
def _validate(class_type):
|
|
||||||
prompt = {"1": {"class_type": class_type, "inputs": {}}}
|
|
||||||
return asyncio.run(validate_prompt("pid", prompt, None))
|
|
||||||
|
|
||||||
|
|
||||||
def test_good_node_passes():
|
|
||||||
_register("GoodV1Node", _GoodV1Node)
|
|
||||||
assert node_not_executable_reason(_GoodV1Node, "GoodV1Node") is None
|
|
||||||
valid, _, _, _ = _validate("GoodV1Node")
|
|
||||||
assert valid is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_typo_node_rejected_with_node_error():
|
|
||||||
_register("TypoV1Node", _TypoV1Node)
|
|
||||||
valid, error, _, node_errors = _validate("TypoV1Node")
|
|
||||||
assert valid is False
|
|
||||||
assert error["type"] == "invalid_node_definition"
|
|
||||||
assert node_errors["1"]["class_type"] == "TypoV1Node"
|
|
||||||
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"
|
|
||||||
assert "invert" in node_errors["1"]["errors"][0]["details"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_validation_does_not_instantiate_node():
|
|
||||||
"""A valid node is not constructed during validation, so __init__ never runs."""
|
|
||||||
_register("SideEffectInitV1Node", _SideEffectInitV1Node)
|
|
||||||
assert node_not_executable_reason(_SideEffectInitV1Node, "SideEffectInitV1Node") is None
|
|
||||||
valid, _, _, _ = _validate("SideEffectInitV1Node")
|
|
||||||
assert valid is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_good_v3_node_passes():
|
|
||||||
_register("GoodV3Node", _GoodV3Node)
|
|
||||||
assert node_not_executable_reason(_GoodV3Node, "GoodV3Node") is None
|
|
||||||
valid, _, _, _ = _validate("GoodV3Node")
|
|
||||||
assert valid is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_typo_v3_node_rejected_with_node_error():
|
|
||||||
_register("TypoV3Node", _TypoV3Node)
|
|
||||||
valid, error, _, node_errors = _validate("TypoV3Node")
|
|
||||||
assert valid is False
|
|
||||||
assert error["type"] == "invalid_node_definition"
|
|
||||||
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"
|
|
||||||
Reference in New Issue
Block a user