mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-28 00:46:52 +08:00
Compare commits
13 Commits
fix/valida
...
pysssss/an
| Author | SHA1 | Date | |
|---|---|---|---|
| f94001460f | |||
| 470ac36a0a | |||
| fd5acc96a4 | |||
| ee600a3cce | |||
| 8114516ee6 | |||
| 3eb624ce6c | |||
| 54ff5464bd | |||
| 333ff2e8a0 | |||
| c821d8ee2a | |||
| 27b6f8a927 | |||
| 9ad848bd59 | |||
| efe6439ad0 | |||
| 8d76bb94fd |
@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
|
||||
if (want_requant and len(fns) == 0 or update_weight):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
if isinstance(orig, QuantizedTensor):
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed)
|
||||
else:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
|
||||
if want_requant and len(fns) == 0:
|
||||
@ -1306,8 +1306,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
# dtype is now implicit in the layout class
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
|
||||
@ -1,85 +1,68 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import ctypes
|
||||
import logging
|
||||
import ctypes.util
|
||||
import importlib.util
|
||||
from typing import TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import nodes
|
||||
import comfy_angle
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from typing_extensions import override
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _check_opengl_availability():
|
||||
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
|
||||
logger.debug("_check_opengl_availability: starting")
|
||||
missing = []
|
||||
def _preload_angle():
|
||||
egl_path = comfy_angle.get_egl_path()
|
||||
gles_path = comfy_angle.get_glesv2_path()
|
||||
|
||||
# Check Python packages (using find_spec to avoid importing)
|
||||
logger.debug("_check_opengl_availability: checking for glfw package")
|
||||
if importlib.util.find_spec("glfw") is None:
|
||||
missing.append("glfw")
|
||||
if sys.platform == "win32":
|
||||
angle_dir = comfy_angle.get_lib_dir()
|
||||
os.add_dll_directory(angle_dir)
|
||||
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
|
||||
|
||||
logger.debug("_check_opengl_availability: checking for OpenGL package")
|
||||
if importlib.util.find_spec("OpenGL") is None:
|
||||
missing.append("PyOpenGL")
|
||||
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
||||
)
|
||||
|
||||
# On Linux without display, check if headless backends are available
|
||||
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
|
||||
if sys.platform.startswith("linux"):
|
||||
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
|
||||
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
|
||||
if not has_display:
|
||||
# Check for EGL or OSMesa libraries
|
||||
logger.debug("_check_opengl_availability: checking for EGL library")
|
||||
has_egl = ctypes.util.find_library("EGL")
|
||||
logger.debug("_check_opengl_availability: checking for OSMesa library")
|
||||
has_osmesa = ctypes.util.find_library("OSMesa")
|
||||
|
||||
# Error disabled for CI as it fails this check
|
||||
# if not has_egl and not has_osmesa:
|
||||
# raise RuntimeError(
|
||||
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
|
||||
# "See error below for installation instructions."
|
||||
# )
|
||||
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
|
||||
|
||||
logger.debug("_check_opengl_availability: completed")
|
||||
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
|
||||
ctypes.CDLL(str(egl_path), mode=mode)
|
||||
ctypes.CDLL(str(gles_path), mode=mode)
|
||||
|
||||
|
||||
# Run early check at import time
|
||||
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
|
||||
_check_opengl_availability()
|
||||
|
||||
# OpenGL modules - initialized lazily when context is created
|
||||
gl = None
|
||||
glfw = None
|
||||
EGL = None
|
||||
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
|
||||
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
|
||||
_preload_angle()
|
||||
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
|
||||
|
||||
def _import_opengl():
|
||||
"""Import OpenGL module. Called after context is created."""
|
||||
global gl
|
||||
if gl is None:
|
||||
logger.debug("_import_opengl: importing OpenGL.GL")
|
||||
import OpenGL.GL as _gl
|
||||
gl = _gl
|
||||
logger.debug("_import_opengl: import completed")
|
||||
return gl
|
||||
import OpenGL
|
||||
OpenGL.USE_ACCELERATE = False
|
||||
|
||||
|
||||
def _patch_find_library():
|
||||
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
|
||||
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
|
||||
'libGLESv2'. Patch find_library to return the full ANGLE paths so
|
||||
PyOpenGL loads the same libraries we pre-loaded."""
|
||||
if sys.platform == "linux":
|
||||
return
|
||||
import ctypes.util
|
||||
_orig = ctypes.util.find_library
|
||||
def _patched(name):
|
||||
if name == 'EGL':
|
||||
return comfy_angle.get_egl_path()
|
||||
if name == 'GLESv2':
|
||||
return comfy_angle.get_glesv2_path()
|
||||
return _orig(name)
|
||||
ctypes.util.find_library = _patched
|
||||
|
||||
|
||||
_patch_find_library()
|
||||
|
||||
from OpenGL import EGL
|
||||
from OpenGL import GLES3 as gl
|
||||
|
||||
class SizeModeInput(TypedDict):
|
||||
size_mode: str
|
||||
width: int
|
||||
@ -102,7 +85,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
# (-1,-1)---(3,-1)
|
||||
#
|
||||
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
||||
VERTEX_SHADER = """#version 330 core
|
||||
VERTEX_SHADER = """#version 300 es
|
||||
out vec2 v_texCoord;
|
||||
void main() {
|
||||
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
||||
@ -126,14 +109,99 @@ void main() {
|
||||
"""
|
||||
|
||||
|
||||
def _convert_es_to_desktop(source: str) -> str:
|
||||
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
||||
# Remove any existing #version directive
|
||||
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
||||
# Remove precision qualifiers (not needed in desktop GLSL)
|
||||
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
||||
# Prepend desktop GLSL version
|
||||
return "#version 330 core\n" + source
|
||||
|
||||
def _egl_attribs(*values):
|
||||
"""Build an EGL_NONE-terminated EGLint attribute array."""
|
||||
vals = list(values) + [EGL.EGL_NONE]
|
||||
return (ctypes.c_int32 * len(vals))(*vals)
|
||||
|
||||
|
||||
# EGL platform extension constants
|
||||
EGL_PLATFORM_ANGLE_ANGLE = 0x3202
|
||||
EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203
|
||||
EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450
|
||||
EGL_MESA_PLATFORM_SURFACELESS = 0x31DD
|
||||
|
||||
|
||||
_eglGetPlatformDisplayEXT = None
|
||||
|
||||
def _get_egl_platform_display_ext(platform, native_display, attribs):
|
||||
"""Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL)."""
|
||||
global _eglGetPlatformDisplayEXT
|
||||
if _eglGetPlatformDisplayEXT is None:
|
||||
from OpenGL import platform as _plat
|
||||
egl_lib = _plat.PLATFORM.EGL
|
||||
_get_proc = egl_lib.eglGetProcAddress
|
||||
_get_proc.restype = ctypes.c_void_p
|
||||
_get_proc.argtypes = [ctypes.c_char_p]
|
||||
ptr = _get_proc(b"eglGetPlatformDisplayEXT")
|
||||
if not ptr:
|
||||
return None
|
||||
func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p)
|
||||
_eglGetPlatformDisplayEXT = func_type(ptr)
|
||||
|
||||
raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs)
|
||||
if not raw:
|
||||
return None
|
||||
return ctypes.cast(raw, EGL.EGLDisplay)
|
||||
|
||||
|
||||
def _get_egl_display():
|
||||
"""Get an EGL display, trying the default first then ANGLE's Vulkan
|
||||
platform for headless environments without a display server."""
|
||||
failures = []
|
||||
|
||||
# Try the default display first (works when X11/Wayland is available)
|
||||
display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
|
||||
if display:
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
try:
|
||||
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
return display, major.value, minor.value
|
||||
except Exception as e:
|
||||
failures.append(f"default: {e}")
|
||||
|
||||
logger.info("Default EGL display unavailable, trying headless fallbacks")
|
||||
|
||||
# Headless fallback strategies, tried in order:
|
||||
headless_strategies = [
|
||||
("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None),
|
||||
("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None,
|
||||
_egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)),
|
||||
]
|
||||
|
||||
for name, platform, native_display, attribs in headless_strategies:
|
||||
display = _get_egl_platform_display_ext(platform, native_display, attribs)
|
||||
if not display:
|
||||
failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display")
|
||||
continue
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
try:
|
||||
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
logger.info(f"Using EGL {name} platform (headless)")
|
||||
return display, major.value, minor.value
|
||||
failures.append(f"{name}: eglInitialize returned false")
|
||||
except Exception as e:
|
||||
failures.append(f"{name}: {e}")
|
||||
continue
|
||||
|
||||
details = "\n".join(f" - {f}" for f in failures)
|
||||
raise RuntimeError(
|
||||
"Failed to initialize EGL display.\n"
|
||||
"No display server and no headless EGL platform available.\n"
|
||||
f"Tried:\n{details}\n"
|
||||
"Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer."
|
||||
)
|
||||
|
||||
|
||||
def _gl_str(name):
|
||||
"""Get an OpenGL string parameter."""
|
||||
v = gl.glGetString(name)
|
||||
if not v:
|
||||
return "Unknown"
|
||||
if isinstance(v, bytes):
|
||||
return v.decode(errors="replace")
|
||||
return ctypes.string_at(v).decode(errors="replace")
|
||||
|
||||
|
||||
def _detect_output_count(source: str) -> int:
|
||||
@ -159,163 +227,8 @@ def _detect_pass_count(source: str) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _init_glfw():
|
||||
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_glfw: starting")
|
||||
# On macOS, glfw.init() must be called from main thread or it hangs forever
|
||||
if sys.platform == "darwin":
|
||||
logger.debug("_init_glfw: skipping on macOS")
|
||||
raise RuntimeError("GLFW backend not supported on macOS")
|
||||
|
||||
logger.debug("_init_glfw: importing glfw module")
|
||||
import glfw as _glfw
|
||||
|
||||
logger.debug("_init_glfw: calling glfw.init()")
|
||||
if not _glfw.init():
|
||||
raise RuntimeError("glfw.init() failed")
|
||||
|
||||
try:
|
||||
logger.debug("_init_glfw: setting window hints")
|
||||
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
|
||||
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
|
||||
|
||||
logger.debug("_init_glfw: calling create_window()")
|
||||
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
||||
if not window:
|
||||
raise RuntimeError("glfw.create_window() failed")
|
||||
|
||||
logger.debug("_init_glfw: calling make_context_current()")
|
||||
_glfw.make_context_current(window)
|
||||
logger.debug("_init_glfw: completed successfully")
|
||||
return window, _glfw
|
||||
except Exception:
|
||||
logger.debug("_init_glfw: failed, terminating glfw")
|
||||
_glfw.terminate()
|
||||
raise
|
||||
|
||||
|
||||
def _init_egl():
|
||||
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_egl: starting")
|
||||
from OpenGL import EGL as _EGL
|
||||
from OpenGL.EGL import (
|
||||
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
|
||||
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
|
||||
eglTerminate, eglDestroyContext, eglDestroySurface,
|
||||
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
|
||||
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
|
||||
)
|
||||
logger.debug("_init_egl: imports completed")
|
||||
|
||||
display = None
|
||||
context = None
|
||||
surface = None
|
||||
|
||||
try:
|
||||
logger.debug("_init_egl: calling eglGetDisplay()")
|
||||
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
|
||||
if display == _EGL.EGL_NO_DISPLAY:
|
||||
raise RuntimeError("eglGetDisplay() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglInitialize()")
|
||||
major, minor = _EGL.EGLint(), _EGL.EGLint()
|
||||
if not eglInitialize(display, major, minor):
|
||||
display = None # Not initialized, don't terminate
|
||||
raise RuntimeError("eglInitialize() failed")
|
||||
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
|
||||
|
||||
config_attribs = [
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
|
||||
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
|
||||
EGL_DEPTH_SIZE, 0, EGL_NONE
|
||||
]
|
||||
configs = (_EGL.EGLConfig * 1)()
|
||||
num_configs = _EGL.EGLint()
|
||||
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
config = configs[0]
|
||||
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
|
||||
|
||||
if not eglBindAPI(EGL_OPENGL_API):
|
||||
raise RuntimeError("eglBindAPI() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreateContext()")
|
||||
context_attribs = [
|
||||
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
|
||||
EGL_NONE
|
||||
]
|
||||
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
|
||||
if context == EGL_NO_CONTEXT:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
|
||||
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
|
||||
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
|
||||
if surface == _EGL.EGL_NO_SURFACE:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglMakeCurrent()")
|
||||
if not eglMakeCurrent(display, surface, surface, context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_egl: completed successfully")
|
||||
return display, context, surface, _EGL
|
||||
|
||||
except Exception:
|
||||
logger.debug("_init_egl: failed, cleaning up")
|
||||
# Clean up any resources on failure
|
||||
if surface is not None:
|
||||
eglDestroySurface(display, surface)
|
||||
if context is not None:
|
||||
eglDestroyContext(display, context)
|
||||
if display is not None:
|
||||
eglTerminate(display)
|
||||
raise
|
||||
|
||||
|
||||
def _init_osmesa():
|
||||
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
|
||||
import ctypes
|
||||
|
||||
logger.debug("_init_osmesa: starting")
|
||||
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
||||
|
||||
logger.debug("_init_osmesa: importing OpenGL.osmesa")
|
||||
from OpenGL import GL as _gl
|
||||
from OpenGL.osmesa import (
|
||||
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
|
||||
OSMESA_RGBA,
|
||||
)
|
||||
logger.debug("_init_osmesa: imports completed")
|
||||
|
||||
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
|
||||
if not ctx:
|
||||
raise RuntimeError("OSMesaCreateContextExt() failed")
|
||||
|
||||
width, height = 64, 64
|
||||
buffer = (ctypes.c_ubyte * (width * height * 4))()
|
||||
|
||||
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
|
||||
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
|
||||
OSMesaDestroyContext(ctx)
|
||||
raise RuntimeError("OSMesaMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_osmesa: completed successfully")
|
||||
return ctx, buffer
|
||||
|
||||
|
||||
class GLContext:
|
||||
"""Manages OpenGL context and resources for shader execution.
|
||||
|
||||
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
|
||||
"""
|
||||
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
@ -327,131 +240,105 @@ class GLContext:
|
||||
|
||||
def __init__(self):
|
||||
if GLContext._initialized:
|
||||
logger.debug("GLContext.__init__: already initialized, skipping")
|
||||
return
|
||||
|
||||
logger.debug("GLContext.__init__: starting initialization")
|
||||
|
||||
global glfw, EGL
|
||||
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
|
||||
self._backend = None
|
||||
self._window = None
|
||||
self._egl_display = None
|
||||
self._egl_context = None
|
||||
self._egl_surface = None
|
||||
self._osmesa_ctx = None
|
||||
self._osmesa_buffer = None
|
||||
self._display = None
|
||||
self._surface = None
|
||||
self._context = None
|
||||
self._vao = None
|
||||
|
||||
# Try backends in order: GLFW → EGL → OSMesa
|
||||
errors = []
|
||||
|
||||
logger.debug("GLContext.__init__: trying GLFW backend")
|
||||
try:
|
||||
self._window, glfw = _init_glfw()
|
||||
self._backend = "glfw"
|
||||
logger.debug("GLContext.__init__: GLFW backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
|
||||
errors.append(("GLFW", e))
|
||||
self._display, self._egl_major, self._egl_minor = _get_egl_display()
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying EGL backend")
|
||||
try:
|
||||
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
|
||||
self._backend = "egl"
|
||||
logger.debug("GLContext.__init__: EGL backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
|
||||
errors.append(("EGL", e))
|
||||
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
|
||||
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying OSMesa backend")
|
||||
try:
|
||||
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
|
||||
self._backend = "osmesa"
|
||||
logger.debug("GLContext.__init__: OSMesa backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
|
||||
errors.append(("OSMesa", e))
|
||||
config = EGL.EGLConfig()
|
||||
n_configs = ctypes.c_int32(0)
|
||||
if not EGL.eglChooseConfig(
|
||||
self._display,
|
||||
_egl_attribs(
|
||||
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
|
||||
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
|
||||
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
|
||||
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
|
||||
),
|
||||
ctypes.byref(config), 1, ctypes.byref(n_configs),
|
||||
) or n_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
|
||||
if self._backend is None:
|
||||
if sys.platform == "win32":
|
||||
platform_help = (
|
||||
"Windows: Ensure GPU drivers are installed and display is available.\n"
|
||||
" CPU-only/headless mode is not supported on Windows."
|
||||
)
|
||||
elif sys.platform == "darwin":
|
||||
platform_help = (
|
||||
"macOS: GLFW is not supported.\n"
|
||||
" Install OSMesa via Homebrew: brew install mesa\n"
|
||||
" Then: pip install PyOpenGL PyOpenGL-accelerate"
|
||||
)
|
||||
else:
|
||||
platform_help = (
|
||||
"Linux: Install one of these backends:\n"
|
||||
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
|
||||
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
|
||||
" Headless (CPU): sudo apt install libosmesa6"
|
||||
)
|
||||
|
||||
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenGL context.\n\n"
|
||||
f"Backend errors:\n{error_details}\n\n"
|
||||
f"{platform_help}"
|
||||
self._surface = EGL.eglCreatePbufferSurface(
|
||||
self._display, config,
|
||||
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
|
||||
)
|
||||
if not self._surface:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
# Now import OpenGL.GL (after context is current)
|
||||
logger.debug("GLContext.__init__: importing OpenGL.GL")
|
||||
_import_opengl()
|
||||
self._context = EGL.eglCreateContext(
|
||||
self._display, config, EGL.EGL_NO_CONTEXT,
|
||||
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
|
||||
)
|
||||
if not self._context:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
# Create VAO (required for core profile, but OSMesa may use compat profile)
|
||||
logger.debug("GLContext.__init__: creating VAO")
|
||||
try:
|
||||
vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(vao)
|
||||
self._vao = vao # Only store after successful bind
|
||||
logger.debug("GLContext.__init__: VAO created successfully")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
|
||||
# OSMesa with older Mesa may not support VAOs
|
||||
# Clean up if we created but couldn't bind
|
||||
if vao:
|
||||
try:
|
||||
gl.glDeleteVertexArrays(1, [vao])
|
||||
except Exception:
|
||||
pass
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
self._vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
except Exception:
|
||||
self._cleanup()
|
||||
raise
|
||||
|
||||
elapsed = (time.perf_counter() - start) * 1000
|
||||
|
||||
# Log device info
|
||||
renderer = gl.glGetString(gl.GL_RENDERER)
|
||||
vendor = gl.glGetString(gl.GL_VENDOR)
|
||||
version = gl.glGetString(gl.GL_VERSION)
|
||||
renderer = renderer.decode() if renderer else "Unknown"
|
||||
vendor = vendor.decode() if vendor else "Unknown"
|
||||
version = version.decode() if version else "Unknown"
|
||||
renderer = _gl_str(gl.GL_RENDERER)
|
||||
vendor = _gl_str(gl.GL_VENDOR)
|
||||
version = _gl_str(gl.GL_VERSION)
|
||||
|
||||
GLContext._initialized = True
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}")
|
||||
|
||||
def make_current(self):
|
||||
if self._backend == "glfw":
|
||||
glfw.make_context_current(self._window)
|
||||
elif self._backend == "egl":
|
||||
from OpenGL.EGL import eglMakeCurrent
|
||||
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
|
||||
elif self._backend == "osmesa":
|
||||
from OpenGL.osmesa import OSMesaMakeCurrent
|
||||
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
|
||||
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
err = EGL.eglGetError()
|
||||
raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})")
|
||||
if self._vao is not None:
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
def _cleanup(self):
|
||||
if not self._display:
|
||||
return
|
||||
try:
|
||||
if self._vao is not None:
|
||||
gl.glDeleteVertexArrays(1, [self._vao])
|
||||
self._vao = None
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._context:
|
||||
EGL.eglDestroyContext(self._display, self._context)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._surface:
|
||||
EGL.eglDestroySurface(self._display, self._surface)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglTerminate(self._display)
|
||||
except Exception:
|
||||
pass
|
||||
self._display = None
|
||||
|
||||
|
||||
def _compile_shader(source: str, shader_type: int) -> int:
|
||||
"""Compile a shader and return its ID."""
|
||||
@ -459,8 +346,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
|
||||
gl.glShaderSource(shader, source)
|
||||
gl.glCompileShader(shader)
|
||||
|
||||
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetShaderInfoLog(shader).decode()
|
||||
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
|
||||
error = gl.glGetShaderInfoLog(shader)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteShader(shader)
|
||||
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
||||
|
||||
@ -484,8 +373,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
|
||||
gl.glDeleteShader(vertex_shader)
|
||||
gl.glDeleteShader(fragment_shader)
|
||||
|
||||
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetProgramInfoLog(program).decode()
|
||||
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
|
||||
error = gl.glGetProgramInfoLog(program)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteProgram(program)
|
||||
raise RuntimeError(f"Program linking failed:\n{error}")
|
||||
|
||||
@ -530,9 +421,6 @@ def _render_shader_batch(
|
||||
ctx = GLContext()
|
||||
ctx.make_current()
|
||||
|
||||
# Convert from GLSL ES to desktop GLSL 330
|
||||
fragment_source = _convert_es_to_desktop(fragment_code)
|
||||
|
||||
# Detect how many outputs the shader actually uses
|
||||
num_outputs = _detect_output_count(fragment_code)
|
||||
|
||||
@ -558,9 +446,9 @@ def _render_shader_batch(
|
||||
try:
|
||||
# Compile shaders (once for all batches)
|
||||
try:
|
||||
program = _create_program(VERTEX_SHADER, fragment_source)
|
||||
program = _create_program(VERTEX_SHADER, fragment_code)
|
||||
except RuntimeError:
|
||||
logger.error(f"Fragment shader:\n{fragment_source}")
|
||||
logger.error(f"Fragment shader:\n{fragment_code}")
|
||||
raise
|
||||
|
||||
gl.glUseProgram(program)
|
||||
@ -723,13 +611,13 @@ def _render_shader_batch(
|
||||
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
||||
|
||||
# Read back outputs for this batch
|
||||
# (glGetTexImage is synchronous, implicitly waits for rendering)
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
||||
batch_outputs = []
|
||||
for tex in output_textures:
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
||||
batch_outputs.append(img[::-1, :, :].copy())
|
||||
for i in range(num_outputs):
|
||||
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
|
||||
buf = np.empty((height, width, 4), dtype=np.float32)
|
||||
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
|
||||
batch_outputs.append(buf[::-1, :, :].copy())
|
||||
|
||||
# Pad with black images for unused outputs
|
||||
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
||||
@ -750,18 +638,18 @@ def _render_shader_batch(
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
||||
gl.glUseProgram(0)
|
||||
|
||||
for tex in input_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in curve_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in output_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in ping_pong_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
if input_textures:
|
||||
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||
if curve_textures:
|
||||
gl.glDeleteTextures(len(curve_textures), curve_textures)
|
||||
if output_textures:
|
||||
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||
if ping_pong_textures:
|
||||
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
|
||||
if fbo is not None:
|
||||
gl.glDeleteFramebuffers(1, [fbo])
|
||||
for pp_fbo in ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(1, [pp_fbo])
|
||||
if ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
|
||||
if program is not None:
|
||||
gl.glDeleteProgram(program)
|
||||
|
||||
|
||||
55
execution.py
55
execution.py
@ -1113,32 +1113,6 @@ def full_type_name(klass):
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
def node_not_executable_reason(class_def, class_type):
|
||||
"""Return a human-readable reason the node cannot be executed, or None if it's fine.
|
||||
|
||||
Catches a node whose declared entry point doesn't resolve to a real method
|
||||
(e.g. a V1 ``FUNCTION = "invert"`` where the method is misspelled, or a V3 node
|
||||
missing its ``execute`` override). Running this during validation surfaces the
|
||||
problem before execution starts, instead of after upstream nodes have run.
|
||||
|
||||
Only the class is inspected; the node is never instantiated here, so a node's
|
||||
``__init__`` side effects cannot run (or fail) during validation.
|
||||
"""
|
||||
try:
|
||||
if issubclass(class_def, _ComfyNodeInternal):
|
||||
# V3: validates that execute()/define_schema() overrides exist.
|
||||
class_def.VALIDATE_CLASS()
|
||||
return None
|
||||
# V1: FUNCTION names the method to call; it must exist on the class.
|
||||
function_name = getattr(class_def, "FUNCTION", None)
|
||||
if function_name is None:
|
||||
return f"'{class_type}' does not define FUNCTION"
|
||||
if not callable(getattr(class_def, function_name, None)):
|
||||
return f"'{class_type}' has no method '{function_name}' (declared in FUNCTION)"
|
||||
return None
|
||||
except Exception as ex:
|
||||
return str(ex)
|
||||
|
||||
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
@ -1174,35 +1148,6 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[
|
||||
}
|
||||
return (False, error, [], {})
|
||||
|
||||
# Make sure the node is actually executable (its FUNCTION/execute entry
|
||||
# point resolves to a real method) before we touch any schema-derived
|
||||
# attributes below or start execution. Catches code typos up front and
|
||||
# attributes the error to the offending node.
|
||||
not_executable = node_not_executable_reason(class_, class_type)
|
||||
if not_executable is not None:
|
||||
node_title = prompt[x].get('_meta', {}).get('title', class_type)
|
||||
error = {
|
||||
"type": "invalid_node_definition",
|
||||
"message": "Node is not executable",
|
||||
"details": f"{not_executable} (Node ID '#{x}')",
|
||||
"extra_info": {
|
||||
"node_id": x,
|
||||
"class_type": class_type,
|
||||
"node_title": node_title,
|
||||
}
|
||||
}
|
||||
node_errors = {x: {
|
||||
"errors": [{
|
||||
"type": "invalid_node_definition",
|
||||
"message": "Node is not executable",
|
||||
"details": not_executable,
|
||||
"extra_info": {},
|
||||
}],
|
||||
"dependent_outputs": [],
|
||||
"class_type": class_type,
|
||||
}}
|
||||
return (False, error, [], node_errors)
|
||||
|
||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||
if partial_execution_list is None or x in partial_execution_list:
|
||||
outputs.add(x)
|
||||
|
||||
@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=16.0.0
|
||||
comfy-kitchen==0.2.12
|
||||
comfy-kitchen==0.2.13
|
||||
comfy-aimdo==0.4.10
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
@ -33,5 +33,5 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
PyOpenGL
|
||||
glfw
|
||||
PyOpenGL>=3.1.8
|
||||
comfy-angle
|
||||
|
||||
@ -1,137 +0,0 @@
|
||||
"""Tests for pre-execution validation that a node is actually executable.
|
||||
|
||||
validate_prompt rejects a node whose declared entry point does not resolve to a
|
||||
real method (a V1 FUNCTION typo, or a V3 node missing its execute override) before
|
||||
any node runs, attributing the error to the offending node.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
import nodes
|
||||
from comfy_api.latest import io
|
||||
from execution import node_not_executable_reason, validate_prompt
|
||||
|
||||
|
||||
class _GoodV1Node:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "run"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "Test"
|
||||
|
||||
def run(self):
|
||||
return (None,)
|
||||
|
||||
|
||||
class _TypoV1Node:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "invert" # method below is misspelled
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "Test"
|
||||
|
||||
def invvert(self):
|
||||
return (None,)
|
||||
|
||||
|
||||
class _SideEffectInitV1Node:
|
||||
"""Valid class-level method, but a constructor that must never run in validation."""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "run"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "Test"
|
||||
|
||||
def __init__(self):
|
||||
raise RuntimeError("__init__ must not run during validation")
|
||||
|
||||
def run(self):
|
||||
return (None,)
|
||||
|
||||
|
||||
def _v3_schema(node_id):
|
||||
return io.Schema(
|
||||
node_id=node_id,
|
||||
display_name=node_id,
|
||||
category="Test",
|
||||
inputs=[],
|
||||
outputs=[io.Image.Output()],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
|
||||
class _GoodV3Node(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return _v3_schema("GoodV3Node")
|
||||
|
||||
@classmethod
|
||||
def execute(cls):
|
||||
return io.NodeOutput(None)
|
||||
|
||||
|
||||
class _TypoV3Node(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return _v3_schema("TypoV3Node")
|
||||
|
||||
@classmethod
|
||||
def exicute(cls): # typo: should be "execute"
|
||||
return io.NodeOutput(None)
|
||||
|
||||
|
||||
def _register(class_type, class_def):
|
||||
nodes.NODE_CLASS_MAPPINGS[class_type] = class_def
|
||||
|
||||
|
||||
def _validate(class_type):
|
||||
prompt = {"1": {"class_type": class_type, "inputs": {}}}
|
||||
return asyncio.run(validate_prompt("pid", prompt, None))
|
||||
|
||||
|
||||
def test_good_node_passes():
|
||||
_register("GoodV1Node", _GoodV1Node)
|
||||
assert node_not_executable_reason(_GoodV1Node, "GoodV1Node") is None
|
||||
valid, _, _, _ = _validate("GoodV1Node")
|
||||
assert valid is True
|
||||
|
||||
|
||||
def test_typo_node_rejected_with_node_error():
|
||||
_register("TypoV1Node", _TypoV1Node)
|
||||
valid, error, _, node_errors = _validate("TypoV1Node")
|
||||
assert valid is False
|
||||
assert error["type"] == "invalid_node_definition"
|
||||
assert node_errors["1"]["class_type"] == "TypoV1Node"
|
||||
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"
|
||||
assert "invert" in node_errors["1"]["errors"][0]["details"]
|
||||
|
||||
|
||||
def test_validation_does_not_instantiate_node():
|
||||
"""A valid node is not constructed during validation, so __init__ never runs."""
|
||||
_register("SideEffectInitV1Node", _SideEffectInitV1Node)
|
||||
assert node_not_executable_reason(_SideEffectInitV1Node, "SideEffectInitV1Node") is None
|
||||
valid, _, _, _ = _validate("SideEffectInitV1Node")
|
||||
assert valid is True
|
||||
|
||||
|
||||
def test_good_v3_node_passes():
|
||||
_register("GoodV3Node", _GoodV3Node)
|
||||
assert node_not_executable_reason(_GoodV3Node, "GoodV3Node") is None
|
||||
valid, _, _, _ = _validate("GoodV3Node")
|
||||
assert valid is True
|
||||
|
||||
|
||||
def test_typo_v3_node_rejected_with_node_error():
|
||||
_register("TypoV3Node", _TypoV3Node)
|
||||
valid, error, _, node_errors = _validate("TypoV3Node")
|
||||
assert valid is False
|
||||
assert error["type"] == "invalid_node_definition"
|
||||
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"
|
||||
Reference in New Issue
Block a user