mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-30 18:07:30 +08:00
Compare commits
2 Commits
feat/api-n
...
fix/valida
| Author | SHA1 | Date | |
|---|---|---|---|
| bf00c39705 | |||
| 82c954bd2a |
38
.github/workflows/ci-cursor-review.yml
vendored
38
.github/workflows/ci-cursor-review.yml
vendored
@ -1,38 +0,0 @@
|
||||
name: CI - Cursor Review
|
||||
|
||||
# Thin caller for the shared reusable cursor-review workflow in
|
||||
# Comfy-Org/github-workflows. The review logic (panel matrix, judge
|
||||
# consolidation, prompts, extract/post/notify scripts) lives there as the
|
||||
# single source of truth, so this repo only carries the repo-specific diff
|
||||
# excludes.
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [labeled, unlabeled]
|
||||
|
||||
concurrency:
|
||||
group: cursor-review-pr-${{ github.event.pull_request.number }}-${{ github.event.label.name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
cursor-review:
|
||||
if: github.event.label.name == 'cursor-review'
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
# SHA-pinned per zizmor `unpinned-uses: hash-pin`. Bump this SHA to pick up
|
||||
# upstream changes; keep `workflows_ref` matching so prompts/scripts load
|
||||
# from the same commit as the workflow definition.
|
||||
uses: Comfy-Org/github-workflows/.github/workflows/cursor-review.yml@047ca48febe3a6647608ed2e0c4331b491cb9d6a # github-workflows#9
|
||||
with:
|
||||
workflows_ref: 047ca48febe3a6647608ed2e0c4331b491cb9d6a
|
||||
diff_excludes: >-
|
||||
:!**/.claude/**
|
||||
:!**/dist/**
|
||||
:!**/vendor/**
|
||||
:!**/*.generated.*
|
||||
:!**/*.min.js
|
||||
:!**/*.min.css
|
||||
secrets:
|
||||
CURSOR_API_KEY: ${{ secrets.CURSOR_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
|
||||
if (want_requant and len(fns) == 0 or update_weight):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
if isinstance(orig, QuantizedTensor):
|
||||
y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed)
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
else:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
|
||||
if want_requant and len(fns) == 0:
|
||||
@ -1216,7 +1216,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
bias_dtype=input.dtype,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=True,
|
||||
want_requant=want_requant,
|
||||
)
|
||||
weight = weight.to(dtype=input.dtype)
|
||||
else:
|
||||
@ -1306,7 +1306,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
# dtype is now implicit in the layout class
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
|
||||
@ -1,82 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FalQueueSubmit(BaseModel):
|
||||
request_id: str = Field(...)
|
||||
status: str | None = Field(None)
|
||||
|
||||
|
||||
class FalQueueStatus(BaseModel):
|
||||
status: str | None = Field(None)
|
||||
|
||||
|
||||
class PatinaImage(BaseModel):
|
||||
url: str = Field(...)
|
||||
map_type: str | None = Field(None, description="PBR map type; None for a base texture image.")
|
||||
width: int | None = Field(None)
|
||||
height: int | None = Field(None)
|
||||
content_type: str | None = Field(None)
|
||||
|
||||
|
||||
class PatinaResult(BaseModel):
|
||||
images: list[PatinaImage] = Field(default_factory=list)
|
||||
seed: int | None = Field(None)
|
||||
prompt: str | None = Field(None)
|
||||
|
||||
|
||||
class ImageSize(BaseModel):
|
||||
width: int = Field(...)
|
||||
height: int = Field(...)
|
||||
|
||||
|
||||
class PatinaPBRMapsRequest(BaseModel):
|
||||
"""fal-ai/patina — image -> PBR maps."""
|
||||
|
||||
image_url: str = Field(...)
|
||||
maps: list[str] | None = Field(None)
|
||||
seed: int | None = Field(None)
|
||||
output_format: str = Field("png")
|
||||
enable_safety_checker: bool = Field(False)
|
||||
|
||||
|
||||
class PatinaMaterialRequest(BaseModel):
|
||||
"""fal-ai/patina/material — text (+optional img2img/inpaint) -> tileable material."""
|
||||
|
||||
prompt: str = Field(...)
|
||||
image_size: str | ImageSize = Field("square_hd")
|
||||
maps: list[str] | None = Field(None)
|
||||
upscale_factor: int = Field(0, description="0, 2, or 4 - SeedVR upscaling of the PBR maps.")
|
||||
tiling_mode: str = Field("both")
|
||||
num_inference_steps: int = Field(8)
|
||||
enable_prompt_expansion: bool = Field(False)
|
||||
enable_safety_checker: bool = Field(False)
|
||||
tile_size: int = Field(128, description="Tile size in latent space (64 = 512px, 128 = 1024px).")
|
||||
tile_stride: int = Field(64, description="Tile stride in latent space.")
|
||||
image_url: str | None = Field(
|
||||
None, description="Optional source for img2img, or inpaint when combined with mask_url."
|
||||
)
|
||||
mask_url: str | None = Field(
|
||||
None, description="Inpaint mask (white = regenerate, black = keep); requires image_url."
|
||||
)
|
||||
strength: float = Field(0.6)
|
||||
seed: int | None = Field(None)
|
||||
output_format: str = Field("png")
|
||||
|
||||
|
||||
class PatinaExtractRequest(BaseModel):
|
||||
"""fal-ai/patina/material/extract — image + prompt -> tileable material (no inpainting)."""
|
||||
|
||||
prompt: str = Field(...)
|
||||
image_url: str = Field(...)
|
||||
image_size: str | ImageSize = Field("square_hd")
|
||||
maps: list[str] | None = Field(None)
|
||||
upscale_factor: int = Field(0, description="0, 2, or 4 - SeedVR upscaling of the PBR maps.")
|
||||
tiling_mode: str = Field("both")
|
||||
num_inference_steps: int = Field(8)
|
||||
enable_prompt_expansion: bool = Field(False)
|
||||
enable_safety_checker: bool = Field(False)
|
||||
tile_size: int = Field(128, description="Tile size in latent space (64 = 512px, 128 = 1024px).")
|
||||
tile_stride: int = Field(64, description="Tile stride in latent space.")
|
||||
strength: float = Field(0.6)
|
||||
seed: int | None = Field(None)
|
||||
output_format: str = Field("png")
|
||||
@ -1,582 +0,0 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.patina import (
|
||||
FalQueueStatus,
|
||||
FalQueueSubmit,
|
||||
ImageSize,
|
||||
PatinaExtractRequest,
|
||||
PatinaMaterialRequest,
|
||||
PatinaPBRMapsRequest,
|
||||
PatinaResult,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
bytesio_to_image_tensor,
|
||||
convert_mask_to_image,
|
||||
download_url_as_bytesio,
|
||||
downscale_image_tensor_by_max_side,
|
||||
poll_op,
|
||||
resize_mask_to_image,
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
|
||||
PATINA_MAPS = ["basecolor", "normal", "roughness", "metalness", "height"]
|
||||
_IMAGE_SIZES = [
|
||||
("1:1 (1024x1024)", "square_hd", 1024, 1024),
|
||||
("1:1 (512x512)", "square", 512, 512),
|
||||
("4:3 (1024x768)", "landscape_4_3", 1024, 768),
|
||||
("3:4 (768x1024)", "portrait_4_3", 768, 1024),
|
||||
("16:9 (1024x576)", "landscape_16_9", 1024, 576),
|
||||
("9:16 (576x1024)", "portrait_16_9", 576, 1024),
|
||||
]
|
||||
_LABEL_TO_PRESET = {label: preset for label, preset, _, _ in _IMAGE_SIZES}
|
||||
_PRESET_MP = json.dumps({label: w * h / 1048576 for label, _, w, h in _IMAGE_SIZES})
|
||||
# nMaps from the five boolean map toggles (BOOLEAN widgets reach JSONata as true/false).
|
||||
_NMAPS = (
|
||||
"(widgets.basecolor?1:0)+(widgets.normal?1:0)+(widgets.roughness?1:0)+(widgets.metalness?1:0)+(widgets.height?1:0)"
|
||||
)
|
||||
|
||||
|
||||
async def _run_patina(cls: type[IO.ComfyNode], model_id: str, request) -> PatinaResult:
|
||||
submit = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/fal/{model_id}", method="POST"),
|
||||
response_model=FalQueueSubmit,
|
||||
data=request,
|
||||
)
|
||||
await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/fal/fal-ai/patina/requests/{submit.request_id}/status"),
|
||||
response_model=FalQueueStatus,
|
||||
status_extractor=lambda r: r.status,
|
||||
poll_interval=3.0,
|
||||
)
|
||||
return await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/fal/fal-ai/patina/requests/{submit.request_id}"),
|
||||
response_model=PatinaResult,
|
||||
)
|
||||
|
||||
|
||||
async def _download_rgb(cls: type[IO.ComfyNode], url: str) -> torch.Tensor:
|
||||
"""Download an image as a 3-channel (B,H,W,3) tensor, matching the blank-map placeholder."""
|
||||
return bytesio_to_image_tensor(await download_url_as_bytesio(url, cls=cls), mode="RGB")
|
||||
|
||||
|
||||
async def _map_outputs(cls: type[IO.ComfyNode], result: PatinaResult) -> tuple[torch.Tensor, ...]:
|
||||
"""One tensor per entry in PATINA_MAPS; a 1x1 black placeholder for any map not returned."""
|
||||
by_type = {img.map_type: img for img in result.images if img.map_type}
|
||||
outputs = []
|
||||
for name in PATINA_MAPS:
|
||||
img = by_type.get(name)
|
||||
outputs.append(await _download_rgb(cls, img.url) if img else torch.zeros(1, 1, 1, 3))
|
||||
return tuple(outputs)
|
||||
|
||||
|
||||
async def _base_texture(cls: type[IO.ComfyNode], result: PatinaResult) -> torch.Tensor:
|
||||
"""The single tileable base texture (the item without a map_type); blank 1x1 if absent."""
|
||||
texture = next((img for img in result.images if not img.map_type), None)
|
||||
if texture is None:
|
||||
return torch.zeros(1, 1, 1, 3)
|
||||
return await _download_rgb(cls, texture.url)
|
||||
|
||||
|
||||
def _selected_maps(basecolor: bool, normal: bool, roughness: bool, metalness: bool, height: bool) -> list[str]:
|
||||
flags = {
|
||||
"basecolor": basecolor,
|
||||
"normal": normal,
|
||||
"roughness": roughness,
|
||||
"metalness": metalness,
|
||||
"height": height,
|
||||
}
|
||||
return [m for m in PATINA_MAPS if flags[m]]
|
||||
|
||||
|
||||
def _resolve_image_size(image_size: dict[str, Any]) -> str | ImageSize:
|
||||
"""DynamicCombo -> a preset string, or an ImageSize object when 'custom' is selected."""
|
||||
key = image_size.get("image_size") if isinstance(image_size, dict) else None
|
||||
if key == "custom":
|
||||
return ImageSize(width=int(image_size["width"]), height=int(image_size["height"]))
|
||||
return _LABEL_TO_PRESET.get(key, "square_hd")
|
||||
|
||||
|
||||
def _image_size_input() -> IO.DynamicCombo.Input:
|
||||
return IO.DynamicCombo.Input(
|
||||
"image_size",
|
||||
options=[IO.DynamicCombo.Option(label, []) for label, _, _, _ in _IMAGE_SIZES]
|
||||
+ [
|
||||
IO.DynamicCombo.Option(
|
||||
"custom",
|
||||
[
|
||||
IO.Int.Input("width", default=1024, min=512, max=2048, step=8),
|
||||
IO.Int.Input("height", default=1024, min=512, max=2048, step=8),
|
||||
],
|
||||
)
|
||||
],
|
||||
tooltip="Output texture size. Choose 'custom' for a width/height between 512 and 2048 "
|
||||
"(FAL's base-texture limits; an 8K result comes from 4x upscaling the maps).",
|
||||
)
|
||||
|
||||
|
||||
def _map_toggle_inputs() -> list[IO.Boolean.Input]:
|
||||
"""Five per-map toggles; each maps 1:1 to its output socket."""
|
||||
return [
|
||||
IO.Boolean.Input("basecolor", default=True, tooltip="Generate the basecolor (albedo) map."),
|
||||
IO.Boolean.Input("normal", default=True, tooltip="Generate the normal map."),
|
||||
IO.Boolean.Input("roughness", default=False, tooltip="Generate the roughness map."),
|
||||
IO.Boolean.Input("metalness", default=False, tooltip="Generate the metalness map."),
|
||||
IO.Boolean.Input("height", default=False, tooltip="Generate the height/displacement map."),
|
||||
]
|
||||
|
||||
class PatinaPBRMapsNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PatinaPBRMapsNode",
|
||||
display_name="Patina PBR Maps",
|
||||
category="partner/3d/FAL",
|
||||
essentials_category="3D",
|
||||
description="Generate seamless PBR maps (basecolor, normal, roughness, metalness, height) "
|
||||
"from a photo or render via fal.ai PATINA.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="Input photograph or render to derive PBR maps from."),
|
||||
*_map_toggle_inputs(),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483646,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for reproducible denoising.",
|
||||
),
|
||||
IO.Boolean.Input("safety_checker", default=False, advanced=True),
|
||||
IO.Boolean.Input(
|
||||
"auto_downscale",
|
||||
default=True,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Automatically downscale an input image whose longest side exceeds 2048px "
|
||||
"(fal.ai PATINA's input limit), preserving aspect ratio; smaller images are left as-is. "
|
||||
"Disable to raise an error on oversized images instead.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
*[IO.Image.Output(m) for m in PATINA_MAPS],
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=list(PATINA_MAPS)),
|
||||
expr=f"""
|
||||
(
|
||||
$n := {_NMAPS};
|
||||
{{"type":"range_usd","min_usd": 0.0143 + 0.0143*$n, "max_usd": 0.0143 + 0.0572*$n, "format":{{"approximate":true}}}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
basecolor: bool = True,
|
||||
normal: bool = True,
|
||||
roughness: bool = False,
|
||||
metalness: bool = False,
|
||||
height: bool = False,
|
||||
seed: int = 0,
|
||||
safety_checker: bool = False,
|
||||
auto_downscale: bool = True,
|
||||
) -> IO.NodeOutput:
|
||||
maps = _selected_maps(basecolor, normal, roughness, metalness, height)
|
||||
if not maps:
|
||||
raise ValueError("Enable at least one PBR map to generate.")
|
||||
if auto_downscale:
|
||||
image = downscale_image_tensor_by_max_side(image, max_side=2048)
|
||||
else:
|
||||
validate_image_dimensions(image, max_width=2048, max_height=2048)
|
||||
image_url = await upload_image_to_comfyapi(cls, image, mime_type="image/png")
|
||||
result = await _run_patina(
|
||||
cls,
|
||||
"fal-ai/patina",
|
||||
PatinaPBRMapsRequest(
|
||||
image_url=image_url,
|
||||
maps=maps,
|
||||
seed=seed,
|
||||
enable_safety_checker=safety_checker,
|
||||
),
|
||||
)
|
||||
basecolor_t, normal_t, roughness_t, metalness_t, height_t = await _map_outputs(cls, result)
|
||||
return IO.NodeOutput(basecolor_t, normal_t, roughness_t, metalness_t, height_t)
|
||||
|
||||
|
||||
class PatinaMaterialNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PatinaMaterialNode",
|
||||
display_name="Patina Material",
|
||||
category="partner/3d/FAL",
|
||||
essentials_category="3D",
|
||||
description="Generate a complete seamlessly tiling PBR material (base texture + maps, up to 8K) "
|
||||
"from a text prompt via fal.ai PATINA. Optionally drive it with an input image (img2img) "
|
||||
"or an image + mask (inpaint).",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Describe the material/texture to generate."),
|
||||
_image_size_input(),
|
||||
*_map_toggle_inputs(),
|
||||
IO.Int.Input(
|
||||
"upscale_factor",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4,
|
||||
step=2,
|
||||
tooltip="Seamless SeedVR upscaling of the PBR maps (the base texture is not upscaled).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483646,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for reproducible generation.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"tiling_mode",
|
||||
options=["both", "horizontal", "vertical"],
|
||||
default="both",
|
||||
advanced=True,
|
||||
tooltip="Tiling direction: omnidirectional, horizontal, or vertical.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"num_inference_steps",
|
||||
default=8,
|
||||
min=1,
|
||||
max=8,
|
||||
advanced=True,
|
||||
tooltip="Denoising steps for texture generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"tile_size",
|
||||
default=128,
|
||||
min=32,
|
||||
max=256,
|
||||
advanced=True,
|
||||
tooltip="Tile size in latent space (64 = 512px, 128 = 1024px).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"tile_stride", default=64, min=16, max=128, advanced=True, tooltip="Tile stride in latent space."
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
tooltip="Optional source image. Provided alone = img2img; with mask = inpaint.",
|
||||
),
|
||||
IO.Mask.Input(
|
||||
"mask",
|
||||
optional=True,
|
||||
tooltip="Optional inpaint mask (requires image). White = regenerate, black = keep.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"strength",
|
||||
default=0.6,
|
||||
min=0.01,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
advanced=True,
|
||||
tooltip="How much to transform the input image. Only used when an image is provided.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"prompt_expansion",
|
||||
default=False,
|
||||
advanced=True,
|
||||
tooltip="Expand the prompt with an LLM for richer texture detail. Off by default: "
|
||||
"expansion reframes the prompt as a photo and tends to wash out the metalness map.",
|
||||
),
|
||||
IO.Boolean.Input("safety_checker", default=False, advanced=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output("texture"),
|
||||
*[IO.Image.Output(m) for m in PATINA_MAPS],
|
||||
IO.String.Output("expanded_prompt"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["image_size", "image_size.width", "image_size.height", *PATINA_MAPS, "upscale_factor"]
|
||||
),
|
||||
expr=f"""
|
||||
(
|
||||
$mp := $ceil(widgets.image_size = "custom"
|
||||
? ($lookup(widgets, "image_size.width") * $lookup(widgets, "image_size.height")) / 1048576
|
||||
: $lookup({_PRESET_MP}, widgets.image_size));
|
||||
$n := {_NMAPS};
|
||||
$up := widgets.upscale_factor = 4 ? 0.02288 : widgets.upscale_factor = 2 ? 0.00572 : 0;
|
||||
{{"type":"usd","usd": 0.0143 + 0.0286*$mp + $mp*$n*(0.0143+$up)}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
image_size: dict[str, Any],
|
||||
basecolor: bool = True,
|
||||
normal: bool = True,
|
||||
roughness: bool = False,
|
||||
metalness: bool = False,
|
||||
height: bool = False,
|
||||
upscale_factor: int = 0,
|
||||
seed: int = 0,
|
||||
tiling_mode: str = "both",
|
||||
num_inference_steps: int = 8,
|
||||
tile_size: int = 128,
|
||||
tile_stride: int = 64,
|
||||
image: Input.Image | None = None,
|
||||
mask: Input.Mask | None = None,
|
||||
strength: float = 0.6,
|
||||
prompt_expansion: bool = False,
|
||||
safety_checker: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1)
|
||||
if mask is not None and image is None:
|
||||
raise ValueError("A mask requires an input image (inpaint mode).")
|
||||
image_url = None
|
||||
mask_url = None
|
||||
if image is not None:
|
||||
image_url = await upload_image_to_comfyapi(cls, image, mime_type="image/png")
|
||||
if mask is not None:
|
||||
mask_url = await upload_image_to_comfyapi(
|
||||
cls,
|
||||
convert_mask_to_image(resize_mask_to_image(mask, image, allow_gradient=False)),
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading mask",
|
||||
)
|
||||
result = await _run_patina(
|
||||
cls,
|
||||
"fal-ai/patina/material",
|
||||
PatinaMaterialRequest(
|
||||
prompt=prompt,
|
||||
image_size=_resolve_image_size(image_size),
|
||||
maps=_selected_maps(basecolor, normal, roughness, metalness, height),
|
||||
upscale_factor=upscale_factor,
|
||||
tiling_mode=tiling_mode,
|
||||
num_inference_steps=num_inference_steps,
|
||||
enable_prompt_expansion=prompt_expansion,
|
||||
enable_safety_checker=safety_checker,
|
||||
tile_size=tile_size,
|
||||
tile_stride=tile_stride,
|
||||
image_url=image_url,
|
||||
mask_url=mask_url,
|
||||
strength=strength,
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
texture = await _base_texture(cls, result)
|
||||
basecolor_t, normal_t, roughness_t, metalness_t, height_t = await _map_outputs(cls, result)
|
||||
return IO.NodeOutput(
|
||||
texture,
|
||||
basecolor_t,
|
||||
normal_t,
|
||||
roughness_t,
|
||||
metalness_t,
|
||||
height_t,
|
||||
result.prompt or prompt,
|
||||
)
|
||||
|
||||
|
||||
class PatinaMaterialExtractNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PatinaMaterialExtractNode",
|
||||
display_name="Patina Material Extract",
|
||||
category="partner/3d/FAL",
|
||||
essentials_category="3D",
|
||||
description="Extract a seamlessly tiling PBR material (base texture + maps) from a region of an "
|
||||
"input image, guided by a prompt, via fal.ai PATINA.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="Image to extract a texture from."),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip='Describe which texture to extract from the image (e.g. "the wall").',
|
||||
),
|
||||
_image_size_input(),
|
||||
*_map_toggle_inputs(),
|
||||
IO.Int.Input(
|
||||
"upscale_factor",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4,
|
||||
step=2,
|
||||
tooltip="Seamless SeedVR upscaling of the PBR maps (the base texture is not upscaled).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483646,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for reproducible generation.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"strength",
|
||||
default=0.6,
|
||||
min=0.01,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
advanced=True,
|
||||
tooltip="How much to transform the input image.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"tiling_mode",
|
||||
options=["both", "horizontal", "vertical"],
|
||||
default="both",
|
||||
advanced=True,
|
||||
tooltip="Tiling direction: omnidirectional, horizontal, or vertical.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"num_inference_steps",
|
||||
default=8,
|
||||
min=1,
|
||||
max=8,
|
||||
advanced=True,
|
||||
tooltip="Denoising steps for texture generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"tile_size",
|
||||
default=128,
|
||||
min=32,
|
||||
max=256,
|
||||
advanced=True,
|
||||
tooltip="Tile size in latent space (64 = 512px, 128 = 1024px).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"tile_stride", default=64, min=16, max=128, advanced=True, tooltip="Tile stride in latent space."
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"prompt_expansion",
|
||||
default=False,
|
||||
advanced=True,
|
||||
tooltip="Expand the prompt with an LLM for richer texture detail. Off by default: "
|
||||
"expansion reframes the prompt as a photo and tends to wash out the metalness map.",
|
||||
),
|
||||
IO.Boolean.Input("safety_checker", default=False, advanced=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output("texture"),
|
||||
*[IO.Image.Output(m) for m in PATINA_MAPS],
|
||||
IO.String.Output("expanded_prompt"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["image_size", "image_size.width", "image_size.height", *PATINA_MAPS, "upscale_factor"]
|
||||
),
|
||||
expr=f"""
|
||||
(
|
||||
$mp := $ceil(widgets.image_size = "custom"
|
||||
? ($lookup(widgets, "image_size.width") * $lookup(widgets, "image_size.height")) / 1048576
|
||||
: $lookup({_PRESET_MP}, widgets.image_size));
|
||||
$n := {_NMAPS};
|
||||
$up := widgets.upscale_factor = 4 ? 0.02288 : widgets.upscale_factor = 2 ? 0.00572 : 0;
|
||||
{{"type":"usd","usd": 0.143 + 0.0286*$mp + $mp*$n*(0.0143+$up)}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
image_size: dict[str, Any],
|
||||
basecolor: bool = True,
|
||||
normal: bool = True,
|
||||
roughness: bool = False,
|
||||
metalness: bool = False,
|
||||
height: bool = False,
|
||||
upscale_factor: int = 0,
|
||||
seed: int = 0,
|
||||
strength: float = 0.6,
|
||||
tiling_mode: str = "both",
|
||||
num_inference_steps: int = 8,
|
||||
tile_size: int = 128,
|
||||
tile_stride: int = 64,
|
||||
prompt_expansion: bool = False,
|
||||
safety_checker: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1)
|
||||
image_url = await upload_image_to_comfyapi(cls, image, mime_type="image/png")
|
||||
result = await _run_patina(
|
||||
cls,
|
||||
"fal-ai/patina/material/extract",
|
||||
PatinaExtractRequest(
|
||||
prompt=prompt,
|
||||
image_url=image_url,
|
||||
image_size=_resolve_image_size(image_size),
|
||||
maps=_selected_maps(basecolor, normal, roughness, metalness, height),
|
||||
upscale_factor=upscale_factor,
|
||||
tiling_mode=tiling_mode,
|
||||
num_inference_steps=num_inference_steps,
|
||||
enable_prompt_expansion=prompt_expansion,
|
||||
enable_safety_checker=safety_checker,
|
||||
tile_size=tile_size,
|
||||
tile_stride=tile_stride,
|
||||
strength=strength,
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
texture = await _base_texture(cls, result)
|
||||
basecolor_t, normal_t, roughness_t, metalness_t, height_t = await _map_outputs(cls, result)
|
||||
return IO.NodeOutput(
|
||||
texture,
|
||||
basecolor_t,
|
||||
normal_t,
|
||||
roughness_t,
|
||||
metalness_t,
|
||||
height_t,
|
||||
result.prompt or prompt,
|
||||
)
|
||||
|
||||
|
||||
class PatinaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
PatinaPBRMapsNode,
|
||||
PatinaMaterialNode,
|
||||
PatinaMaterialExtractNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> PatinaExtension:
|
||||
return PatinaExtension()
|
||||
@ -1,68 +1,85 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import ctypes
|
||||
import logging
|
||||
import ctypes.util
|
||||
import importlib.util
|
||||
from typing import TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import nodes
|
||||
import comfy_angle
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from typing_extensions import override
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _preload_angle():
|
||||
egl_path = comfy_angle.get_egl_path()
|
||||
gles_path = comfy_angle.get_glesv2_path()
|
||||
def _check_opengl_availability():
|
||||
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
|
||||
logger.debug("_check_opengl_availability: starting")
|
||||
missing = []
|
||||
|
||||
if sys.platform == "win32":
|
||||
angle_dir = comfy_angle.get_lib_dir()
|
||||
os.add_dll_directory(angle_dir)
|
||||
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
|
||||
# Check Python packages (using find_spec to avoid importing)
|
||||
logger.debug("_check_opengl_availability: checking for glfw package")
|
||||
if importlib.util.find_spec("glfw") is None:
|
||||
missing.append("glfw")
|
||||
|
||||
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
|
||||
ctypes.CDLL(str(egl_path), mode=mode)
|
||||
ctypes.CDLL(str(gles_path), mode=mode)
|
||||
logger.debug("_check_opengl_availability: checking for OpenGL package")
|
||||
if importlib.util.find_spec("OpenGL") is None:
|
||||
missing.append("PyOpenGL")
|
||||
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
||||
)
|
||||
|
||||
# On Linux without display, check if headless backends are available
|
||||
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
|
||||
if sys.platform.startswith("linux"):
|
||||
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
|
||||
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
|
||||
if not has_display:
|
||||
# Check for EGL or OSMesa libraries
|
||||
logger.debug("_check_opengl_availability: checking for EGL library")
|
||||
has_egl = ctypes.util.find_library("EGL")
|
||||
logger.debug("_check_opengl_availability: checking for OSMesa library")
|
||||
has_osmesa = ctypes.util.find_library("OSMesa")
|
||||
|
||||
# Error disabled for CI as it fails this check
|
||||
# if not has_egl and not has_osmesa:
|
||||
# raise RuntimeError(
|
||||
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
|
||||
# "See error below for installation instructions."
|
||||
# )
|
||||
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
|
||||
|
||||
logger.debug("_check_opengl_availability: completed")
|
||||
|
||||
|
||||
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
|
||||
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
|
||||
_preload_angle()
|
||||
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
# Run early check at import time
|
||||
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
|
||||
_check_opengl_availability()
|
||||
|
||||
# OpenGL modules - initialized lazily when context is created
|
||||
gl = None
|
||||
glfw = None
|
||||
EGL = None
|
||||
|
||||
|
||||
import OpenGL
|
||||
OpenGL.USE_ACCELERATE = False
|
||||
def _import_opengl():
|
||||
"""Import OpenGL module. Called after context is created."""
|
||||
global gl
|
||||
if gl is None:
|
||||
logger.debug("_import_opengl: importing OpenGL.GL")
|
||||
import OpenGL.GL as _gl
|
||||
gl = _gl
|
||||
logger.debug("_import_opengl: import completed")
|
||||
return gl
|
||||
|
||||
|
||||
def _patch_find_library():
|
||||
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
|
||||
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
|
||||
'libGLESv2'. Patch find_library to return the full ANGLE paths so
|
||||
PyOpenGL loads the same libraries we pre-loaded."""
|
||||
if sys.platform == "linux":
|
||||
return
|
||||
import ctypes.util
|
||||
_orig = ctypes.util.find_library
|
||||
def _patched(name):
|
||||
if name == 'EGL':
|
||||
return comfy_angle.get_egl_path()
|
||||
if name == 'GLESv2':
|
||||
return comfy_angle.get_glesv2_path()
|
||||
return _orig(name)
|
||||
ctypes.util.find_library = _patched
|
||||
|
||||
|
||||
_patch_find_library()
|
||||
|
||||
from OpenGL import EGL
|
||||
from OpenGL import GLES3 as gl
|
||||
|
||||
class SizeModeInput(TypedDict):
|
||||
size_mode: str
|
||||
width: int
|
||||
@ -85,7 +102,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
# (-1,-1)---(3,-1)
|
||||
#
|
||||
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
||||
VERTEX_SHADER = """#version 300 es
|
||||
VERTEX_SHADER = """#version 330 core
|
||||
out vec2 v_texCoord;
|
||||
void main() {
|
||||
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
||||
@ -109,99 +126,14 @@ void main() {
|
||||
"""
|
||||
|
||||
|
||||
|
||||
def _egl_attribs(*values):
|
||||
"""Build an EGL_NONE-terminated EGLint attribute array."""
|
||||
vals = list(values) + [EGL.EGL_NONE]
|
||||
return (ctypes.c_int32 * len(vals))(*vals)
|
||||
|
||||
|
||||
# EGL platform extension constants
|
||||
EGL_PLATFORM_ANGLE_ANGLE = 0x3202
|
||||
EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203
|
||||
EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450
|
||||
EGL_MESA_PLATFORM_SURFACELESS = 0x31DD
|
||||
|
||||
|
||||
_eglGetPlatformDisplayEXT = None
|
||||
|
||||
def _get_egl_platform_display_ext(platform, native_display, attribs):
|
||||
"""Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL)."""
|
||||
global _eglGetPlatformDisplayEXT
|
||||
if _eglGetPlatformDisplayEXT is None:
|
||||
from OpenGL import platform as _plat
|
||||
egl_lib = _plat.PLATFORM.EGL
|
||||
_get_proc = egl_lib.eglGetProcAddress
|
||||
_get_proc.restype = ctypes.c_void_p
|
||||
_get_proc.argtypes = [ctypes.c_char_p]
|
||||
ptr = _get_proc(b"eglGetPlatformDisplayEXT")
|
||||
if not ptr:
|
||||
return None
|
||||
func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p)
|
||||
_eglGetPlatformDisplayEXT = func_type(ptr)
|
||||
|
||||
raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs)
|
||||
if not raw:
|
||||
return None
|
||||
return ctypes.cast(raw, EGL.EGLDisplay)
|
||||
|
||||
|
||||
def _get_egl_display():
|
||||
"""Get an EGL display, trying the default first then ANGLE's Vulkan
|
||||
platform for headless environments without a display server."""
|
||||
failures = []
|
||||
|
||||
# Try the default display first (works when X11/Wayland is available)
|
||||
display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
|
||||
if display:
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
try:
|
||||
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
return display, major.value, minor.value
|
||||
except Exception as e:
|
||||
failures.append(f"default: {e}")
|
||||
|
||||
logger.info("Default EGL display unavailable, trying headless fallbacks")
|
||||
|
||||
# Headless fallback strategies, tried in order:
|
||||
headless_strategies = [
|
||||
("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None),
|
||||
("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None,
|
||||
_egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)),
|
||||
]
|
||||
|
||||
for name, platform, native_display, attribs in headless_strategies:
|
||||
display = _get_egl_platform_display_ext(platform, native_display, attribs)
|
||||
if not display:
|
||||
failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display")
|
||||
continue
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
try:
|
||||
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
logger.info(f"Using EGL {name} platform (headless)")
|
||||
return display, major.value, minor.value
|
||||
failures.append(f"{name}: eglInitialize returned false")
|
||||
except Exception as e:
|
||||
failures.append(f"{name}: {e}")
|
||||
continue
|
||||
|
||||
details = "\n".join(f" - {f}" for f in failures)
|
||||
raise RuntimeError(
|
||||
"Failed to initialize EGL display.\n"
|
||||
"No display server and no headless EGL platform available.\n"
|
||||
f"Tried:\n{details}\n"
|
||||
"Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer."
|
||||
)
|
||||
|
||||
|
||||
def _gl_str(name):
|
||||
"""Get an OpenGL string parameter."""
|
||||
v = gl.glGetString(name)
|
||||
if not v:
|
||||
return "Unknown"
|
||||
if isinstance(v, bytes):
|
||||
return v.decode(errors="replace")
|
||||
return ctypes.string_at(v).decode(errors="replace")
|
||||
def _convert_es_to_desktop(source: str) -> str:
|
||||
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
||||
# Remove any existing #version directive
|
||||
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
||||
# Remove precision qualifiers (not needed in desktop GLSL)
|
||||
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
||||
# Prepend desktop GLSL version
|
||||
return "#version 330 core\n" + source
|
||||
|
||||
|
||||
def _detect_output_count(source: str) -> int:
|
||||
@ -227,8 +159,163 @@ def _detect_pass_count(source: str) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _init_glfw():
|
||||
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_glfw: starting")
|
||||
# On macOS, glfw.init() must be called from main thread or it hangs forever
|
||||
if sys.platform == "darwin":
|
||||
logger.debug("_init_glfw: skipping on macOS")
|
||||
raise RuntimeError("GLFW backend not supported on macOS")
|
||||
|
||||
logger.debug("_init_glfw: importing glfw module")
|
||||
import glfw as _glfw
|
||||
|
||||
logger.debug("_init_glfw: calling glfw.init()")
|
||||
if not _glfw.init():
|
||||
raise RuntimeError("glfw.init() failed")
|
||||
|
||||
try:
|
||||
logger.debug("_init_glfw: setting window hints")
|
||||
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
|
||||
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
|
||||
|
||||
logger.debug("_init_glfw: calling create_window()")
|
||||
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
||||
if not window:
|
||||
raise RuntimeError("glfw.create_window() failed")
|
||||
|
||||
logger.debug("_init_glfw: calling make_context_current()")
|
||||
_glfw.make_context_current(window)
|
||||
logger.debug("_init_glfw: completed successfully")
|
||||
return window, _glfw
|
||||
except Exception:
|
||||
logger.debug("_init_glfw: failed, terminating glfw")
|
||||
_glfw.terminate()
|
||||
raise
|
||||
|
||||
|
||||
def _init_egl():
|
||||
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_egl: starting")
|
||||
from OpenGL import EGL as _EGL
|
||||
from OpenGL.EGL import (
|
||||
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
|
||||
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
|
||||
eglTerminate, eglDestroyContext, eglDestroySurface,
|
||||
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
|
||||
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
|
||||
)
|
||||
logger.debug("_init_egl: imports completed")
|
||||
|
||||
display = None
|
||||
context = None
|
||||
surface = None
|
||||
|
||||
try:
|
||||
logger.debug("_init_egl: calling eglGetDisplay()")
|
||||
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
|
||||
if display == _EGL.EGL_NO_DISPLAY:
|
||||
raise RuntimeError("eglGetDisplay() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglInitialize()")
|
||||
major, minor = _EGL.EGLint(), _EGL.EGLint()
|
||||
if not eglInitialize(display, major, minor):
|
||||
display = None # Not initialized, don't terminate
|
||||
raise RuntimeError("eglInitialize() failed")
|
||||
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
|
||||
|
||||
config_attribs = [
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
|
||||
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
|
||||
EGL_DEPTH_SIZE, 0, EGL_NONE
|
||||
]
|
||||
configs = (_EGL.EGLConfig * 1)()
|
||||
num_configs = _EGL.EGLint()
|
||||
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
config = configs[0]
|
||||
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
|
||||
|
||||
if not eglBindAPI(EGL_OPENGL_API):
|
||||
raise RuntimeError("eglBindAPI() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreateContext()")
|
||||
context_attribs = [
|
||||
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
|
||||
EGL_NONE
|
||||
]
|
||||
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
|
||||
if context == EGL_NO_CONTEXT:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
|
||||
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
|
||||
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
|
||||
if surface == _EGL.EGL_NO_SURFACE:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglMakeCurrent()")
|
||||
if not eglMakeCurrent(display, surface, surface, context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_egl: completed successfully")
|
||||
return display, context, surface, _EGL
|
||||
|
||||
except Exception:
|
||||
logger.debug("_init_egl: failed, cleaning up")
|
||||
# Clean up any resources on failure
|
||||
if surface is not None:
|
||||
eglDestroySurface(display, surface)
|
||||
if context is not None:
|
||||
eglDestroyContext(display, context)
|
||||
if display is not None:
|
||||
eglTerminate(display)
|
||||
raise
|
||||
|
||||
|
||||
def _init_osmesa():
|
||||
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
|
||||
import ctypes
|
||||
|
||||
logger.debug("_init_osmesa: starting")
|
||||
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
||||
|
||||
logger.debug("_init_osmesa: importing OpenGL.osmesa")
|
||||
from OpenGL import GL as _gl
|
||||
from OpenGL.osmesa import (
|
||||
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
|
||||
OSMESA_RGBA,
|
||||
)
|
||||
logger.debug("_init_osmesa: imports completed")
|
||||
|
||||
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
|
||||
if not ctx:
|
||||
raise RuntimeError("OSMesaCreateContextExt() failed")
|
||||
|
||||
width, height = 64, 64
|
||||
buffer = (ctypes.c_ubyte * (width * height * 4))()
|
||||
|
||||
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
|
||||
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
|
||||
OSMesaDestroyContext(ctx)
|
||||
raise RuntimeError("OSMesaMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_osmesa: completed successfully")
|
||||
return ctx, buffer
|
||||
|
||||
|
||||
class GLContext:
|
||||
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
|
||||
"""Manages OpenGL context and resources for shader execution.
|
||||
|
||||
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
@ -240,105 +327,131 @@ class GLContext:
|
||||
|
||||
def __init__(self):
|
||||
if GLContext._initialized:
|
||||
logger.debug("GLContext.__init__: already initialized, skipping")
|
||||
return
|
||||
|
||||
logger.debug("GLContext.__init__: starting initialization")
|
||||
|
||||
global glfw, EGL
|
||||
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
|
||||
self._display = None
|
||||
self._surface = None
|
||||
self._context = None
|
||||
self._backend = None
|
||||
self._window = None
|
||||
self._egl_display = None
|
||||
self._egl_context = None
|
||||
self._egl_surface = None
|
||||
self._osmesa_ctx = None
|
||||
self._osmesa_buffer = None
|
||||
self._vao = None
|
||||
|
||||
# Try backends in order: GLFW → EGL → OSMesa
|
||||
errors = []
|
||||
|
||||
logger.debug("GLContext.__init__: trying GLFW backend")
|
||||
try:
|
||||
self._display, self._egl_major, self._egl_minor = _get_egl_display()
|
||||
self._window, glfw = _init_glfw()
|
||||
self._backend = "glfw"
|
||||
logger.debug("GLContext.__init__: GLFW backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
|
||||
errors.append(("GLFW", e))
|
||||
|
||||
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
|
||||
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying EGL backend")
|
||||
try:
|
||||
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
|
||||
self._backend = "egl"
|
||||
logger.debug("GLContext.__init__: EGL backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
|
||||
errors.append(("EGL", e))
|
||||
|
||||
config = EGL.EGLConfig()
|
||||
n_configs = ctypes.c_int32(0)
|
||||
if not EGL.eglChooseConfig(
|
||||
self._display,
|
||||
_egl_attribs(
|
||||
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
|
||||
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
|
||||
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
|
||||
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
|
||||
),
|
||||
ctypes.byref(config), 1, ctypes.byref(n_configs),
|
||||
) or n_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying OSMesa backend")
|
||||
try:
|
||||
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
|
||||
self._backend = "osmesa"
|
||||
logger.debug("GLContext.__init__: OSMesa backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
|
||||
errors.append(("OSMesa", e))
|
||||
|
||||
self._surface = EGL.eglCreatePbufferSurface(
|
||||
self._display, config,
|
||||
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
|
||||
if self._backend is None:
|
||||
if sys.platform == "win32":
|
||||
platform_help = (
|
||||
"Windows: Ensure GPU drivers are installed and display is available.\n"
|
||||
" CPU-only/headless mode is not supported on Windows."
|
||||
)
|
||||
elif sys.platform == "darwin":
|
||||
platform_help = (
|
||||
"macOS: GLFW is not supported.\n"
|
||||
" Install OSMesa via Homebrew: brew install mesa\n"
|
||||
" Then: pip install PyOpenGL PyOpenGL-accelerate"
|
||||
)
|
||||
else:
|
||||
platform_help = (
|
||||
"Linux: Install one of these backends:\n"
|
||||
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
|
||||
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
|
||||
" Headless (CPU): sudo apt install libosmesa6"
|
||||
)
|
||||
|
||||
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenGL context.\n\n"
|
||||
f"Backend errors:\n{error_details}\n\n"
|
||||
f"{platform_help}"
|
||||
)
|
||||
if not self._surface:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
self._context = EGL.eglCreateContext(
|
||||
self._display, config, EGL.EGL_NO_CONTEXT,
|
||||
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
|
||||
)
|
||||
if not self._context:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
# Now import OpenGL.GL (after context is current)
|
||||
logger.debug("GLContext.__init__: importing OpenGL.GL")
|
||||
_import_opengl()
|
||||
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
self._vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
except Exception:
|
||||
self._cleanup()
|
||||
raise
|
||||
# Create VAO (required for core profile, but OSMesa may use compat profile)
|
||||
logger.debug("GLContext.__init__: creating VAO")
|
||||
try:
|
||||
vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(vao)
|
||||
self._vao = vao # Only store after successful bind
|
||||
logger.debug("GLContext.__init__: VAO created successfully")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
|
||||
# OSMesa with older Mesa may not support VAOs
|
||||
# Clean up if we created but couldn't bind
|
||||
if vao:
|
||||
try:
|
||||
gl.glDeleteVertexArrays(1, [vao])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
elapsed = (time.perf_counter() - start) * 1000
|
||||
|
||||
renderer = _gl_str(gl.GL_RENDERER)
|
||||
vendor = _gl_str(gl.GL_VENDOR)
|
||||
version = _gl_str(gl.GL_VERSION)
|
||||
# Log device info
|
||||
renderer = gl.glGetString(gl.GL_RENDERER)
|
||||
vendor = gl.glGetString(gl.GL_VENDOR)
|
||||
version = gl.glGetString(gl.GL_VERSION)
|
||||
renderer = renderer.decode() if renderer else "Unknown"
|
||||
vendor = vendor.decode() if vendor else "Unknown"
|
||||
version = version.decode() if version else "Unknown"
|
||||
|
||||
GLContext._initialized = True
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}")
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
|
||||
|
||||
def make_current(self):
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
err = EGL.eglGetError()
|
||||
raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})")
|
||||
if self._backend == "glfw":
|
||||
glfw.make_context_current(self._window)
|
||||
elif self._backend == "egl":
|
||||
from OpenGL.EGL import eglMakeCurrent
|
||||
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
|
||||
elif self._backend == "osmesa":
|
||||
from OpenGL.osmesa import OSMesaMakeCurrent
|
||||
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
|
||||
|
||||
if self._vao is not None:
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
def _cleanup(self):
|
||||
if not self._display:
|
||||
return
|
||||
try:
|
||||
if self._vao is not None:
|
||||
gl.glDeleteVertexArrays(1, [self._vao])
|
||||
self._vao = None
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._context:
|
||||
EGL.eglDestroyContext(self._display, self._context)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._surface:
|
||||
EGL.eglDestroySurface(self._display, self._surface)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglTerminate(self._display)
|
||||
except Exception:
|
||||
pass
|
||||
self._display = None
|
||||
|
||||
|
||||
def _compile_shader(source: str, shader_type: int) -> int:
|
||||
"""Compile a shader and return its ID."""
|
||||
@ -346,10 +459,8 @@ def _compile_shader(source: str, shader_type: int) -> int:
|
||||
gl.glShaderSource(shader, source)
|
||||
gl.glCompileShader(shader)
|
||||
|
||||
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
|
||||
error = gl.glGetShaderInfoLog(shader)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetShaderInfoLog(shader).decode()
|
||||
gl.glDeleteShader(shader)
|
||||
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
||||
|
||||
@ -373,10 +484,8 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
|
||||
gl.glDeleteShader(vertex_shader)
|
||||
gl.glDeleteShader(fragment_shader)
|
||||
|
||||
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
|
||||
error = gl.glGetProgramInfoLog(program)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetProgramInfoLog(program).decode()
|
||||
gl.glDeleteProgram(program)
|
||||
raise RuntimeError(f"Program linking failed:\n{error}")
|
||||
|
||||
@ -421,6 +530,9 @@ def _render_shader_batch(
|
||||
ctx = GLContext()
|
||||
ctx.make_current()
|
||||
|
||||
# Convert from GLSL ES to desktop GLSL 330
|
||||
fragment_source = _convert_es_to_desktop(fragment_code)
|
||||
|
||||
# Detect how many outputs the shader actually uses
|
||||
num_outputs = _detect_output_count(fragment_code)
|
||||
|
||||
@ -446,9 +558,9 @@ def _render_shader_batch(
|
||||
try:
|
||||
# Compile shaders (once for all batches)
|
||||
try:
|
||||
program = _create_program(VERTEX_SHADER, fragment_code)
|
||||
program = _create_program(VERTEX_SHADER, fragment_source)
|
||||
except RuntimeError:
|
||||
logger.error(f"Fragment shader:\n{fragment_code}")
|
||||
logger.error(f"Fragment shader:\n{fragment_source}")
|
||||
raise
|
||||
|
||||
gl.glUseProgram(program)
|
||||
@ -611,13 +723,13 @@ def _render_shader_batch(
|
||||
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
||||
|
||||
# Read back outputs for this batch
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
||||
# (glGetTexImage is synchronous, implicitly waits for rendering)
|
||||
batch_outputs = []
|
||||
for i in range(num_outputs):
|
||||
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
|
||||
buf = np.empty((height, width, 4), dtype=np.float32)
|
||||
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
|
||||
batch_outputs.append(buf[::-1, :, :].copy())
|
||||
for tex in output_textures:
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
||||
batch_outputs.append(img[::-1, :, :].copy())
|
||||
|
||||
# Pad with black images for unused outputs
|
||||
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
||||
@ -638,18 +750,18 @@ def _render_shader_batch(
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
||||
gl.glUseProgram(0)
|
||||
|
||||
if input_textures:
|
||||
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||
if curve_textures:
|
||||
gl.glDeleteTextures(len(curve_textures), curve_textures)
|
||||
if output_textures:
|
||||
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||
if ping_pong_textures:
|
||||
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
|
||||
for tex in input_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in curve_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in output_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in ping_pong_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
if fbo is not None:
|
||||
gl.glDeleteFramebuffers(1, [fbo])
|
||||
if ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
|
||||
for pp_fbo in ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(1, [pp_fbo])
|
||||
if program is not None:
|
||||
gl.glDeleteProgram(program)
|
||||
|
||||
|
||||
55
execution.py
55
execution.py
@ -1113,6 +1113,32 @@ def full_type_name(klass):
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
def node_not_executable_reason(class_def, class_type):
|
||||
"""Return a human-readable reason the node cannot be executed, or None if it's fine.
|
||||
|
||||
Catches a node whose declared entry point doesn't resolve to a real method
|
||||
(e.g. a V1 ``FUNCTION = "invert"`` where the method is misspelled, or a V3 node
|
||||
missing its ``execute`` override). Running this during validation surfaces the
|
||||
problem before execution starts, instead of after upstream nodes have run.
|
||||
|
||||
Only the class is inspected; the node is never instantiated here, so a node's
|
||||
``__init__`` side effects cannot run (or fail) during validation.
|
||||
"""
|
||||
try:
|
||||
if issubclass(class_def, _ComfyNodeInternal):
|
||||
# V3: validates that execute()/define_schema() overrides exist.
|
||||
class_def.VALIDATE_CLASS()
|
||||
return None
|
||||
# V1: FUNCTION names the method to call; it must exist on the class.
|
||||
function_name = getattr(class_def, "FUNCTION", None)
|
||||
if function_name is None:
|
||||
return f"'{class_type}' does not define FUNCTION"
|
||||
if not callable(getattr(class_def, function_name, None)):
|
||||
return f"'{class_type}' has no method '{function_name}' (declared in FUNCTION)"
|
||||
return None
|
||||
except Exception as ex:
|
||||
return str(ex)
|
||||
|
||||
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
@ -1148,6 +1174,35 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[
|
||||
}
|
||||
return (False, error, [], {})
|
||||
|
||||
# Make sure the node is actually executable (its FUNCTION/execute entry
|
||||
# point resolves to a real method) before we touch any schema-derived
|
||||
# attributes below or start execution. Catches code typos up front and
|
||||
# attributes the error to the offending node.
|
||||
not_executable = node_not_executable_reason(class_, class_type)
|
||||
if not_executable is not None:
|
||||
node_title = prompt[x].get('_meta', {}).get('title', class_type)
|
||||
error = {
|
||||
"type": "invalid_node_definition",
|
||||
"message": "Node is not executable",
|
||||
"details": f"{not_executable} (Node ID '#{x}')",
|
||||
"extra_info": {
|
||||
"node_id": x,
|
||||
"class_type": class_type,
|
||||
"node_title": node_title,
|
||||
}
|
||||
}
|
||||
node_errors = {x: {
|
||||
"errors": [{
|
||||
"type": "invalid_node_definition",
|
||||
"message": "Node is not executable",
|
||||
"details": not_executable,
|
||||
"extra_info": {},
|
||||
}],
|
||||
"dependent_outputs": [],
|
||||
"class_type": class_type,
|
||||
}}
|
||||
return (False, error, [], node_errors)
|
||||
|
||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||
if partial_execution_list is None or x in partial_execution_list:
|
||||
outputs.add(x)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.45.19
|
||||
comfyui-workflow-templates==0.10.7
|
||||
comfyui-embedded-docs==0.5.6
|
||||
comfyui-embedded-docs==0.5.5
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=16.0.0
|
||||
comfy-kitchen==0.2.14
|
||||
comfy-kitchen==0.2.12
|
||||
comfy-aimdo==0.4.10
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
@ -33,5 +33,5 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
PyOpenGL>=3.1.8
|
||||
comfy-angle
|
||||
PyOpenGL
|
||||
glfw
|
||||
|
||||
137
tests-unit/execution_test/validate_node_executable_test.py
Normal file
137
tests-unit/execution_test/validate_node_executable_test.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""Tests for pre-execution validation that a node is actually executable.
|
||||
|
||||
validate_prompt rejects a node whose declared entry point does not resolve to a
|
||||
real method (a V1 FUNCTION typo, or a V3 node missing its execute override) before
|
||||
any node runs, attributing the error to the offending node.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
import nodes
|
||||
from comfy_api.latest import io
|
||||
from execution import node_not_executable_reason, validate_prompt
|
||||
|
||||
|
||||
class _GoodV1Node:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "run"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "Test"
|
||||
|
||||
def run(self):
|
||||
return (None,)
|
||||
|
||||
|
||||
class _TypoV1Node:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "invert" # method below is misspelled
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "Test"
|
||||
|
||||
def invvert(self):
|
||||
return (None,)
|
||||
|
||||
|
||||
class _SideEffectInitV1Node:
|
||||
"""Valid class-level method, but a constructor that must never run in validation."""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "run"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "Test"
|
||||
|
||||
def __init__(self):
|
||||
raise RuntimeError("__init__ must not run during validation")
|
||||
|
||||
def run(self):
|
||||
return (None,)
|
||||
|
||||
|
||||
def _v3_schema(node_id):
|
||||
return io.Schema(
|
||||
node_id=node_id,
|
||||
display_name=node_id,
|
||||
category="Test",
|
||||
inputs=[],
|
||||
outputs=[io.Image.Output()],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
|
||||
class _GoodV3Node(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return _v3_schema("GoodV3Node")
|
||||
|
||||
@classmethod
|
||||
def execute(cls):
|
||||
return io.NodeOutput(None)
|
||||
|
||||
|
||||
class _TypoV3Node(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return _v3_schema("TypoV3Node")
|
||||
|
||||
@classmethod
|
||||
def exicute(cls): # typo: should be "execute"
|
||||
return io.NodeOutput(None)
|
||||
|
||||
|
||||
def _register(class_type, class_def):
|
||||
nodes.NODE_CLASS_MAPPINGS[class_type] = class_def
|
||||
|
||||
|
||||
def _validate(class_type):
|
||||
prompt = {"1": {"class_type": class_type, "inputs": {}}}
|
||||
return asyncio.run(validate_prompt("pid", prompt, None))
|
||||
|
||||
|
||||
def test_good_node_passes():
|
||||
_register("GoodV1Node", _GoodV1Node)
|
||||
assert node_not_executable_reason(_GoodV1Node, "GoodV1Node") is None
|
||||
valid, _, _, _ = _validate("GoodV1Node")
|
||||
assert valid is True
|
||||
|
||||
|
||||
def test_typo_node_rejected_with_node_error():
|
||||
_register("TypoV1Node", _TypoV1Node)
|
||||
valid, error, _, node_errors = _validate("TypoV1Node")
|
||||
assert valid is False
|
||||
assert error["type"] == "invalid_node_definition"
|
||||
assert node_errors["1"]["class_type"] == "TypoV1Node"
|
||||
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"
|
||||
assert "invert" in node_errors["1"]["errors"][0]["details"]
|
||||
|
||||
|
||||
def test_validation_does_not_instantiate_node():
|
||||
"""A valid node is not constructed during validation, so __init__ never runs."""
|
||||
_register("SideEffectInitV1Node", _SideEffectInitV1Node)
|
||||
assert node_not_executable_reason(_SideEffectInitV1Node, "SideEffectInitV1Node") is None
|
||||
valid, _, _, _ = _validate("SideEffectInitV1Node")
|
||||
assert valid is True
|
||||
|
||||
|
||||
def test_good_v3_node_passes():
|
||||
_register("GoodV3Node", _GoodV3Node)
|
||||
assert node_not_executable_reason(_GoodV3Node, "GoodV3Node") is None
|
||||
valid, _, _, _ = _validate("GoodV3Node")
|
||||
assert valid is True
|
||||
|
||||
|
||||
def test_typo_v3_node_rejected_with_node_error():
|
||||
_register("TypoV3Node", _TypoV3Node)
|
||||
valid, error, _, node_errors = _validate("TypoV3Node")
|
||||
assert valid is False
|
||||
assert error["type"] == "invalid_node_definition"
|
||||
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"
|
||||
Reference in New Issue
Block a user