mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-01 18:36:33 +08:00
Compare commits
19 Commits
fix/valida
...
matt/jobs-
| Author | SHA1 | Date | |
|---|---|---|---|
| b656217514 | |||
| 44fb02e510 | |||
| 50e5270b86 | |||
| bb131be9e8 | |||
| 6fca64780c | |||
| 6e11828d10 | |||
| b70944e710 | |||
| 1c59659a2f | |||
| d395813bcd | |||
| 8fe0243d97 | |||
| ba3f697dbb | |||
| 510ed5c384 | |||
| 7851410511 | |||
| a58473fd9b | |||
| 79c555ce6b | |||
| f19735759e | |||
| a95e461916 | |||
| 603d891eaf | |||
| 470ac36a0a |
38
.github/workflows/ci-cursor-review.yml
vendored
Normal file
38
.github/workflows/ci-cursor-review.yml
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
name: CI - Cursor Review
|
||||
|
||||
# Thin caller for the shared reusable cursor-review workflow in
|
||||
# Comfy-Org/github-workflows. The review logic (panel matrix, judge
|
||||
# consolidation, prompts, extract/post/notify scripts) lives there as the
|
||||
# single source of truth, so this repo only carries the repo-specific diff
|
||||
# excludes.
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [labeled, unlabeled]
|
||||
|
||||
concurrency:
|
||||
group: cursor-review-pr-${{ github.event.pull_request.number }}-${{ github.event.label.name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
cursor-review:
|
||||
if: github.event.label.name == 'cursor-review'
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
# SHA-pinned per zizmor `unpinned-uses: hash-pin`. Bump this SHA to pick up
|
||||
# upstream changes; keep `workflows_ref` matching so prompts/scripts load
|
||||
# from the same commit as the workflow definition.
|
||||
uses: Comfy-Org/github-workflows/.github/workflows/cursor-review.yml@047ca48febe3a6647608ed2e0c4331b491cb9d6a # github-workflows#9
|
||||
with:
|
||||
workflows_ref: 047ca48febe3a6647608ed2e0c4331b491cb9d6a
|
||||
diff_excludes: >-
|
||||
:!**/.claude/**
|
||||
:!**/dist/**
|
||||
:!**/vendor/**
|
||||
:!**/*.generated.*
|
||||
:!**/*.min.js
|
||||
:!**/*.min.css
|
||||
secrets:
|
||||
CURSOR_API_KEY: ${{ secrets.CURSOR_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
78
AGENTS.md
Normal file
78
AGENTS.md
Normal file
@ -0,0 +1,78 @@
|
||||
## Engineering Style
|
||||
|
||||
- Keep changes small and direct. Most fixes should touch the narrowest code path
|
||||
that explains the bug, performance issue, dtype issue, model-format issue, or
|
||||
user-facing behavior.
|
||||
- Change the least amount of files possible. A change that touches many files is
|
||||
more likely to be a bad change than a good one unless the broader scope is
|
||||
directly required.
|
||||
- Prefer practical fixes over broad architecture work. Add abstractions only
|
||||
when they remove real repeated logic or match an existing ComfyUI pattern.
|
||||
- Delete obsolete code aggressively when newer infrastructure makes it useless.
|
||||
Remove dead fallbacks, migration paths, unused options, debug prints, and
|
||||
compatibility branches that are no longer needed.
|
||||
- Revert or disable problematic behavior quickly when it breaks users. It is
|
||||
better to remove a broken feature path than keep a complicated partial fix.
|
||||
- Preserve existing APIs, node names, model-loading behavior, file layout, and
|
||||
workflow compatibility unless the change is explicitly about replacing them.
|
||||
- Code must look hand-written for this repository. Changes that read like
|
||||
generic AI-generated code will be rejected automatically: unnecessary helper
|
||||
layers, vague names, boilerplate comments, defensive branches without a real
|
||||
failure mode, broad rewrites, or code that ignores the local style.
|
||||
|
||||
## Python Style
|
||||
|
||||
- Keep imports at module scope. Avoid inline imports unless they are already part
|
||||
of an established optional-backend probe or are needed to avoid an import
|
||||
cycle.
|
||||
- Do not add unnecessary `try`/`except` blocks. Use them for optional dependency,
|
||||
platform, or backend capability detection only when the program has a useful
|
||||
fallback. Prefer specific exception types when changing new code.
|
||||
- Let unsupported model formats, invalid quantization metadata, and bad states
|
||||
fail with clear errors instead of silently producing lower quality output.
|
||||
- Match the existing local style in the file you edit. This codebase tolerates
|
||||
long lines, simple helper functions, module-level state, and direct tensor
|
||||
operations when they make the code easier to follow.
|
||||
- Keep comments sparse and useful. Short TODOs are fine when they name the
|
||||
concrete missing follow-up.
|
||||
|
||||
## Model, Device, and Memory Behavior
|
||||
|
||||
- Treat dtype, device placement, VRAM usage, and offloading behavior as core
|
||||
correctness concerns. Check CPU, CUDA, ROCm, MPS, DirectML, XPU, NPU, and low
|
||||
VRAM implications when touching shared execution or loading code.
|
||||
- Prefer native ComfyUI formats and existing quantization/offload helpers over
|
||||
adding parallel code paths. Use `comfy.quant_ops`, `comfy.model_management`,
|
||||
`comfy.memory_management`, `comfy.pinned_memory`, `comfy_aimdo`, and
|
||||
`comfy-kitchen` helpers where they already solve the problem.
|
||||
- Avoid unnecessary casts and transfers. Preserve the intended compute dtype,
|
||||
storage dtype, bias dtype, and original tensor shape metadata.
|
||||
- When optimizing, favor small measurable changes: fewer allocations, fewer
|
||||
device transfers, less peak memory, better batching, or use of a faster
|
||||
existing backend op.
|
||||
|
||||
## Nodes and User-Facing Behavior
|
||||
|
||||
- Follow existing node conventions: `INPUT_TYPES`, `RETURN_TYPES`, `FUNCTION`,
|
||||
`CATEGORY`, and registration through the local mapping used by that file.
|
||||
- Keep node changes backward compatible by default. Add inputs with sensible
|
||||
defaults and avoid changing output types unless the request requires it.
|
||||
- The official mascot of ComfyUI is a very cute anime girl with massive fennec
|
||||
ears, a big fluffy tail, long blonde wavy hair, and blue eyes. Feel free to
|
||||
use her in ComfyUI materials, UI text, examples, tests, generated assets, or
|
||||
comments, but do not disrespect her.
|
||||
- Warning and info messages should be short and actionable. Remove noisy or
|
||||
misleading messages rather than adding more logging.
|
||||
- Documentation and README edits should be concise, factual, and tied to the
|
||||
changed behavior.
|
||||
|
||||
## Commit and Review Habits
|
||||
|
||||
- If asked to write commit messages, use short direct subjects like the existing
|
||||
history: `Fix ...`, `Add ...`, `Support ...`, `Remove ...`, `Update ...`,
|
||||
`Make ...`, `Use ...`, `Disable ...`, `Bump ...`, or `Revert ...`.
|
||||
- Prefer one coherent behavioral change per commit. Dependency pins, tests, and
|
||||
the code that needs them may be in the same commit when they are inseparable.
|
||||
- In reviews, prioritize real user impact: crashes, wrong dtype/device behavior,
|
||||
memory regressions, broken model loading, workflow incompatibility, and noisy
|
||||
or misleading user-facing output.
|
||||
@ -240,6 +240,7 @@ database_default_path = os.path.abspath(
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
|
||||
parser.add_argument("--enable-asset-hashing", action="store_true", help="Compute blake3 content hashes when scanning assets. Hashing enables future asset-portability features (deduplication, cross-machine model resolution) but adds startup cost and per-output cost on large models directories. Off by default; enable to opt in.")
|
||||
parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
|
||||
parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")
|
||||
|
||||
|
||||
@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
|
||||
if (want_requant and len(fns) == 0 or update_weight):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
if isinstance(orig, QuantizedTensor):
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed)
|
||||
else:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
|
||||
if want_requant and len(fns) == 0:
|
||||
@ -1216,7 +1216,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
bias_dtype=input.dtype,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=want_requant,
|
||||
want_requant=True,
|
||||
)
|
||||
weight = weight.to(dtype=input.dtype)
|
||||
else:
|
||||
@ -1306,8 +1306,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
# dtype is now implicit in the layout class
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
|
||||
@ -121,6 +121,7 @@ class GeminiGenerationConfig(BaseModel):
|
||||
topK: int | None = Field(None, ge=1)
|
||||
topP: float | None = Field(None, ge=0.0, le=1.0)
|
||||
thinkingConfig: GeminiThinkingConfig | None = Field(None)
|
||||
responseModalities: list[str] | None = Field(None)
|
||||
|
||||
|
||||
class GeminiImageOutputOptions(BaseModel):
|
||||
|
||||
@ -13,7 +13,7 @@ import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl, Types
|
||||
from comfy_api_nodes.apis.gemini import (
|
||||
GeminiContent,
|
||||
GeminiFileData,
|
||||
@ -37,6 +37,7 @@ from comfy_api_nodes.util import (
|
||||
audio_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
@ -45,6 +46,7 @@ from comfy_api_nodes.util import (
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
validate_video_duration,
|
||||
video_to_base64_string,
|
||||
)
|
||||
|
||||
@ -229,10 +231,29 @@ async def get_image_from_response(response: GeminiGenerateContentResponse, thoug
|
||||
return torch.cat(image_tensors, dim=0)
|
||||
|
||||
|
||||
async def get_video_from_response(
|
||||
response: GeminiGenerateContentResponse, cls: type[IO.ComfyNode] | None = None
|
||||
) -> InputImpl.VideoFromFile:
|
||||
parts = get_parts_by_type(response, "video/*")
|
||||
for part in parts:
|
||||
if part.inlineData and part.inlineData.data:
|
||||
return InputImpl.VideoFromFile(BytesIO(base64.b64decode(part.inlineData.data)))
|
||||
if part.fileData and part.fileData.fileUri:
|
||||
return await download_url_to_video_output(part.fileData.fileUri, cls=cls)
|
||||
model_message = get_text_from_response(response).strip()
|
||||
if model_message:
|
||||
raise ValueError(f"Gemini did not generate a video. Model response: {model_message}")
|
||||
raise ValueError(
|
||||
"Gemini did not generate a video. Try rephrasing your prompt, "
|
||||
"shortening the requested duration, or reducing the number of input images/videos."
|
||||
)
|
||||
|
||||
|
||||
def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None:
|
||||
if not response.modelVersion:
|
||||
return None
|
||||
# Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing
|
||||
output_video_tokens_price = 0.0
|
||||
if response.modelVersion == "gemini-2.5-pro":
|
||||
input_tokens_price = 1.25
|
||||
output_text_tokens_price = 10.0
|
||||
@ -249,18 +270,27 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
||||
input_tokens_price = 2
|
||||
output_text_tokens_price = 12.0
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion == "gemini-3.1-flash-lite-preview":
|
||||
elif response.modelVersion in ("gemini-3.1-flash-lite-preview", "gemini-3.1-flash-lite"):
|
||||
input_tokens_price = 0.25
|
||||
output_text_tokens_price = 1.50
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion == "gemini-3-pro-image-preview":
|
||||
elif response.modelVersion in ("gemini-3-pro-image-preview", "gemini-3-pro-image"):
|
||||
input_tokens_price = 2
|
||||
output_text_tokens_price = 12.0
|
||||
output_image_tokens_price = 120.0
|
||||
elif response.modelVersion == "gemini-3.1-flash-image-preview":
|
||||
elif response.modelVersion in ("gemini-3.1-flash-image-preview", "gemini-3.1-flash-image"):
|
||||
input_tokens_price = 0.5
|
||||
output_text_tokens_price = 3.0
|
||||
output_image_tokens_price = 60.0
|
||||
elif response.modelVersion == "gemini-3.1-flash-lite-image":
|
||||
input_tokens_price = 0.25
|
||||
output_text_tokens_price = 1.50
|
||||
output_image_tokens_price = 30.0
|
||||
elif response.modelVersion == "gemini-omni-flash-preview":
|
||||
input_tokens_price = 2.145
|
||||
output_text_tokens_price = 12.87
|
||||
output_image_tokens_price = 0.0
|
||||
output_video_tokens_price = 25.025
|
||||
else:
|
||||
return None
|
||||
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
|
||||
@ -268,6 +298,8 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
||||
for i in response.usageMetadata.candidatesTokensDetails:
|
||||
if i.modality == Modality.IMAGE:
|
||||
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
|
||||
elif i.modality == Modality.VIDEO:
|
||||
final_price += output_video_tokens_price * i.tokenCount # for Omni Flash
|
||||
else:
|
||||
final_price += output_text_tokens_price * i.tokenCount
|
||||
if response.usageMetadata.thoughtsTokenCount:
|
||||
@ -1302,7 +1334,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
def _nano_banana_2_v2_model_inputs():
|
||||
def _nano_banana_2_v2_model_inputs(resolutions: list[str]):
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
@ -1329,8 +1361,8 @@ def _nano_banana_2_v2_model_inputs():
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["1K", "2K", "4K"],
|
||||
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
|
||||
options=resolutions,
|
||||
tooltip="Target output resolution.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"thinking_level",
|
||||
@ -1376,7 +1408,11 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"Nano Banana 2 (Gemini 3.1 Flash Image)",
|
||||
_nano_banana_2_v2_model_inputs(),
|
||||
_nano_banana_2_v2_model_inputs(resolutions=["1K", "2K", "4K"]),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Nano Banana 2 Lite",
|
||||
_nano_banana_2_v2_model_inputs(resolutions=["1K"]),
|
||||
),
|
||||
],
|
||||
),
|
||||
@ -1445,9 +1481,13 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$r := $lookup(widgets, "model.resolution");
|
||||
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||
$contains(widgets.model, "lite")
|
||||
? {"type":"usd","usd": 0.034, "format":{"suffix":"/Image","approximate":true}}
|
||||
: (
|
||||
$r := $lookup(widgets, "model.resolution");
|
||||
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||
)
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -1468,6 +1508,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
model_choice = model["model"]
|
||||
if model_choice == "Nano Banana 2 (Gemini 3.1 Flash Image)":
|
||||
model_id = "gemini-3.1-flash-image-preview"
|
||||
elif model_choice == "Nano Banana 2 Lite":
|
||||
model_id = "gemini-3.1-flash-lite-image"
|
||||
else:
|
||||
model_id = model_choice
|
||||
|
||||
@ -1517,6 +1559,149 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
OMNI_MAX_IMAGES = 14
|
||||
OMNI_MAX_VIDEOS = 3
|
||||
|
||||
OMNI_MODELS: dict[str, str] = {
|
||||
"Omni Flash": "gemini-omni-flash-preview",
|
||||
}
|
||||
|
||||
|
||||
def _omni_flash_inputs() -> list[Input]:
|
||||
"""Per-model inputs for the Omni video DynamicCombo (prompt + reference media + sampling)."""
|
||||
return [
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Describe the video to generate. Specify the length and aspect ratio directly in the "
|
||||
'prompt, e.g. "a 6-second clip in 16:9". Length may be 3-10 seconds; the aspect ratio must be '
|
||||
"16:9 (landscape) or 9:16 (portrait). The output is 720p, 24 FPS, with audio.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, OMNI_MAX_IMAGES + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional reference image(s) to guide or animate the video. Up to {OMNI_MAX_IMAGES} images.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"videos",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Video.Input("video"),
|
||||
names=[f"video_{i}" for i in range(1, OMNI_MAX_VIDEOS + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional reference video(s) to guide or edit. Up to {OMNI_MAX_VIDEOS} videos, "
|
||||
f"each up to 10 seconds long.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.01,
|
||||
tooltip="Controls randomness. Lower is more focused/deterministic, higher is more varied.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"top_p",
|
||||
default=0.95,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class GeminiVideoOmni(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiVideoOmni",
|
||||
display_name="Google Gemini Omni (Video)",
|
||||
category="partner/video/Gemini",
|
||||
essentials_category="Video Generation",
|
||||
description="Generate a video with audio from a text prompt using Google's Gemini Omni Flash model. "
|
||||
"Optionally provide reference images and/or videos to guide or edit the result. Describe the desired "
|
||||
"length (3-10s) and aspect ratio (16:9 or 9:16) directly in the prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Omni Flash", _omni_flash_inputs()),
|
||||
],
|
||||
tooltip="The Gemini video model used to generate the video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
IO.String.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr='{"type":"usd","usd":0.146,"format":{"suffix":"/second","approximate":true}}'
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, model: dict, seed: int) -> IO.NodeOutput:
|
||||
prompt = model.get("prompt") or ""
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_id = OMNI_MODELS[model["model"]]
|
||||
|
||||
images = [t for t in (model.get("images") or {}).values() if t is not None]
|
||||
videos = [v for v in (model.get("videos") or {}).values() if v is not None]
|
||||
if sum(get_number_of_images(t) for t in images) > OMNI_MAX_IMAGES:
|
||||
raise ValueError(f"The current maximum number of supported images is {OMNI_MAX_IMAGES}.")
|
||||
if len(videos) > OMNI_MAX_VIDEOS:
|
||||
raise ValueError(f"The current maximum number of supported videos is {OMNI_MAX_VIDEOS}.")
|
||||
for video in videos:
|
||||
validate_video_duration(video, max_duration=10)
|
||||
|
||||
parts: list[GeminiPart] = []
|
||||
if images or videos:
|
||||
parts.extend(await build_gemini_media_parts(cls, images, [], videos))
|
||||
parts.append(GeminiPart(text=prompt))
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"),
|
||||
data=GeminiGenerateContentRequest(
|
||||
contents=[GeminiContent(role=GeminiRole.user, parts=parts)],
|
||||
generationConfig=GeminiGenerationConfig(
|
||||
responseModalities=["TEXT", "VIDEO"],
|
||||
temperature=model.get("temperature", 1.0),
|
||||
topP=model.get("top_p", 0.95),
|
||||
),
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await get_video_from_response(response, cls=cls),
|
||||
get_text_from_response(response),
|
||||
)
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -1527,6 +1712,7 @@ class GeminiExtension(ComfyExtension):
|
||||
GeminiImage2,
|
||||
GeminiNanoBanana2,
|
||||
GeminiNanoBanana2V2,
|
||||
GeminiVideoOmni,
|
||||
GeminiInputFiles,
|
||||
]
|
||||
|
||||
|
||||
@ -31,6 +31,12 @@ class JobStatus:
|
||||
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
|
||||
|
||||
|
||||
# Maximum number of (distinct) ids accepted by the `ids` filter on the jobs
|
||||
# listing. Caps request size; the bounded id-lookup in get_all_jobs then keeps
|
||||
# a batch-poll request at O(requested ids), not O(total history).
|
||||
MAX_JOB_IDS_FILTER = 100
|
||||
|
||||
|
||||
def validate_job_id(value) -> str:
|
||||
"""Validate a client-supplied job (prompt) id.
|
||||
|
||||
@ -50,6 +56,56 @@ def validate_job_id(value) -> str:
|
||||
return value
|
||||
|
||||
|
||||
class JobIdsFilterError(ValueError):
|
||||
"""Raised when the ``ids`` query-param value is malformed.
|
||||
|
||||
Carries an HTTP-ready ``payload`` dict so the caller can return it verbatim
|
||||
with a 400 without re-deriving the message.
|
||||
"""
|
||||
|
||||
def __init__(self, payload: dict):
|
||||
self.payload = payload
|
||||
super().__init__(payload.get("error", "invalid ids"))
|
||||
|
||||
|
||||
def parse_ids_filter(ids_param: Optional[str]) -> Optional[list[str]]:
|
||||
"""Parse the ``ids`` query-param value into a filter list.
|
||||
|
||||
Single source of truth for ``ids`` parsing/validation, shared by the HTTP
|
||||
handler and its tests so the two cannot drift.
|
||||
|
||||
Returns:
|
||||
- ``None`` when the param is absent (``ids_param is None``) -> no filter.
|
||||
- A de-duplicated list when present. An empty/blank value (``?ids=``,
|
||||
``?ids=,,``) yields ``[]``, which ``get_all_jobs`` treats as a
|
||||
zero-match filter -- NOT "return everything".
|
||||
|
||||
Raises:
|
||||
JobIdsFilterError: more than ``MAX_JOB_IDS_FILTER`` distinct ids, or any
|
||||
id not in canonical UUID form. ``.payload`` is a 400-ready dict.
|
||||
"""
|
||||
if ids_param is None:
|
||||
return None
|
||||
# De-dupe up front: a repeated id must not count toward the cap or be
|
||||
# looked up twice. dict.fromkeys keeps first-seen order.
|
||||
ids_filter = list(dict.fromkeys(i.strip() for i in ids_param.split(',') if i.strip()))
|
||||
if len(ids_filter) > MAX_JOB_IDS_FILTER:
|
||||
raise JobIdsFilterError(
|
||||
{"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"}
|
||||
)
|
||||
invalid_ids = []
|
||||
for jid in ids_filter:
|
||||
try:
|
||||
validate_job_id(jid)
|
||||
except (ValueError, AttributeError):
|
||||
invalid_ids.append(jid)
|
||||
if invalid_ids:
|
||||
raise JobIdsFilterError(
|
||||
{"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids}
|
||||
)
|
||||
return ids_filter
|
||||
|
||||
|
||||
# Media types that can be previewed in the frontend
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
|
||||
|
||||
@ -362,6 +418,7 @@ def get_all_jobs(
|
||||
history: dict,
|
||||
status_filter: Optional[list[str]] = None,
|
||||
workflow_id: Optional[str] = None,
|
||||
ids: Optional[list[str]] = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
limit: Optional[int] = None,
|
||||
@ -376,6 +433,8 @@ def get_all_jobs(
|
||||
history: Dict of history items keyed by prompt_id
|
||||
status_filter: List of statuses to include (from JobStatus.ALL)
|
||||
workflow_id: Filter by workflow ID
|
||||
ids: Restrict the result to these job ids. None = no filter; a present
|
||||
list (including empty) restricts to that set, so [] = zero matches
|
||||
sort_by: Field to sort by ('created_at', 'execution_duration')
|
||||
sort_order: 'asc' or 'desc'
|
||||
limit: Maximum number of items to return
|
||||
@ -389,6 +448,10 @@ def get_all_jobs(
|
||||
if status_filter is None:
|
||||
status_filter = JobStatus.ALL
|
||||
|
||||
# None => no id filter; a present list (including empty) restricts to that
|
||||
# set (empty => zero matches).
|
||||
id_set = set(ids) if ids is not None else None
|
||||
|
||||
if JobStatus.IN_PROGRESS in status_filter:
|
||||
for item in running:
|
||||
jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS))
|
||||
@ -400,14 +463,30 @@ def get_all_jobs(
|
||||
history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED}
|
||||
requested_history_statuses = history_statuses & set(status_filter)
|
||||
if requested_history_statuses:
|
||||
for prompt_id, history_item in history.items():
|
||||
job = normalize_history_item(prompt_id, history_item)
|
||||
if job.get('status') in requested_history_statuses:
|
||||
jobs.append(job)
|
||||
if id_set is not None:
|
||||
# Batch-poll fast path: history is keyed by id, so look up only the
|
||||
# requested ids instead of normalizing the whole (unbounded) history.
|
||||
for prompt_id in id_set:
|
||||
history_item = history.get(prompt_id)
|
||||
if history_item is None:
|
||||
continue
|
||||
job = normalize_history_item(prompt_id, history_item)
|
||||
if job.get('status') in requested_history_statuses:
|
||||
jobs.append(job)
|
||||
else:
|
||||
for prompt_id, history_item in history.items():
|
||||
job = normalize_history_item(prompt_id, history_item)
|
||||
if job.get('status') in requested_history_statuses:
|
||||
jobs.append(job)
|
||||
|
||||
if workflow_id:
|
||||
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]
|
||||
|
||||
if id_set is not None:
|
||||
# `.get('id')` (not `j['id']`): prune_dict can drop a None id, and a
|
||||
# job missing its id should degrade to "no match", not raise KeyError.
|
||||
jobs = [j for j in jobs if j.get('id') in id_set]
|
||||
|
||||
jobs = apply_sorting(jobs, sort_by, sort_order)
|
||||
|
||||
total_count = len(jobs)
|
||||
|
||||
@ -8,7 +8,8 @@ class CLIPTextEncodeControlnet(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeControlnet",
|
||||
category="experimental/conditioning",
|
||||
display_name="CLIP Text Encode (Controlnet)",
|
||||
category="model/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Conditioning.Input("conditioning"),
|
||||
@ -35,11 +36,12 @@ class T5TokenizerOptions(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="T5TokenizerOptions",
|
||||
category="experimental/conditioning",
|
||||
display_name="T5 Tokenizer Options",
|
||||
category="model/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),
|
||||
io.Int.Input("min_length", default=0, min=0, max=10000, step=1, advanced=True),
|
||||
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1),
|
||||
io.Int.Input("min_length", default=0, min=0, max=10000, step=1),
|
||||
],
|
||||
outputs=[io.Clip.Output()],
|
||||
is_experimental=True,
|
||||
|
||||
@ -1070,7 +1070,7 @@ class AddNoise(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="AddNoise",
|
||||
category="experimental/custom_sampling/noise",
|
||||
category="model/sampling/noise",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
@ -1120,7 +1120,7 @@ class ManualSigmas(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ManualSigmas",
|
||||
search_aliases=["custom noise schedule", "define sigmas"],
|
||||
category="experimental/custom_sampling",
|
||||
category="model/sampling/sigmas",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.String.Input("sigmas", default="1, 0.5", multiline=False)
|
||||
|
||||
@ -1,85 +1,68 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import ctypes
|
||||
import logging
|
||||
import ctypes.util
|
||||
import importlib.util
|
||||
from typing import TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import nodes
|
||||
import comfy_angle
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from typing_extensions import override
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _check_opengl_availability():
|
||||
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
|
||||
logger.debug("_check_opengl_availability: starting")
|
||||
missing = []
|
||||
def _preload_angle():
|
||||
egl_path = comfy_angle.get_egl_path()
|
||||
gles_path = comfy_angle.get_glesv2_path()
|
||||
|
||||
# Check Python packages (using find_spec to avoid importing)
|
||||
logger.debug("_check_opengl_availability: checking for glfw package")
|
||||
if importlib.util.find_spec("glfw") is None:
|
||||
missing.append("glfw")
|
||||
if sys.platform == "win32":
|
||||
angle_dir = comfy_angle.get_lib_dir()
|
||||
os.add_dll_directory(angle_dir)
|
||||
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
|
||||
|
||||
logger.debug("_check_opengl_availability: checking for OpenGL package")
|
||||
if importlib.util.find_spec("OpenGL") is None:
|
||||
missing.append("PyOpenGL")
|
||||
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
||||
)
|
||||
|
||||
# On Linux without display, check if headless backends are available
|
||||
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
|
||||
if sys.platform.startswith("linux"):
|
||||
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
|
||||
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
|
||||
if not has_display:
|
||||
# Check for EGL or OSMesa libraries
|
||||
logger.debug("_check_opengl_availability: checking for EGL library")
|
||||
has_egl = ctypes.util.find_library("EGL")
|
||||
logger.debug("_check_opengl_availability: checking for OSMesa library")
|
||||
has_osmesa = ctypes.util.find_library("OSMesa")
|
||||
|
||||
# Error disabled for CI as it fails this check
|
||||
# if not has_egl and not has_osmesa:
|
||||
# raise RuntimeError(
|
||||
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
|
||||
# "See error below for installation instructions."
|
||||
# )
|
||||
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
|
||||
|
||||
logger.debug("_check_opengl_availability: completed")
|
||||
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
|
||||
ctypes.CDLL(str(egl_path), mode=mode)
|
||||
ctypes.CDLL(str(gles_path), mode=mode)
|
||||
|
||||
|
||||
# Run early check at import time
|
||||
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
|
||||
_check_opengl_availability()
|
||||
|
||||
# OpenGL modules - initialized lazily when context is created
|
||||
gl = None
|
||||
glfw = None
|
||||
EGL = None
|
||||
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
|
||||
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
|
||||
_preload_angle()
|
||||
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
|
||||
|
||||
def _import_opengl():
|
||||
"""Import OpenGL module. Called after context is created."""
|
||||
global gl
|
||||
if gl is None:
|
||||
logger.debug("_import_opengl: importing OpenGL.GL")
|
||||
import OpenGL.GL as _gl
|
||||
gl = _gl
|
||||
logger.debug("_import_opengl: import completed")
|
||||
return gl
|
||||
import OpenGL
|
||||
OpenGL.USE_ACCELERATE = False
|
||||
|
||||
|
||||
def _patch_find_library():
|
||||
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
|
||||
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
|
||||
'libGLESv2'. Patch find_library to return the full ANGLE paths so
|
||||
PyOpenGL loads the same libraries we pre-loaded."""
|
||||
if sys.platform == "linux":
|
||||
return
|
||||
import ctypes.util
|
||||
_orig = ctypes.util.find_library
|
||||
def _patched(name):
|
||||
if name == 'EGL':
|
||||
return comfy_angle.get_egl_path()
|
||||
if name == 'GLESv2':
|
||||
return comfy_angle.get_glesv2_path()
|
||||
return _orig(name)
|
||||
ctypes.util.find_library = _patched
|
||||
|
||||
|
||||
_patch_find_library()
|
||||
|
||||
from OpenGL import EGL
|
||||
from OpenGL import GLES3 as gl
|
||||
|
||||
class SizeModeInput(TypedDict):
|
||||
size_mode: str
|
||||
width: int
|
||||
@ -102,7 +85,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
# (-1,-1)---(3,-1)
|
||||
#
|
||||
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
||||
VERTEX_SHADER = """#version 330 core
|
||||
VERTEX_SHADER = """#version 300 es
|
||||
out vec2 v_texCoord;
|
||||
void main() {
|
||||
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
||||
@ -126,14 +109,99 @@ void main() {
|
||||
"""
|
||||
|
||||
|
||||
def _convert_es_to_desktop(source: str) -> str:
|
||||
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
||||
# Remove any existing #version directive
|
||||
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
||||
# Remove precision qualifiers (not needed in desktop GLSL)
|
||||
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
||||
# Prepend desktop GLSL version
|
||||
return "#version 330 core\n" + source
|
||||
|
||||
def _egl_attribs(*values):
|
||||
"""Build an EGL_NONE-terminated EGLint attribute array."""
|
||||
vals = list(values) + [EGL.EGL_NONE]
|
||||
return (ctypes.c_int32 * len(vals))(*vals)
|
||||
|
||||
|
||||
# EGL platform extension constants
|
||||
EGL_PLATFORM_ANGLE_ANGLE = 0x3202
|
||||
EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203
|
||||
EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450
|
||||
EGL_MESA_PLATFORM_SURFACELESS = 0x31DD
|
||||
|
||||
|
||||
_eglGetPlatformDisplayEXT = None
|
||||
|
||||
def _get_egl_platform_display_ext(platform, native_display, attribs):
|
||||
"""Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL)."""
|
||||
global _eglGetPlatformDisplayEXT
|
||||
if _eglGetPlatformDisplayEXT is None:
|
||||
from OpenGL import platform as _plat
|
||||
egl_lib = _plat.PLATFORM.EGL
|
||||
_get_proc = egl_lib.eglGetProcAddress
|
||||
_get_proc.restype = ctypes.c_void_p
|
||||
_get_proc.argtypes = [ctypes.c_char_p]
|
||||
ptr = _get_proc(b"eglGetPlatformDisplayEXT")
|
||||
if not ptr:
|
||||
return None
|
||||
func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p)
|
||||
_eglGetPlatformDisplayEXT = func_type(ptr)
|
||||
|
||||
raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs)
|
||||
if not raw:
|
||||
return None
|
||||
return ctypes.cast(raw, EGL.EGLDisplay)
|
||||
|
||||
|
||||
def _get_egl_display():
|
||||
"""Get an EGL display, trying the default first then ANGLE's Vulkan
|
||||
platform for headless environments without a display server."""
|
||||
failures = []
|
||||
|
||||
# Try the default display first (works when X11/Wayland is available)
|
||||
display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
|
||||
if display:
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
try:
|
||||
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
return display, major.value, minor.value
|
||||
except Exception as e:
|
||||
failures.append(f"default: {e}")
|
||||
|
||||
logger.info("Default EGL display unavailable, trying headless fallbacks")
|
||||
|
||||
# Headless fallback strategies, tried in order:
|
||||
headless_strategies = [
|
||||
("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None),
|
||||
("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None,
|
||||
_egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)),
|
||||
]
|
||||
|
||||
for name, platform, native_display, attribs in headless_strategies:
|
||||
display = _get_egl_platform_display_ext(platform, native_display, attribs)
|
||||
if not display:
|
||||
failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display")
|
||||
continue
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
try:
|
||||
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
logger.info(f"Using EGL {name} platform (headless)")
|
||||
return display, major.value, minor.value
|
||||
failures.append(f"{name}: eglInitialize returned false")
|
||||
except Exception as e:
|
||||
failures.append(f"{name}: {e}")
|
||||
continue
|
||||
|
||||
details = "\n".join(f" - {f}" for f in failures)
|
||||
raise RuntimeError(
|
||||
"Failed to initialize EGL display.\n"
|
||||
"No display server and no headless EGL platform available.\n"
|
||||
f"Tried:\n{details}\n"
|
||||
"Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer."
|
||||
)
|
||||
|
||||
|
||||
def _gl_str(name):
|
||||
"""Get an OpenGL string parameter."""
|
||||
v = gl.glGetString(name)
|
||||
if not v:
|
||||
return "Unknown"
|
||||
if isinstance(v, bytes):
|
||||
return v.decode(errors="replace")
|
||||
return ctypes.string_at(v).decode(errors="replace")
|
||||
|
||||
|
||||
def _detect_output_count(source: str) -> int:
|
||||
@ -159,163 +227,8 @@ def _detect_pass_count(source: str) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _init_glfw():
|
||||
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_glfw: starting")
|
||||
# On macOS, glfw.init() must be called from main thread or it hangs forever
|
||||
if sys.platform == "darwin":
|
||||
logger.debug("_init_glfw: skipping on macOS")
|
||||
raise RuntimeError("GLFW backend not supported on macOS")
|
||||
|
||||
logger.debug("_init_glfw: importing glfw module")
|
||||
import glfw as _glfw
|
||||
|
||||
logger.debug("_init_glfw: calling glfw.init()")
|
||||
if not _glfw.init():
|
||||
raise RuntimeError("glfw.init() failed")
|
||||
|
||||
try:
|
||||
logger.debug("_init_glfw: setting window hints")
|
||||
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
|
||||
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
|
||||
|
||||
logger.debug("_init_glfw: calling create_window()")
|
||||
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
||||
if not window:
|
||||
raise RuntimeError("glfw.create_window() failed")
|
||||
|
||||
logger.debug("_init_glfw: calling make_context_current()")
|
||||
_glfw.make_context_current(window)
|
||||
logger.debug("_init_glfw: completed successfully")
|
||||
return window, _glfw
|
||||
except Exception:
|
||||
logger.debug("_init_glfw: failed, terminating glfw")
|
||||
_glfw.terminate()
|
||||
raise
|
||||
|
||||
|
||||
def _init_egl():
|
||||
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_egl: starting")
|
||||
from OpenGL import EGL as _EGL
|
||||
from OpenGL.EGL import (
|
||||
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
|
||||
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
|
||||
eglTerminate, eglDestroyContext, eglDestroySurface,
|
||||
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
|
||||
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
|
||||
)
|
||||
logger.debug("_init_egl: imports completed")
|
||||
|
||||
display = None
|
||||
context = None
|
||||
surface = None
|
||||
|
||||
try:
|
||||
logger.debug("_init_egl: calling eglGetDisplay()")
|
||||
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
|
||||
if display == _EGL.EGL_NO_DISPLAY:
|
||||
raise RuntimeError("eglGetDisplay() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglInitialize()")
|
||||
major, minor = _EGL.EGLint(), _EGL.EGLint()
|
||||
if not eglInitialize(display, major, minor):
|
||||
display = None # Not initialized, don't terminate
|
||||
raise RuntimeError("eglInitialize() failed")
|
||||
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
|
||||
|
||||
config_attribs = [
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
|
||||
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
|
||||
EGL_DEPTH_SIZE, 0, EGL_NONE
|
||||
]
|
||||
configs = (_EGL.EGLConfig * 1)()
|
||||
num_configs = _EGL.EGLint()
|
||||
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
config = configs[0]
|
||||
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
|
||||
|
||||
if not eglBindAPI(EGL_OPENGL_API):
|
||||
raise RuntimeError("eglBindAPI() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreateContext()")
|
||||
context_attribs = [
|
||||
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
|
||||
EGL_NONE
|
||||
]
|
||||
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
|
||||
if context == EGL_NO_CONTEXT:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
|
||||
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
|
||||
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
|
||||
if surface == _EGL.EGL_NO_SURFACE:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglMakeCurrent()")
|
||||
if not eglMakeCurrent(display, surface, surface, context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_egl: completed successfully")
|
||||
return display, context, surface, _EGL
|
||||
|
||||
except Exception:
|
||||
logger.debug("_init_egl: failed, cleaning up")
|
||||
# Clean up any resources on failure
|
||||
if surface is not None:
|
||||
eglDestroySurface(display, surface)
|
||||
if context is not None:
|
||||
eglDestroyContext(display, context)
|
||||
if display is not None:
|
||||
eglTerminate(display)
|
||||
raise
|
||||
|
||||
|
||||
def _init_osmesa():
|
||||
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
|
||||
import ctypes
|
||||
|
||||
logger.debug("_init_osmesa: starting")
|
||||
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
||||
|
||||
logger.debug("_init_osmesa: importing OpenGL.osmesa")
|
||||
from OpenGL import GL as _gl
|
||||
from OpenGL.osmesa import (
|
||||
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
|
||||
OSMESA_RGBA,
|
||||
)
|
||||
logger.debug("_init_osmesa: imports completed")
|
||||
|
||||
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
|
||||
if not ctx:
|
||||
raise RuntimeError("OSMesaCreateContextExt() failed")
|
||||
|
||||
width, height = 64, 64
|
||||
buffer = (ctypes.c_ubyte * (width * height * 4))()
|
||||
|
||||
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
|
||||
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
|
||||
OSMesaDestroyContext(ctx)
|
||||
raise RuntimeError("OSMesaMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_osmesa: completed successfully")
|
||||
return ctx, buffer
|
||||
|
||||
|
||||
class GLContext:
|
||||
"""Manages OpenGL context and resources for shader execution.
|
||||
|
||||
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
|
||||
"""
|
||||
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
@ -327,131 +240,105 @@ class GLContext:
|
||||
|
||||
def __init__(self):
|
||||
if GLContext._initialized:
|
||||
logger.debug("GLContext.__init__: already initialized, skipping")
|
||||
return
|
||||
|
||||
logger.debug("GLContext.__init__: starting initialization")
|
||||
|
||||
global glfw, EGL
|
||||
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
|
||||
self._backend = None
|
||||
self._window = None
|
||||
self._egl_display = None
|
||||
self._egl_context = None
|
||||
self._egl_surface = None
|
||||
self._osmesa_ctx = None
|
||||
self._osmesa_buffer = None
|
||||
self._display = None
|
||||
self._surface = None
|
||||
self._context = None
|
||||
self._vao = None
|
||||
|
||||
# Try backends in order: GLFW → EGL → OSMesa
|
||||
errors = []
|
||||
|
||||
logger.debug("GLContext.__init__: trying GLFW backend")
|
||||
try:
|
||||
self._window, glfw = _init_glfw()
|
||||
self._backend = "glfw"
|
||||
logger.debug("GLContext.__init__: GLFW backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
|
||||
errors.append(("GLFW", e))
|
||||
self._display, self._egl_major, self._egl_minor = _get_egl_display()
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying EGL backend")
|
||||
try:
|
||||
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
|
||||
self._backend = "egl"
|
||||
logger.debug("GLContext.__init__: EGL backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
|
||||
errors.append(("EGL", e))
|
||||
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
|
||||
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying OSMesa backend")
|
||||
try:
|
||||
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
|
||||
self._backend = "osmesa"
|
||||
logger.debug("GLContext.__init__: OSMesa backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
|
||||
errors.append(("OSMesa", e))
|
||||
config = EGL.EGLConfig()
|
||||
n_configs = ctypes.c_int32(0)
|
||||
if not EGL.eglChooseConfig(
|
||||
self._display,
|
||||
_egl_attribs(
|
||||
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
|
||||
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
|
||||
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
|
||||
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
|
||||
),
|
||||
ctypes.byref(config), 1, ctypes.byref(n_configs),
|
||||
) or n_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
|
||||
if self._backend is None:
|
||||
if sys.platform == "win32":
|
||||
platform_help = (
|
||||
"Windows: Ensure GPU drivers are installed and display is available.\n"
|
||||
" CPU-only/headless mode is not supported on Windows."
|
||||
)
|
||||
elif sys.platform == "darwin":
|
||||
platform_help = (
|
||||
"macOS: GLFW is not supported.\n"
|
||||
" Install OSMesa via Homebrew: brew install mesa\n"
|
||||
" Then: pip install PyOpenGL PyOpenGL-accelerate"
|
||||
)
|
||||
else:
|
||||
platform_help = (
|
||||
"Linux: Install one of these backends:\n"
|
||||
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
|
||||
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
|
||||
" Headless (CPU): sudo apt install libosmesa6"
|
||||
)
|
||||
|
||||
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenGL context.\n\n"
|
||||
f"Backend errors:\n{error_details}\n\n"
|
||||
f"{platform_help}"
|
||||
self._surface = EGL.eglCreatePbufferSurface(
|
||||
self._display, config,
|
||||
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
|
||||
)
|
||||
if not self._surface:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
# Now import OpenGL.GL (after context is current)
|
||||
logger.debug("GLContext.__init__: importing OpenGL.GL")
|
||||
_import_opengl()
|
||||
self._context = EGL.eglCreateContext(
|
||||
self._display, config, EGL.EGL_NO_CONTEXT,
|
||||
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
|
||||
)
|
||||
if not self._context:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
# Create VAO (required for core profile, but OSMesa may use compat profile)
|
||||
logger.debug("GLContext.__init__: creating VAO")
|
||||
try:
|
||||
vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(vao)
|
||||
self._vao = vao # Only store after successful bind
|
||||
logger.debug("GLContext.__init__: VAO created successfully")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
|
||||
# OSMesa with older Mesa may not support VAOs
|
||||
# Clean up if we created but couldn't bind
|
||||
if vao:
|
||||
try:
|
||||
gl.glDeleteVertexArrays(1, [vao])
|
||||
except Exception:
|
||||
pass
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
self._vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
except Exception:
|
||||
self._cleanup()
|
||||
raise
|
||||
|
||||
elapsed = (time.perf_counter() - start) * 1000
|
||||
|
||||
# Log device info
|
||||
renderer = gl.glGetString(gl.GL_RENDERER)
|
||||
vendor = gl.glGetString(gl.GL_VENDOR)
|
||||
version = gl.glGetString(gl.GL_VERSION)
|
||||
renderer = renderer.decode() if renderer else "Unknown"
|
||||
vendor = vendor.decode() if vendor else "Unknown"
|
||||
version = version.decode() if version else "Unknown"
|
||||
renderer = _gl_str(gl.GL_RENDERER)
|
||||
vendor = _gl_str(gl.GL_VENDOR)
|
||||
version = _gl_str(gl.GL_VERSION)
|
||||
|
||||
GLContext._initialized = True
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}")
|
||||
|
||||
def make_current(self):
|
||||
if self._backend == "glfw":
|
||||
glfw.make_context_current(self._window)
|
||||
elif self._backend == "egl":
|
||||
from OpenGL.EGL import eglMakeCurrent
|
||||
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
|
||||
elif self._backend == "osmesa":
|
||||
from OpenGL.osmesa import OSMesaMakeCurrent
|
||||
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
|
||||
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
err = EGL.eglGetError()
|
||||
raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})")
|
||||
if self._vao is not None:
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
def _cleanup(self):
|
||||
if not self._display:
|
||||
return
|
||||
try:
|
||||
if self._vao is not None:
|
||||
gl.glDeleteVertexArrays(1, [self._vao])
|
||||
self._vao = None
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._context:
|
||||
EGL.eglDestroyContext(self._display, self._context)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._surface:
|
||||
EGL.eglDestroySurface(self._display, self._surface)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglTerminate(self._display)
|
||||
except Exception:
|
||||
pass
|
||||
self._display = None
|
||||
|
||||
|
||||
def _compile_shader(source: str, shader_type: int) -> int:
|
||||
"""Compile a shader and return its ID."""
|
||||
@ -459,8 +346,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
|
||||
gl.glShaderSource(shader, source)
|
||||
gl.glCompileShader(shader)
|
||||
|
||||
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetShaderInfoLog(shader).decode()
|
||||
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
|
||||
error = gl.glGetShaderInfoLog(shader)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteShader(shader)
|
||||
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
||||
|
||||
@ -484,8 +373,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
|
||||
gl.glDeleteShader(vertex_shader)
|
||||
gl.glDeleteShader(fragment_shader)
|
||||
|
||||
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetProgramInfoLog(program).decode()
|
||||
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
|
||||
error = gl.glGetProgramInfoLog(program)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteProgram(program)
|
||||
raise RuntimeError(f"Program linking failed:\n{error}")
|
||||
|
||||
@ -530,9 +421,6 @@ def _render_shader_batch(
|
||||
ctx = GLContext()
|
||||
ctx.make_current()
|
||||
|
||||
# Convert from GLSL ES to desktop GLSL 330
|
||||
fragment_source = _convert_es_to_desktop(fragment_code)
|
||||
|
||||
# Detect how many outputs the shader actually uses
|
||||
num_outputs = _detect_output_count(fragment_code)
|
||||
|
||||
@ -558,9 +446,9 @@ def _render_shader_batch(
|
||||
try:
|
||||
# Compile shaders (once for all batches)
|
||||
try:
|
||||
program = _create_program(VERTEX_SHADER, fragment_source)
|
||||
program = _create_program(VERTEX_SHADER, fragment_code)
|
||||
except RuntimeError:
|
||||
logger.error(f"Fragment shader:\n{fragment_source}")
|
||||
logger.error(f"Fragment shader:\n{fragment_code}")
|
||||
raise
|
||||
|
||||
gl.glUseProgram(program)
|
||||
@ -723,13 +611,13 @@ def _render_shader_batch(
|
||||
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
||||
|
||||
# Read back outputs for this batch
|
||||
# (glGetTexImage is synchronous, implicitly waits for rendering)
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
||||
batch_outputs = []
|
||||
for tex in output_textures:
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
||||
batch_outputs.append(img[::-1, :, :].copy())
|
||||
for i in range(num_outputs):
|
||||
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
|
||||
buf = np.empty((height, width, 4), dtype=np.float32)
|
||||
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
|
||||
batch_outputs.append(buf[::-1, :, :].copy())
|
||||
|
||||
# Pad with black images for unused outputs
|
||||
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
||||
@ -750,18 +638,18 @@ def _render_shader_batch(
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
||||
gl.glUseProgram(0)
|
||||
|
||||
for tex in input_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in curve_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in output_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in ping_pong_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
if input_textures:
|
||||
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||
if curve_textures:
|
||||
gl.glDeleteTextures(len(curve_textures), curve_textures)
|
||||
if output_textures:
|
||||
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||
if ping_pong_textures:
|
||||
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
|
||||
if fbo is not None:
|
||||
gl.glDeleteFramebuffers(1, [fbo])
|
||||
for pp_fbo in ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(1, [pp_fbo])
|
||||
if ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
|
||||
if program is not None:
|
||||
gl.glDeleteProgram(program)
|
||||
|
||||
|
||||
@ -123,7 +123,8 @@ class PhotoMakerLoader(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PhotoMakerLoader",
|
||||
category="experimental/photomaker",
|
||||
display_name="Load PhotoMaker Model",
|
||||
category="model/loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
|
||||
],
|
||||
@ -149,7 +150,8 @@ class PhotoMakerEncode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PhotoMakerEncode",
|
||||
category="experimental/photomaker",
|
||||
display_name="PhotoMaker Encode",
|
||||
category="model/conditioning/photomaker",
|
||||
inputs=[
|
||||
io.Photomaker.Input("photomaker"),
|
||||
io.Image.Input("image"),
|
||||
|
||||
@ -119,7 +119,7 @@ class StableCascade_SuperResolutionControlnet(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_SuperResolutionControlnet",
|
||||
category="experimental/stable_cascade",
|
||||
category="experimental/stable cascade",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
|
||||
@ -143,7 +143,7 @@ class VAEDecodeTripoSplat(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeTripoSplat",
|
||||
display_name="TripoSplat Decode",
|
||||
category="3d/latent",
|
||||
category="model/latent/triposplat",
|
||||
description="Decode the sampled TripoSplat latent into a 3D gaussian splat. "
|
||||
"Modify the number of gaussians to vary the density.",
|
||||
inputs=[
|
||||
@ -188,7 +188,7 @@ class TripoSplatSamplingPreview(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoSplatSamplingPreview",
|
||||
display_name="TripoSplat Sampling Preview",
|
||||
category="3d/latent",
|
||||
category="model/latent/triposplat",
|
||||
description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded "
|
||||
"gaussian splat preview at each step.",
|
||||
inputs=[
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.26.0"
|
||||
__version__ = "0.27.0"
|
||||
|
||||
55
execution.py
55
execution.py
@ -1113,32 +1113,6 @@ def full_type_name(klass):
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
def node_not_executable_reason(class_def, class_type):
|
||||
"""Return a human-readable reason the node cannot be executed, or None if it's fine.
|
||||
|
||||
Catches a node whose declared entry point doesn't resolve to a real method
|
||||
(e.g. a V1 ``FUNCTION = "invert"`` where the method is misspelled, or a V3 node
|
||||
missing its ``execute`` override). Running this during validation surfaces the
|
||||
problem before execution starts, instead of after upstream nodes have run.
|
||||
|
||||
Only the class is inspected; the node is never instantiated here, so a node's
|
||||
``__init__`` side effects cannot run (or fail) during validation.
|
||||
"""
|
||||
try:
|
||||
if issubclass(class_def, _ComfyNodeInternal):
|
||||
# V3: validates that execute()/define_schema() overrides exist.
|
||||
class_def.VALIDATE_CLASS()
|
||||
return None
|
||||
# V1: FUNCTION names the method to call; it must exist on the class.
|
||||
function_name = getattr(class_def, "FUNCTION", None)
|
||||
if function_name is None:
|
||||
return f"'{class_type}' does not define FUNCTION"
|
||||
if not callable(getattr(class_def, function_name, None)):
|
||||
return f"'{class_type}' has no method '{function_name}' (declared in FUNCTION)"
|
||||
return None
|
||||
except Exception as ex:
|
||||
return str(ex)
|
||||
|
||||
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
@ -1174,35 +1148,6 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[
|
||||
}
|
||||
return (False, error, [], {})
|
||||
|
||||
# Make sure the node is actually executable (its FUNCTION/execute entry
|
||||
# point resolves to a real method) before we touch any schema-derived
|
||||
# attributes below or start execution. Catches code typos up front and
|
||||
# attributes the error to the offending node.
|
||||
not_executable = node_not_executable_reason(class_, class_type)
|
||||
if not_executable is not None:
|
||||
node_title = prompt[x].get('_meta', {}).get('title', class_type)
|
||||
error = {
|
||||
"type": "invalid_node_definition",
|
||||
"message": "Node is not executable",
|
||||
"details": f"{not_executable} (Node ID '#{x}')",
|
||||
"extra_info": {
|
||||
"node_id": x,
|
||||
"class_type": class_type,
|
||||
"node_title": node_title,
|
||||
}
|
||||
}
|
||||
node_errors = {x: {
|
||||
"errors": [{
|
||||
"type": "invalid_node_definition",
|
||||
"message": "Node is not executable",
|
||||
"details": not_executable,
|
||||
"extra_info": {},
|
||||
}],
|
||||
"dependent_outputs": [],
|
||||
"class_type": class_type,
|
||||
}}
|
||||
return (False, error, [], node_errors)
|
||||
|
||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||
if partial_execution_list is None or x in partial_execution_list:
|
||||
outputs.add(x)
|
||||
|
||||
4
main.py
4
main.py
@ -403,7 +403,7 @@ def prompt_worker(q, server_instance):
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
if not asset_seeder.is_disabled():
|
||||
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=args.enable_asset_hashing)
|
||||
asset_seeder.resume()
|
||||
|
||||
|
||||
@ -458,7 +458,7 @@ def setup_database():
|
||||
if dependencies_available():
|
||||
init_db()
|
||||
if args.enable_assets:
|
||||
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True):
|
||||
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=args.enable_asset_hashing):
|
||||
logging.info("Background asset scan initiated for models, input, output")
|
||||
except Exception as e:
|
||||
if "database is locked" in str(e):
|
||||
|
||||
36
nodes.py
36
nodes.py
@ -159,6 +159,29 @@ class ConditioningConcat:
|
||||
|
||||
return (out, )
|
||||
|
||||
class ConditioningMultiply:
|
||||
SEARCH_ALIASES = ["scale conditioning", "scale prompt", "multiply conditioning", "multiply prompt"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
"multiplier": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01})
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "multiply"
|
||||
CATEGORY = "model/conditioning/transform"
|
||||
|
||||
def multiply(self, conditioning, multiplier):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
values = {}
|
||||
pooled_output = t[1].get("pooled_output", None)
|
||||
if pooled_output is not None:
|
||||
values["pooled_output"] = pooled_output * multiplier
|
||||
scaled = node_helpers.conditioning_set_values([[t[0] * multiplier, t[1]]], values)[0]
|
||||
c.append(scaled)
|
||||
return (c,)
|
||||
|
||||
class ConditioningSetArea:
|
||||
SEARCH_ALIASES = ["regional prompt", "area prompt", "spatial conditioning", "localized prompt"]
|
||||
|
||||
@ -326,7 +349,7 @@ class VAEDecodeTiled:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "experimental"
|
||||
CATEGORY = "model/latent"
|
||||
|
||||
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
|
||||
if tile_size < overlap * 4:
|
||||
@ -373,7 +396,7 @@ class VAEEncodeTiled:
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "experimental"
|
||||
CATEGORY = "model/latent"
|
||||
|
||||
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
|
||||
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||
@ -491,7 +514,7 @@ class SaveLatent:
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "experimental"
|
||||
CATEGORY = "model/latent"
|
||||
|
||||
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
@ -536,7 +559,7 @@ class LoadLatent:
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
|
||||
return {"required": {"latent": [sorted(files), ]}, }
|
||||
|
||||
CATEGORY = "experimental"
|
||||
CATEGORY = "model/latent"
|
||||
|
||||
RETURN_TYPES = ("LATENT", )
|
||||
FUNCTION = "load"
|
||||
@ -2050,6 +2073,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ConditioningAverage": ConditioningAverage,
|
||||
"ConditioningCombine": ConditioningCombine,
|
||||
"ConditioningConcat": ConditioningConcat,
|
||||
"ConditioningMultiply": ConditioningMultiply,
|
||||
"ConditioningSetArea": ConditioningSetArea,
|
||||
"ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
|
||||
"ConditioningSetAreaStrength": ConditioningSetAreaStrength,
|
||||
@ -2121,6 +2145,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ConditioningAverage ": "Conditioning (Average)",
|
||||
"ConditioningAverage": "Conditioning (Average)",
|
||||
"ConditioningConcat": "Conditioning (Concat)",
|
||||
"ConditioningMultiply": "Conditioning (Multiply)",
|
||||
"ConditioningSetArea": "Conditioning (Set Area)",
|
||||
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
||||
"ConditioningSetAreaStrength": "Conditioning (Set Area Strength)",
|
||||
@ -2130,6 +2155,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"GLIGENTextBoxApply": "Apply GLIGEN Text Box",
|
||||
"ConditioningZeroOut": "Conditioning Zero Out",
|
||||
# Latent
|
||||
"LoadLatent": "Load Latent",
|
||||
"SaveLatent": "Save Latent",
|
||||
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
||||
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
||||
"VAEDecode": "VAE Decode",
|
||||
@ -2164,7 +2191,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageSharpen": "Sharpen Image",
|
||||
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||
"GetImageSize": "Get Image Size",
|
||||
# experimental
|
||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.26.0"
|
||||
version = "0.27.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.45.19
|
||||
comfyui-workflow-templates==0.10.7
|
||||
comfyui-embedded-docs==0.5.5
|
||||
comfyui-frontend-package==1.45.20
|
||||
comfyui-workflow-templates==0.11.1
|
||||
comfyui-embedded-docs==0.5.6
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=16.0.0
|
||||
comfy-kitchen==0.2.12
|
||||
comfy-kitchen==0.2.16
|
||||
comfy-aimdo==0.4.10
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
@ -33,5 +33,5 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
PyOpenGL
|
||||
glfw
|
||||
PyOpenGL>=3.1.8
|
||||
comfy-angle
|
||||
|
||||
15
server.py
15
server.py
@ -16,6 +16,8 @@ from comfy_execution.jobs import (
|
||||
cancel_job,
|
||||
CANCEL_PENDING,
|
||||
CANCEL_RUNNING,
|
||||
parse_ids_filter,
|
||||
JobIdsFilterError,
|
||||
)
|
||||
import uuid
|
||||
import urllib
|
||||
@ -791,6 +793,7 @@ class PromptServer():
|
||||
Query parameters:
|
||||
status: Filter by status (comma-separated): pending, in_progress, completed, failed
|
||||
workflow_id: Filter by workflow ID
|
||||
ids: Filter by job id (comma-separated UUIDs, max 100)
|
||||
sort_by: Sort field: created_at (default), execution_duration
|
||||
sort_order: Sort direction: asc, desc (default)
|
||||
limit: Max items to return (positive integer)
|
||||
@ -800,6 +803,7 @@ class PromptServer():
|
||||
|
||||
status_param = query.get('status')
|
||||
workflow_id = query.get('workflow_id')
|
||||
ids_param = query.get('ids')
|
||||
sort_by = query.get('sort_by', 'created_at').lower()
|
||||
sort_order = query.get('sort_order', 'desc').lower()
|
||||
|
||||
@ -813,6 +817,16 @@ class PromptServer():
|
||||
status=400
|
||||
)
|
||||
|
||||
# Optional batch filter: narrow the result to a known set of job ids
|
||||
# (e.g. polling a submitted batch in one request). Parsing/validation
|
||||
# lives in parse_ids_filter so this handler and its tests share one
|
||||
# implementation. Absent => no filter; present-but-empty (`?ids=`,
|
||||
# `?ids=,,`) => zero matches, not "everything".
|
||||
try:
|
||||
ids_filter = parse_ids_filter(ids_param)
|
||||
except JobIdsFilterError as e:
|
||||
return web.json_response(e.payload, status=400)
|
||||
|
||||
if sort_by not in {'created_at', 'execution_duration'}:
|
||||
return web.json_response(
|
||||
{"error": "sort_by must be 'created_at' or 'execution_duration'"},
|
||||
@ -864,6 +878,7 @@ class PromptServer():
|
||||
running, queued, history,
|
||||
status_filter=status_filter,
|
||||
workflow_id=workflow_id,
|
||||
ids=ids_filter,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
limit=limit,
|
||||
|
||||
@ -1,137 +0,0 @@
|
||||
"""Tests for pre-execution validation that a node is actually executable.
|
||||
|
||||
validate_prompt rejects a node whose declared entry point does not resolve to a
|
||||
real method (a V1 FUNCTION typo, or a V3 node missing its execute override) before
|
||||
any node runs, attributing the error to the offending node.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
import nodes
|
||||
from comfy_api.latest import io
|
||||
from execution import node_not_executable_reason, validate_prompt
|
||||
|
||||
|
||||
class _GoodV1Node:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "run"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "Test"
|
||||
|
||||
def run(self):
|
||||
return (None,)
|
||||
|
||||
|
||||
class _TypoV1Node:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "invert" # method below is misspelled
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "Test"
|
||||
|
||||
def invvert(self):
|
||||
return (None,)
|
||||
|
||||
|
||||
class _SideEffectInitV1Node:
|
||||
"""Valid class-level method, but a constructor that must never run in validation."""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "run"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "Test"
|
||||
|
||||
def __init__(self):
|
||||
raise RuntimeError("__init__ must not run during validation")
|
||||
|
||||
def run(self):
|
||||
return (None,)
|
||||
|
||||
|
||||
def _v3_schema(node_id):
|
||||
return io.Schema(
|
||||
node_id=node_id,
|
||||
display_name=node_id,
|
||||
category="Test",
|
||||
inputs=[],
|
||||
outputs=[io.Image.Output()],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
|
||||
class _GoodV3Node(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return _v3_schema("GoodV3Node")
|
||||
|
||||
@classmethod
|
||||
def execute(cls):
|
||||
return io.NodeOutput(None)
|
||||
|
||||
|
||||
class _TypoV3Node(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return _v3_schema("TypoV3Node")
|
||||
|
||||
@classmethod
|
||||
def exicute(cls): # typo: should be "execute"
|
||||
return io.NodeOutput(None)
|
||||
|
||||
|
||||
def _register(class_type, class_def):
|
||||
nodes.NODE_CLASS_MAPPINGS[class_type] = class_def
|
||||
|
||||
|
||||
def _validate(class_type):
|
||||
prompt = {"1": {"class_type": class_type, "inputs": {}}}
|
||||
return asyncio.run(validate_prompt("pid", prompt, None))
|
||||
|
||||
|
||||
def test_good_node_passes():
|
||||
_register("GoodV1Node", _GoodV1Node)
|
||||
assert node_not_executable_reason(_GoodV1Node, "GoodV1Node") is None
|
||||
valid, _, _, _ = _validate("GoodV1Node")
|
||||
assert valid is True
|
||||
|
||||
|
||||
def test_typo_node_rejected_with_node_error():
|
||||
_register("TypoV1Node", _TypoV1Node)
|
||||
valid, error, _, node_errors = _validate("TypoV1Node")
|
||||
assert valid is False
|
||||
assert error["type"] == "invalid_node_definition"
|
||||
assert node_errors["1"]["class_type"] == "TypoV1Node"
|
||||
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"
|
||||
assert "invert" in node_errors["1"]["errors"][0]["details"]
|
||||
|
||||
|
||||
def test_validation_does_not_instantiate_node():
|
||||
"""A valid node is not constructed during validation, so __init__ never runs."""
|
||||
_register("SideEffectInitV1Node", _SideEffectInitV1Node)
|
||||
assert node_not_executable_reason(_SideEffectInitV1Node, "SideEffectInitV1Node") is None
|
||||
valid, _, _, _ = _validate("SideEffectInitV1Node")
|
||||
assert valid is True
|
||||
|
||||
|
||||
def test_good_v3_node_passes():
|
||||
_register("GoodV3Node", _GoodV3Node)
|
||||
assert node_not_executable_reason(_GoodV3Node, "GoodV3Node") is None
|
||||
valid, _, _, _ = _validate("GoodV3Node")
|
||||
assert valid is True
|
||||
|
||||
|
||||
def test_typo_v3_node_rejected_with_node_error():
|
||||
_register("TypoV3Node", _TypoV3Node)
|
||||
valid, error, _, node_errors = _validate("TypoV3Node")
|
||||
assert valid is False
|
||||
assert error["type"] == "invalid_node_definition"
|
||||
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"
|
||||
0
tests-unit/jobs_list_test/__init__.py
Normal file
0
tests-unit/jobs_list_test/__init__.py
Normal file
277
tests-unit/jobs_list_test/jobs_list_test.py
Normal file
277
tests-unit/jobs_list_test/jobs_list_test.py
Normal file
@ -0,0 +1,277 @@
|
||||
"""Tests for the ``ids`` batch filter on the jobs listing endpoint.
|
||||
|
||||
Covers both layers:
|
||||
|
||||
* the pure ``comfy_execution.jobs.get_all_jobs`` filtering logic (the ``ids``
|
||||
argument narrows the result, composes with ``status_filter``, and silently
|
||||
ignores ids that match nothing), and
|
||||
|
||||
* the HTTP contract of ``GET /api/jobs`` for the ``ids`` query parameter
|
||||
(a valid set narrows the response, an oversized set or a malformed id is
|
||||
rejected with 400).
|
||||
|
||||
The HTTP layer is exercised against a small aiohttp app whose handler calls the
|
||||
SAME ``parse_ids_filter`` that ``server.py`` uses (no hand-copied wiring, so it
|
||||
cannot drift), driven by a fake queue. This keeps the test free of the heavy
|
||||
ComfyUI runtime (torch, nodes, ...) while still testing the real parsing
|
||||
contract.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from comfy_execution.jobs import (
|
||||
JobStatus,
|
||||
JobIdsFilterError,
|
||||
MAX_JOB_IDS_FILTER,
|
||||
get_all_jobs,
|
||||
parse_ids_filter,
|
||||
)
|
||||
|
||||
# Canonical UUID ids (the endpoint validates UUID format).
|
||||
_UUID_A = "aaaaaaaa-aaaa-4aaa-aaaa-aaaaaaaaaaaa"
|
||||
_UUID_B = "bbbbbbbb-bbbb-4bbb-bbbb-bbbbbbbbbbbb"
|
||||
_UUID_C = "cccccccc-cccc-4ccc-cccc-cccccccccccc"
|
||||
_UUID_MISSING = "ffffffff-ffff-4fff-ffff-ffffffffffff"
|
||||
|
||||
|
||||
def make_queue_item(prompt_id, priority=0):
|
||||
"""Build a queue tuple shaped like the real ones (5 elements, id at index 1)."""
|
||||
return (priority, prompt_id, {}, {}, [])
|
||||
|
||||
|
||||
def make_history_item(status_str="success"):
|
||||
"""Build a history item dict shaped like the real ones."""
|
||||
return {
|
||||
"prompt": (0, "", {}, {}, []),
|
||||
"status": {"status_str": status_str, "messages": []},
|
||||
"outputs": {},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure get_all_jobs filtering logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_ids_filter_returns_only_requested():
|
||||
running = [make_queue_item(_UUID_A)]
|
||||
queued = [make_queue_item(_UUID_B)]
|
||||
history = {_UUID_C: make_history_item()}
|
||||
|
||||
jobs, total = get_all_jobs(running, queued, history, ids=[_UUID_A, _UUID_C])
|
||||
|
||||
returned = {j["id"] for j in jobs}
|
||||
assert returned == {_UUID_A, _UUID_C}
|
||||
assert total == 2
|
||||
assert _UUID_B not in returned
|
||||
|
||||
|
||||
def test_ids_filter_absent_returns_all():
|
||||
running = [make_queue_item(_UUID_A)]
|
||||
queued = [make_queue_item(_UUID_B)]
|
||||
history = {_UUID_C: make_history_item()}
|
||||
|
||||
jobs, total = get_all_jobs(running, queued, history)
|
||||
|
||||
assert {j["id"] for j in jobs} == {_UUID_A, _UUID_B, _UUID_C}
|
||||
assert total == 3
|
||||
|
||||
|
||||
def test_ids_filter_empty_list_returns_none():
|
||||
"""A present-but-empty ids list is a zero-match filter, not "no filter".
|
||||
|
||||
``None`` means "no id filter"; ``[]`` means "restrict to nothing".
|
||||
"""
|
||||
running = [make_queue_item(_UUID_A)]
|
||||
queued = [make_queue_item(_UUID_B)]
|
||||
|
||||
jobs, total = get_all_jobs(running, queued, {}, ids=[])
|
||||
|
||||
assert jobs == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
def test_ids_filter_unknown_id_silently_absent():
|
||||
"""An id that matches nothing is simply not present (no error)."""
|
||||
running = [make_queue_item(_UUID_A)]
|
||||
|
||||
jobs, total = get_all_jobs(running, [], {}, ids=[_UUID_A, _UUID_MISSING])
|
||||
|
||||
assert {j["id"] for j in jobs} == {_UUID_A}
|
||||
assert total == 1
|
||||
|
||||
|
||||
def test_ids_filter_composes_with_status():
|
||||
"""ids only narrows; it composes with the status filter."""
|
||||
running = [make_queue_item(_UUID_A)]
|
||||
queued = [make_queue_item(_UUID_B)]
|
||||
history = {_UUID_C: make_history_item()}
|
||||
|
||||
# Request A and C by id, but restrict to in_progress only -> just A.
|
||||
jobs, total = get_all_jobs(
|
||||
running, queued, history,
|
||||
status_filter=[JobStatus.IN_PROGRESS],
|
||||
ids=[_UUID_A, _UUID_C],
|
||||
)
|
||||
|
||||
assert {j["id"] for j in jobs} == {_UUID_A}
|
||||
assert total == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_ids_filter -- the shared parsing/validation (server.py + these tests)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_ids_absent_is_none():
|
||||
assert parse_ids_filter(None) is None
|
||||
|
||||
|
||||
def test_parse_ids_present_but_empty_is_empty_list():
|
||||
# `?ids=` and `?ids=,,` parse to [] -> zero-match filter, not None.
|
||||
assert parse_ids_filter("") == []
|
||||
assert parse_ids_filter(",,") == []
|
||||
|
||||
|
||||
def test_parse_ids_dedupes_preserving_order():
|
||||
assert parse_ids_filter(f"{_UUID_A},{_UUID_B},{_UUID_A}") == [_UUID_A, _UUID_B]
|
||||
|
||||
|
||||
def test_parse_ids_cap_counts_distinct_not_duplicates():
|
||||
# A small distinct set repeated far past the cap is still under it.
|
||||
repeated = ",".join([_UUID_A, _UUID_B] * MAX_JOB_IDS_FILTER)
|
||||
assert parse_ids_filter(repeated) == [_UUID_A, _UUID_B]
|
||||
# But more than MAX distinct ids is rejected.
|
||||
distinct = ",".join(
|
||||
f"{i:08d}-0000-4000-8000-000000000000" for i in range(MAX_JOB_IDS_FILTER + 1)
|
||||
)
|
||||
with pytest.raises(JobIdsFilterError):
|
||||
parse_ids_filter(distinct)
|
||||
|
||||
|
||||
def test_parse_ids_invalid_raises_with_payload():
|
||||
with pytest.raises(JobIdsFilterError) as exc:
|
||||
parse_ids_filter(f"{_UUID_A},not-a-uuid")
|
||||
assert "not-a-uuid" in exc.value.payload["invalid_ids"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP contract for the ids query parameter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FakePromptQueue:
|
||||
"""Minimal stand-in exposing the accessors get_jobs uses."""
|
||||
|
||||
def __init__(self, running=None, queued=None, history=None):
|
||||
self._running = list(running or [])
|
||||
self._queued = list(queued or [])
|
||||
self._history = dict(history or {})
|
||||
|
||||
def get_current_queue_volatile(self):
|
||||
return (list(self._running), list(self._queued))
|
||||
|
||||
def get_history(self):
|
||||
return dict(self._history)
|
||||
|
||||
|
||||
def make_app(prompt_queue):
|
||||
"""Build an aiohttp app whose handler calls the REAL parse_ids_filter.
|
||||
|
||||
No hand-copied parsing wiring, so this test cannot stay green while the
|
||||
shipped parsing in server.py regresses -- both go through parse_ids_filter.
|
||||
"""
|
||||
|
||||
async def get_jobs(request):
|
||||
try:
|
||||
ids_filter = parse_ids_filter(request.rel_url.query.get('ids'))
|
||||
except JobIdsFilterError as e:
|
||||
return web.json_response(e.payload, status=400)
|
||||
|
||||
running, queued = prompt_queue.get_current_queue_volatile()
|
||||
history = prompt_queue.get_history()
|
||||
|
||||
jobs, total = get_all_jobs(running, queued, history, ids=ids_filter)
|
||||
|
||||
return web.json_response({
|
||||
'jobs': jobs,
|
||||
'pagination': {'total': total},
|
||||
})
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get('/api/jobs', get_jobs)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue():
|
||||
return FakePromptQueue(
|
||||
running=[make_queue_item(_UUID_A)],
|
||||
queued=[make_queue_item(_UUID_B)],
|
||||
history={_UUID_C: make_history_item()},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_ids_filter_narrows(aiohttp_client, queue):
|
||||
client = await aiohttp_client(make_app(queue))
|
||||
|
||||
resp = await client.get(f"/api/jobs?ids={_UUID_A},{_UUID_C}")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_C}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_ids_unknown_id_is_not_an_error(aiohttp_client, queue):
|
||||
client = await aiohttp_client(make_app(queue))
|
||||
|
||||
resp = await client.get(f"/api/jobs?ids={_UUID_A},{_UUID_MISSING}")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert {j["id"] for j in body["jobs"]} == {_UUID_A}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_ids_over_limit_returns_400(aiohttp_client, queue):
|
||||
client = await aiohttp_client(make_app(queue))
|
||||
|
||||
# Distinct ids past the cap. (Repeats of one id are de-duped and would NOT
|
||||
# trip the cap -- see test_parse_ids_cap_counts_distinct_not_duplicates.)
|
||||
too_many = ",".join(
|
||||
f"{i:08d}-0000-4000-8000-000000000000" for i in range(MAX_JOB_IDS_FILTER + 1)
|
||||
)
|
||||
resp = await client.get(f"/api/jobs?ids={too_many}")
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_ids_invalid_id_returns_400(aiohttp_client, queue):
|
||||
client = await aiohttp_client(make_app(queue))
|
||||
|
||||
resp = await client.get(f"/api/jobs?ids={_UUID_A},not-a-uuid")
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert "not-a-uuid" in body["invalid_ids"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_ids_absent_returns_all(aiohttp_client, queue):
|
||||
client = await aiohttp_client(make_app(queue))
|
||||
|
||||
resp = await client.get("/api/jobs")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_B, _UUID_C}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_ids_present_but_empty_returns_none(aiohttp_client, queue):
|
||||
"""`?ids=` (present but empty) is a zero-match filter, not "return all"."""
|
||||
client = await aiohttp_client(make_app(queue))
|
||||
|
||||
resp = await client.get("/api/jobs?ids=")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["jobs"] == []
|
||||
Reference in New Issue
Block a user