Compare commits

..

31 Commits

Author SHA1 Message Date
a78019266f Free all old model patchers, gc.collect(), then free tensors. 2026-07-03 19:25:28 -04:00
f5c4bb1f02 Try to fix issue with ram cache for some users. 2026-07-03 17:18:11 -04:00
1073a74976 [Partner Nodes] chore(ByteDance): adjust category name (#14752)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-07-04 00:01:05 +03:00
de1b8f3e8d Update AGENTS.md (#14738) 2026-07-03 13:08:24 -07:00
77917ed3a6 [Partner Nodes] chore(StabilityAI): remove StabilityAI nodes (#14737)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-07-03 14:24:21 +03:00
a04ebe05c2 chore: update workflow templates to v0.11.2 (#14741) 2026-07-03 19:08:11 +08:00
9764381998 [Partner Nodes] feat(ByteDance): add support for Seed Audio 1.0 (#14731)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-07-03 14:00:10 +03:00
1e04ced089 Update AGENTS.md (#14733) 2026-07-03 02:08:47 -04:00
96e0e3585b security: fix four vulnerabilities (GHSA-779p-m5rp-r4h4) (#14734)
* security: fix five vulnerabilities (GHSA-779p-m5rp-r4h4)

- CVE-2026-56670: force download of SVG/XML responses on /view to prevent stored XSS
- CVE-2026-56671: contain /experiment/models/preview reads within the model folder
- CVE-2026-56672: stop inline rendering of uploaded /userdata/{file} content
- CVE-2026-56673: prevent path traversal in get_annotated_filepath (LoadImage /prompt input)
- CVE-2026-56674: reject opaque/null Origin to close the CSRF middleware bypass

Adds regression tests under tests-unit/security_test/ covering all five.

* security: address review feedback on GHSA-779p fixes

- Fix Windows CI failure in test_get_annotated_filepath: compare against
  os.path.abspath(...) to match the intentional abspath normalization added
  by the traversal hardening (abspath prepends the drive letter on Windows).
- origin_check: narrow the bare `except:` in is_loopback() to ValueError so
  genuine interrupts aren't swallowed (review nit).
- origin_check: guard .port access in is_cross_origin_forbidden() so a
  malformed/out-of-range port (e.g. Origin: http://127.0.0.1:99999) fails
  closed with a 403 instead of surfacing an uncaught 500 in the middleware.
- server /view: escape backslash/quote in the Content-Disposition filename
  (RFC 6266 quoted-string) so a filename containing a double quote can't
  malform the response header.

* security: address CodeRabbit review feedback on GHSA-779p tests

- test #3: guard the symlink-escape test with a try/except skip so it no
  longer errors on Windows CI where os.symlink needs elevated privileges /
  Developer Mode (mirrors the guard in the sibling test #2).
- test #5: refresh the stale module docstring to describe the actual /view
  gating (view_image closure calling folder_paths.is_dangerous_content_type,
  the normalising check) instead of the bypassable raw set-membership test.

* revert(security): drop CVE-2026-56674 Origin: null CSRF change

Per maintainer review, the reported CSRF is already mitigated by the pre-existing
Sec-Fetch-Site: cross-site check for current browsers, and the null-origin
rejection risked breaking legitimate sandboxed-iframe embeds. Restores
origin_only_middleware and is_loopback in server.py to their prior state
(the Sec-Fetch-Site check is retained) and removes utils/origin_check.py and its
regression test. The other four GHSA-779p fixes are unaffected.
2026-07-02 20:44:54 -07:00
35c1470935 Update AGENTS.md (#14726) 2026-07-02 15:05:55 -04:00
694815f498 [Partner Nodes] chore(Ideogram): remove IdeogramV1 and IdeogramV2 nodes (#14712)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
Co-authored-by: Alexis Rolland <alexisrolland@hotmail.com>
2026-07-02 08:35:11 +03:00
92594ca84c Update AGENTS.md with more stuff. (#14725) 2026-07-01 21:55:13 -04:00
2c935de1b1 Fix Qwen3-VL tokenizer crash with custom embeddings (#14713) 2026-07-01 21:15:07 +03:00
dd17debce5 Add some more stuff to AGENTS.md (#14704) 2026-07-01 01:51:51 -04:00
50e5270b86 Add AGENTS.md (#14696) 2026-06-30 17:40:33 -04:00
bb131be9e8 ComfyUI v0.27.0 2026-06-30 17:36:02 -04:00
6fca64780c chore: update workflow templates to v0.11.1 (#14698) 2026-06-30 14:28:09 -07:00
6e11828d10 chore: Update nodes categories (#14674) 2026-07-01 05:20:20 +08:00
b70944e710 [Partner Nodes] feat(Google): add Gemini Video Omni node (#14695) 2026-06-30 17:17:53 -04:00
1c59659a2f feat: make asset hashing opt-in via --enable-asset-hashing, off by default (#14663)
Add a --enable-asset-hashing CLI flag (action=store_true, default False)
and plumb it into the two asset-seeder call sites in main.py that
previously hardcoded compute_hashes=True (the startup scan and the
post-job output enqueue). Local runs now skip blake3 hashing unless the
user opts in, avoiding the startup/per-output cost on large models
directories while keeping hashing available for asset-portability
features.

Co-authored-by: Alexis Rolland <alexisrolland@hotmail.com>
2026-06-30 14:13:20 -07:00
d395813bcd Fix memory leak related to int8. (#14697) 2026-06-30 14:08:59 -07:00
8fe0243d97 [Partner Nodes] feat(Google): add Nano Banana 2 Lite model (#14693)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-06-30 11:17:23 -07:00
ba3f697dbb Add ConditioningMultiply node to nodes.py as an addition to other adj… (#14686) 2026-06-30 16:27:09 +08:00
510ed5c384 Bump comfyui-frontend-package to 1.45.20 (#14684) 2026-06-30 16:25:03 +08:00
7851410511 Better and faster int8 lora applying. (#14685) 2026-06-29 21:52:08 -04:00
a58473fd9b chore: update embedded docs to v0.5.6 (#14668)
Co-authored-by: Alexis Rolland <alexisrolland@hotmail.com>
2026-06-29 17:08:06 +08:00
79c555ce6b Fix int8 mm being skipped on offloaded lora weights. (#14669) 2026-06-28 23:52:36 -04:00
f19735759e ci: add team-gated Cursor review (thin caller for github-workflows) (#14527) 2026-06-27 23:34:30 -07:00
a95e461916 int8 support on turing GPUs. (#14662) 2026-06-27 15:53:11 -07:00
603d891eaf Update GLSL node to use ANGLE library (CORE-162) (#13195) 2026-06-27 08:40:31 +08:00
470ac36a0a Fix int8 loras causing lower quality requant with wrong settings. (#14650)
* Update comfy-kitchen

* Support requantizing with same settings as orig quant.
2026-06-26 16:41:29 -07:00
40 changed files with 2047 additions and 1956 deletions

38
.github/workflows/ci-cursor-review.yml vendored Normal file
View File

@ -0,0 +1,38 @@
name: CI - Cursor Review
# Thin caller for the shared reusable cursor-review workflow in
# Comfy-Org/github-workflows. The review logic (panel matrix, judge
# consolidation, prompts, extract/post/notify scripts) lives there as the
# single source of truth, so this repo only carries the repo-specific diff
# excludes.
on:
pull_request:
types: [labeled, unlabeled]
concurrency:
group: cursor-review-pr-${{ github.event.pull_request.number }}-${{ github.event.label.name }}
cancel-in-progress: true
jobs:
cursor-review:
if: github.event.label.name == 'cursor-review'
permissions:
contents: read
pull-requests: write
# SHA-pinned per zizmor `unpinned-uses: hash-pin`. Bump this SHA to pick up
# upstream changes; keep `workflows_ref` matching so prompts/scripts load
# from the same commit as the workflow definition.
uses: Comfy-Org/github-workflows/.github/workflows/cursor-review.yml@047ca48febe3a6647608ed2e0c4331b491cb9d6a # github-workflows#9
with:
workflows_ref: 047ca48febe3a6647608ed2e0c4331b491cb9d6a
diff_excludes: >-
:!**/.claude/**
:!**/dist/**
:!**/vendor/**
:!**/*.generated.*
:!**/*.min.js
:!**/*.min.css
secrets:
CURSOR_API_KEY: ${{ secrets.CURSOR_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}

294
AGENTS.md Normal file
View File

@ -0,0 +1,294 @@
## Engineering Style
- Keep changes small and direct. Most fixes should touch the narrowest code path
that explains the bug, performance issue, dtype issue, model-format issue, or
user-facing behavior.
- Change the least amount of files possible. A change that touches many files is
more likely to be a bad change than a good one unless the broader scope is
directly required.
- Prefer practical fixes over broad architecture work. Add abstractions only
when they remove real repeated logic or match an existing ComfyUI pattern.
- Prefer fewer dependencies. Do not add new dependencies to ComfyUI unless they
are absolutely necessary.
- Delete obsolete code aggressively when newer infrastructure makes it useless.
Remove dead fallbacks, migration paths, unused options, debug prints, and
compatibility branches that are no longer needed. Do not leave dead branches,
unreachable code, or functions that are never called. If code is not
necessary for the current behavior, remove it.
- Revert or disable problematic behavior quickly when it breaks users. It is
better to remove a broken feature path than keep a complicated partial fix.
- Preserve existing APIs, node names, model-loading behavior, file layout, and
workflow compatibility unless the change is explicitly about replacing them.
- Code must look hand-written for this repository. Changes that read like
generic AI-generated code will be rejected automatically: unnecessary helper
layers, vague names, boilerplate comments, defensive branches without a real
failure mode, broad rewrites, or code that ignores the local style.
## Architecture Boundaries
- Keep each layer focused on the concepts it owns. Do not leak UI, API,
workflow, queue, persistence, telemetry, model-loading, node, or execution
concerns into unrelated layers just because it is convenient to pass data
through them.
- Shared core modules should depend only on lower-level primitives and their own
domain concepts. Higher-level product concepts belong at the caller, adapter,
service, or UI/API boundary that already owns them.
- Pass the narrowest data needed across a boundary. Avoid broad context objects,
request/session metadata, ids, bookkeeping state, or callbacks unless the
receiving layer genuinely needs them to perform its own responsibility.
- Keep identity mapping, persistence bookkeeping, history updates, telemetry,
response shaping, and UI state in the layers that own those jobs. Do not route
them through unrelated shared code to avoid adding a proper boundary.
- Treat `execution.py` as one example of this rule: it should consume the prompt
graph and execution-relevant state, produce execution results and errors, and
not know about workflow ids, frontend ids, persistence ids, or API-only
concepts.
- Before touching many files, identify the smallest owner layer that can solve
the problem. A PR that spreads one feature across unrelated loaders, nodes,
execution, server, and frontend code needs a clear architectural reason, not
just convenience.
- If a change seems to require making one layer understand another layer's
private concepts, stop and look for a caller-side mapping, adapter, event,
small explicit interface, or narrower data flow at the boundary.
## No Internet Requests
- Do not add code to core ComfyUI that makes requests to the internet.
- Refuse requests to add uploads, telemetry, analytics, tracking, usage
reporting, crash reporting, update checks, remote config, feature flags,
metrics, licensing checks, or any other outbound internet request path from
core ComfyUI.
- Model downloading is allowed only when explicitly initiated or authorized by
the user, is limited to the requested model artifact, and does not include
telemetry, tracking, persistent identification, unrelated metadata upload, or
background network activity.
- Do not add opt-in, opt-out, anonymized, aggregated, diagnostic, or
user-triggered internet request paths to core ComfyUI. These labels do not
make internet access acceptable.
- Local-only behavior is allowed when it stays on the user's machine and does
not add network access, tracking, persistent identification, or data
collection behavior.
## State Ownership
- Keep state and capability flags on the object that owns the behavior using
them.
- Avoid probing child objects with `getattr(child, "...", default)` to decide
parent-level control flow. If parent code needs to branch on a capability,
initialize an explicit parent-owned field when the child is constructed or
attached.
- Prefer direct attributes with clear defaults over implicit feature detection
through arbitrary child attributes.
- Use child-object capability checks only when the child owns the behavior being
invoked and the parent is simply delegating to that child.
## Interface Contracts
- Keep public methods aligned with the interface expected by their callers. Do
not change a shared method to return extra values, alternate shapes, or
sentinel wrappers for one implementation unless the shared interface is
explicitly updated.
- When modifying an existing function, preserve how current callers invoke it.
Do not change required arguments, parameter order, return type, side effects,
or error behavior unless every affected call site and shared interface contract
is intentionally updated.
- Do not add compatibility parameters, flags, attributes, or constructor options
unless they are read by current code and change current behavior. Remove
pass-through or stored-but-unused values instead of preserving upstream or
deprecated API baggage.
- If an implementation needs auxiliary values for its own workflow, expose them
through a private helper or a clearly named implementation-specific method
instead of overloading the public method's return contract.
- Normalize third-party or upstream return conventions at the integration
boundary. Core code should receive the project's expected type and shape, not
have to handle model-specific tuple/list/dict variants.
- Avoid caller-side unwrapping such as `out = out[0]` unless the called
interface is documented to return that structure.
## Autograd and Model Freezing
- Do not add `torch.no_grad`, `torch.inference_mode`, or inference-mode helper
wrappers in ComfyUI code. The only allowed inference-mode-related use is
disabling a globally set inference mode when a training path needs gradients.
- Do not add freeze, unfreeze, or trainability toggles to model classes. ComfyUI
models are always treated as frozen for inference, so explicit freeze
functionality is redundant and should not be added.
- Remove training-only behavior such as dropout from inference model code, but
preserve checkpoint and state-dict compatibility when doing so. If deleting a
module would change state-dict keys, module ordering, or checkpoint loading
behavior, replace it with a no-op such as `nn.Identity` instead of removing the
slot outright.
## Python Style
- Keep imports at module scope. Avoid inline imports unless they are already part
of an established optional-backend probe or are needed to avoid an import
cycle.
- Do not add unnecessary `try`/`except` blocks. Use them for optional dependency,
platform, or backend capability detection only when the program has a useful
fallback. Prefer specific exception types when changing new code.
- Remove any workarounds for PyTorch versions that ComfyUI no longer officially
supports. Deprecated workarounds include catching an exception and rerunning
the same op with the input cast to float. If a workaround does not have a
comment naming the exact PyTorch version or versions that still need it,
remove it.
- Let unsupported model formats, invalid quantization metadata, and bad states
fail with clear errors instead of silently producing lower quality output.
- Match the existing local style in the file you edit. This codebase tolerates
long lines, simple helper functions, module-level state, and direct tensor
operations when they make the code easier to follow.
- Keep comments sparse and useful. Strip useless comments that restate the code
or describe obvious behavior. Short TODOs are fine when they name the concrete
missing follow-up.
## Model, Device, and Memory Behavior
- Treat dtype, device placement, VRAM usage, and offloading behavior as core
correctness concerns. Check CPU, CUDA, ROCm, MPS, DirectML, XPU, NPU, and low
VRAM implications when touching shared execution or loading code.
- Prefer native ComfyUI formats and existing quantization/offload helpers over
adding parallel code paths. Use `comfy.quant_ops`, `comfy.model_management`,
`comfy.memory_management`, `comfy.pinned_memory`, `comfy_aimdo`, and
`comfy-kitchen` helpers where they already solve the problem.
- Use optimized comfy-kitchen ops in places where they improve performance
without changing the expected dtype, device, memory, or interface behavior.
- All models should use the optimized attention function selected by ComfyUI.
Treat optimized backend functions, dispatch helpers, and capability-selected
callables as opaque. Higher-level code must not inspect function identity,
names, modules, or implementation details to decide behavior.
- Apply the same opacity rule to similar patterns beyond attention: callers
should depend on the documented interface and result contract, not on which
backend implementation was selected underneath.
- Do not use custom inference ops that only duplicate an existing op while
upcasting to float32, such as custom RMSNorm variants. Use the generic ComfyUI
ops and/or native torch ops instead.
- If a model class `__init__` has an `operations` parameter, assume
`operations` is never `None`. Do not add fallback branches or default torch
ops for a missing `operations` object.
- Do not add unnecessary parameters to model, model block, or model ops related
classes. Constructor and forward signatures should carry only values that are
actually needed by that object for inference.
- Reuse existing model classes, blocks, ops, and helper modules when appropriate.
Before implementing a new version of a model component, search the existing
model code for a class or helper that already provides the behavior.
- Model detection code that inspects linear weight shapes should only use the
first dimension. The second dimension may be half the original size for
NVFP4 or other 4-bit quantized models.
- Avoid adding `einops` usage in core inference code. Use native torch tensor
ops such as `reshape`, `view`, `permute`, `transpose`, `flatten`, `unflatten`,
`unsqueeze`, and `squeeze` instead.
- Do not use tensors as general-purpose Python data structures. Keep metadata,
bookkeeping, counters, flags, shape math, padding math, index planning, memory
estimates, and control-flow decisions in plain Python values unless the data
must participate directly in tensor computation. Do not create tensors for
structural metadata that is only used for Python-side control flow. Sequence
lengths, cumulative offsets, split indices, window counts, slice boundaries,
and repeat counts should be kept as Python ints/lists from the point they are
computed. Do not build them as CPU/GPU tensors and then cast, move, validate,
or convert them back to Python for `split`, `tensor_split`, indexing plans,
loops, or cache keys. Avoid creating temporary tensors just to use tensor
methods for scalar or structural calculations.
- Avoid unnecessary casts and transfers. Preserve the intended compute dtype,
storage dtype, bias dtype, and original tensor shape metadata.
- Keep model-native latent layout handling inside the model or latent-format
owner, not in helper nodes. Do not collapse, expand, pack, or unpack latent
dimensions in nodes or other caller-side adapters just to satisfy a model
forward; the model path should consume and return the native latent shape for
that model family.
- Assume inputs to the main model forward are already in the compute dtype by
default, except integer inputs such as some model timestep tensors. Do not add
defensive or convenience casts in model code; it is better for invalid dtype
plumbing to error clearly than to hide it with unnecessary casts.
- Raw model parameters that are not owned by an op and may be initialized in a
dtype different from the compute dtype should be cast at use in forward or
inference code with `comfy.ops.cast_to_input` or
`comfy.model_management.cast_to` to avoid dtype mismatches.
- Model code should not care what dtype it is initialized in, and model
`__init__` methods should not contain workarounds for specific dtypes. Dtype
workaround code, such as making a model work with fp16 compute, belongs in the
execution or model-management layer that owns compute policy.
- Model code should not perform unnecessary device-to-CPU or CPU-to-device
transfers. New allocations must be created on the correct device and dtype;
never allocate on CPU and then move to GPU, or allocate in one dtype and then
convert to another.
- Model code itself should not perform memory management. Loading, unloading,
offloading, device movement, VRAM policy, cache lifetime, and cleanup belong
in the relevant model-management and execution layers, not inside model
implementations.
- Do not add global, module-level, class-level, singleton, or model-owned stores
for tensors or other large memory that persist across executions. Temporary
caches must be scoped to a single execution or forward/encode/decode call:
allocate them in the owning top-level call, pass them explicitly through the
call stack, and let them be discarded when that call returns.
- Follow the Wan VAE temporal cache pattern for temporary caches: create a local
cache such as `feat_map` for the encode/decode operation, pass it into the
blocks that need it, and do not retain it on the model or in global state.
- In model init code, prefer `torch.empty` for parameter/buffer placeholders
that are populated from the model state dict instead of zero-initializing with
`torch.zeros` or similar. If an allocation is not loaded from the state dict
and is useless for inference, do not include it.
- `nn.Parameter` tensors that are stored in and populated from the model state
dict should be initialized with `torch.empty`, not with zero, random, or
otherwise meaningful initialization.
- Model initialization should describe module structure, not fabricate
checkpoint-owned tensor contents. Parameters and buffers that are loaded from
the state dict must not be manually initialized, reassigned, or filled with
fallback values unless that value is actually used when no checkpoint key
exists.
- When slicing large tensors, copy the slice if the sliced tensor's lifetime
exceeds the current function scope. Do not keep a long-lived view into a large
backing tensor when a smaller copy would release memory sooner.
- Use fused or compound torch operations such as `addcmul` when they naturally
match the math. Reducing Python and torch dispatch overhead is a valid
optimization when it does not obscure the code or change dtype/device
behavior.
- Avoid caches that persist across different executions as much as possible.
Persistent caches are acceptable only when they use a very minimal amount of
memory and have a clear ownership and invalidation story.
- When optimizing, favor small measurable changes: fewer allocations, fewer
device transfers, less peak memory, better batching, or use of a faster
existing backend op.
## Nodes and User-Facing Behavior
- Follow existing node conventions: `INPUT_TYPES`, `RETURN_TYPES`, `FUNCTION`,
`CATEGORY`, and registration through the local mapping used by that file.
- Keep node changes backward compatible by default. Add inputs with sensible
defaults and avoid changing output types unless the request requires it.
- Model implementations should add the minimal number of ComfyUI nodes required
to run the model. Reuse existing nodes as much as possible; adapting the model
to work with existing nodes is strongly preferred over creating new nodes.
- Nodes should output only values they own. Do not add pass-through outputs for
workflow convenience unless the node is explicitly an output node. Existing
models, latents, conditioning, or other inputs should flow directly to the
next consumer instead of being re-emitted unchanged.
- Nodes should expose only inputs they actually read to produce current
behavior. Do not add placeholder, pass-through, compatibility, or
workflow-shaping inputs that are ignored or could flow directly to another
node.
- Node-level code must not patch model code directly. Any node behavior that
modifies, wraps, hooks, or changes model behavior must go through the model
patcher class instead of reaching into model internals.
- The official mascot of ComfyUI is a very cute anime girl with massive fennec
ears, a big fluffy tail, long blonde wavy hair, and blue eyes. Feel free to
use her in ComfyUI materials, UI text, examples, tests, generated assets, or
comments, but do not disrespect her.
- Warning and info messages should be short and actionable. Remove noisy or
misleading messages rather than adding more logging.
- Documentation and README edits should be concise, factual, and tied to the
changed behavior.
## Commit and Review Habits
- If asked to write commit messages, use short direct subjects like the existing
history: `Fix ...`, `Add ...`, `Support ...`, `Remove ...`, `Update ...`,
`Make ...`, `Use ...`, `Disable ...`, `Bump ...`, or `Revert ...`.
- Keep PR descriptions short and reviewable. State the problem, the behavioral
change, and the tests run; avoid long narrative explanations, implementation
diaries, or exhaustive file-by-file summaries unless the reviewer explicitly
needs that context.
- Prefer one coherent behavioral change per commit. Dependency pins, tests, and
the code that needs them may be in the same commit when they are inseparable.
- In reviews, prioritize real user impact: crashes, wrong dtype/device behavior,
memory regressions, broken model loading, workflow incompatibility, and noisy
or misleading user-facing output.

View File

@ -306,12 +306,15 @@ async def download_asset_content(request: web.Request) -> web.Response:
404, "FILE_NOT_FOUND", "Underlying file not found on disk."
)
_DANGEROUS_MIME_TYPES = {
"text/html", "text/html-sandboxed", "application/xhtml+xml",
"text/javascript", "text/css",
}
if content_type in _DANGEROUS_MIME_TYPES:
# User-controlled asset content must never render inline in the app origin
# (stored XSS via SVG/HTML/XML). Force dangerous types to download and
# override any requested inline disposition. Centralised through
# folder_paths.is_dangerous_content_type so this can't drift from /view and
# /userdata (the previous inline set here omitted image/svg+xml and missed
# the charset/casing/+xml-dialect bypasses).
if folder_paths.is_dangerous_content_type(content_type):
content_type = "application/octet-stream"
disposition = "attachment"
safe_name = (filename or "").replace("\r", "").replace("\n", "")
encoded = urllib.parse.quote(safe_name)

View File

@ -50,21 +50,45 @@ class ModelFileManager:
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
async def get_model_preview(request):
folder_name = request.match_info.get("folder", None)
path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None)
if folder_name not in folder_paths.folder_names_and_paths:
return web.Response(status=404)
# The "{filename:.*}" capture also matches the empty string, which
# would resolve to the folder itself; reject it explicitly.
if not filename:
return web.Response(status=400)
try:
path_index = int(request.match_info.get("path_index", None))
except (TypeError, ValueError):
return web.Response(status=400)
folders = folder_paths.folder_names_and_paths[folder_name]
if path_index < 0 or path_index >= len(folders[0]):
return web.Response(status=404)
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)
full_filename = os.path.normpath(os.path.join(folder, filename))
# Prevent path traversal: the requested file must stay within the
# configured model folder. `filename` is an unrestricted ".*" capture,
# so values like "../../../../etc/passwd" would otherwise escape it.
if not folder_paths.is_within_directory(folder, full_filename):
return web.Response(status=403)
previews = self.get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)
# The preview is selected by a glob inside get_model_previews, so a
# companion file (e.g. "model.preview.png") could itself be a symlink
# resolving outside the model folder. Re-validate the file actually
# opened: is_within_directory realpaths it, catching symlink escape.
if isinstance(default_preview, str) and not folder_paths.is_within_directory(folder, default_preview):
return web.Response(status=403)
try:
with Image.open(default_preview) as img:
img_bytes = BytesIO()

View File

@ -6,6 +6,7 @@ import glob
import shutil
import logging
import tempfile
import mimetypes
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
@ -336,7 +337,20 @@ class UserManager():
if not isinstance(path, str):
return path
return web.FileResponse(path)
# User data files are arbitrary user-supplied content and are never
# meant to render inline. Disable MIME sniffing and force a download
# so uploaded markup/scripts can't execute in the app origin (stored
# XSS). Content-Disposition: attachment is the load-bearing guard;
# the content-type override and nosniff are defence in depth.
content_type = mimetypes.guess_type(path)[0] or 'application/octet-stream'
if folder_paths.is_dangerous_content_type(content_type):
content_type = 'application/octet-stream'
return web.FileResponse(path, headers={
"Content-Type": content_type,
"X-Content-Type-Options": "nosniff",
"Content-Disposition": "attachment",
})
@routes.post("/userdata/{file}")
async def post_userdata(request):

View File

@ -240,6 +240,7 @@ database_default_path = os.path.abspath(
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
parser.add_argument("--enable-asset-hashing", action="store_true", help="Compute blake3 content hashes when scanning assets. Hashing enables future asset-portability features (deduplication, cross-machine model resolution) but adds startup cost and per-output cost on large models directories. Off by default; enable to opt in.")
parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")

View File

@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
if (want_requant and len(fns) == 0 or update_weight):
seed = comfy.utils.string_to_seed(s.seed_key)
if isinstance(orig, QuantizedTensor):
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed)
else:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
if want_requant and len(fns) == 0:
@ -1216,7 +1216,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
bias_dtype=input.dtype,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=want_requant,
want_requant=True,
)
weight = weight.to(dtype=input.dtype)
else:
@ -1306,8 +1306,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
# dtype is now implicit in the layout class
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
else:
weight = weight.to(self.weight.dtype)
if return_weight:

View File

@ -167,7 +167,7 @@ class Qwen3VLTokenizer(sd1_clip.SD1Tokenizer):
embed_count = 0
for r in tokens[key_name]:
for i in range(len(r)):
if r[i][0] == 151655: # <|image_pad|>
if isinstance(r[i][0], (int, float)) and r[i][0] == 151655: # <|image_pad|>
if len(images) > embed_count:
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1

View File

@ -30,11 +30,6 @@ CLI_FEATURE_FLAG_REGISTRY: dict[str, FeatureFlagInfo] = {
"default": False,
"description": "Signal the frontend that telemetry collection is enabled",
},
"show_version_updates": {
"type": "bool",
"default": False,
"description": "Default for whether the frontend shows new-release/version-update notifications (the user setting can still override it)",
},
}

View File

@ -1,4 +1,4 @@
from typing import Literal
from typing import Any, Literal
from pydantic import BaseModel, Field
@ -316,3 +316,36 @@ VIDEO_TASKS_EXECUTION_TIME = {
"1080p": 150,
},
}
class SeedAudioConfig(BaseModel):
format: str = Field(default="mp3")
sample_rate: int = Field(default=24000)
speech_rate: int = Field(default=0)
loudness_rate: int = Field(default=0)
pitch_rate: int = Field(default=0)
class SeedAudioReference(BaseModel):
speaker: str | None = Field(default=None)
audio_data: str | None = Field(default=None)
audio_url: str | None = Field(default=None)
image_data: str | None = Field(default=None)
image_url: str | None = Field(default=None)
class SeedAudioRequest(BaseModel):
model: str = Field(default="seed-audio-1.0")
text_prompt: str = Field(...)
references: list[SeedAudioReference] | None = Field(default=None)
audio_config: SeedAudioConfig = Field(default_factory=SeedAudioConfig)
watermark: dict[str, Any] = Field(default_factory=dict)
class SeedAudioResponse(BaseModel):
audio: str | None = Field(default=None)
url: str | None = Field(default=None)
duration: float | None = Field(default=None)
original_duration: float | None = Field(default=None)
code: int | None = Field(default=None)
message: str | None = Field(default=None)

View File

@ -121,6 +121,7 @@ class GeminiGenerationConfig(BaseModel):
topK: int | None = Field(None, ge=1)
topP: float | None = Field(None, ge=0.0, le=1.0)
thinkingConfig: GeminiThinkingConfig | None = Field(None)
responseModalities: list[str] | None = Field(None)
class GeminiImageOutputOptions(BaseModel):

View File

@ -33,53 +33,6 @@ class IdeogramColorPalette(
)
class ImageRequest(BaseModel):
aspect_ratio: Optional[str] = Field(
None,
description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.",
)
color_palette: Optional[Dict[str, Any]] = Field(
None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.'
)
magic_prompt_option: Optional[str] = Field(
None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')."
)
model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')")
negative_prompt: Optional[str] = Field(
None,
description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.',
)
num_images: Optional[int] = Field(
1,
description='Optional. Number of images to generate (1-8). Defaults to 1.',
ge=1,
le=8,
)
prompt: str = Field(
..., description='Required. The prompt to use to generate the image.'
)
resolution: Optional[str] = Field(
None,
description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.",
)
seed: Optional[int] = Field(
None,
description='Optional. A number between 0 and 2147483647.',
ge=0,
le=2147483647,
)
style_type: Optional[str] = Field(
None,
description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.",
)
class IdeogramGenerateRequest(BaseModel):
image_request: ImageRequest = Field(
..., description='The image generation request parameters.'
)
class Datum(BaseModel):
is_image_safe: Optional[bool] = Field(
None, description='Indicates whether the image is considered safe.'
@ -113,20 +66,6 @@ class StyleCode(RootModel[str]):
root: str = Field(..., pattern='^[0-9A-Fa-f]{8}$')
class Datum1(BaseModel):
is_image_safe: Optional[bool] = None
prompt: Optional[str] = None
resolution: Optional[str] = None
seed: Optional[int] = None
style_type: Optional[str] = None
url: Optional[str] = None
class IdeogramV3IdeogramResponse(BaseModel):
created: Optional[datetime] = None
data: Optional[List[Datum1]] = None
class RenderingSpeed1(str, Enum):
TURBO = 'TURBO'
DEFAULT = 'DEFAULT'

View File

@ -1,147 +0,0 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field, confloat
class StabilityFormat(str, Enum):
png = 'png'
jpeg = 'jpeg'
webp = 'webp'
class StabilityAspectRatio(str, Enum):
ratio_1_1 = "1:1"
ratio_16_9 = "16:9"
ratio_9_16 = "9:16"
ratio_3_2 = "3:2"
ratio_2_3 = "2:3"
ratio_5_4 = "5:4"
ratio_4_5 = "4:5"
ratio_21_9 = "21:9"
ratio_9_21 = "9:21"
def get_stability_style_presets(include_none=True):
presets = []
if include_none:
presets.append("None")
return presets + [x.value for x in StabilityStylePreset]
class StabilityStylePreset(str, Enum):
_3d_model = "3d-model"
analog_film = "analog-film"
anime = "anime"
cinematic = "cinematic"
comic_book = "comic-book"
digital_art = "digital-art"
enhance = "enhance"
fantasy_art = "fantasy-art"
isometric = "isometric"
line_art = "line-art"
low_poly = "low-poly"
modeling_compound = "modeling-compound"
neon_punk = "neon-punk"
origami = "origami"
photographic = "photographic"
pixel_art = "pixel-art"
tile_texture = "tile-texture"
class Stability_SD3_5_Model(str, Enum):
sd3_5_large = "sd3.5-large"
# sd3_5_large_turbo = "sd3.5-large-turbo"
sd3_5_medium = "sd3.5-medium"
class Stability_SD3_5_GenerationMode(str, Enum):
text_to_image = "text-to-image"
image_to_image = "image-to-image"
class StabilityStable3_5Request(BaseModel):
model: str = Field(...)
mode: str = Field(...)
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
aspect_ratio: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
style_preset: Optional[str] = Field(None)
cfg_scale: float = Field(...)
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
class StabilityUpscaleConservativeRequest(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
creativity: Optional[confloat(ge=0.2, le=0.5)] = Field(None)
class StabilityUpscaleCreativeRequest(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
creativity: Optional[confloat(ge=0.1, le=0.5)] = Field(None)
style_preset: Optional[str] = Field(None)
class StabilityStableUltraRequest(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
aspect_ratio: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
style_preset: Optional[str] = Field(None)
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
class StabilityStableUltraResponse(BaseModel):
image: Optional[str] = Field(None)
finish_reason: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class StabilityResultsGetResponse(BaseModel):
image: Optional[str] = Field(None)
finish_reason: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
id: Optional[str] = Field(None)
name: Optional[str] = Field(None)
errors: Optional[list[str]] = Field(None)
status: Optional[str] = Field(None)
result: Optional[str] = Field(None)
class StabilityAsyncResponse(BaseModel):
id: Optional[str] = Field(None)
class StabilityTextToAudioRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
duration: int = Field(190, ge=1, le=190)
seed: int = Field(0, ge=0, le=4294967294)
steps: int = Field(8, ge=4, le=8)
output_format: str = Field("wav")
class StabilityAudioToAudioRequest(StabilityTextToAudioRequest):
strength: float = Field(0.01, ge=0.01, le=1.0)
class StabilityAudioInpaintRequest(StabilityTextToAudioRequest):
mask_start: int = Field(30, ge=0, le=190)
mask_end: int = Field(190, ge=0, le=190)
class StabilityAudioResponse(BaseModel):
audio: Optional[str] = Field(None)

View File

@ -1,3 +1,4 @@
import base64
import hashlib
import logging
import math
@ -20,6 +21,10 @@ from comfy_api_nodes.apis.bytedance import (
GetAssetResponse,
Image2VideoTaskCreationRequest,
ImageTaskCreationResponse,
SeedAudioConfig,
SeedAudioReference,
SeedAudioRequest,
SeedAudioResponse,
Seedance2TaskCreationRequest,
SeedanceCreateAssetRequest,
SeedanceCreateAssetResponse,
@ -43,6 +48,8 @@ from comfy_api_nodes.apis.bytedance import (
)
from comfy_api_nodes.util import (
ApiEndpoint,
audio_bytes_to_audio_input,
audio_input_to_mp3,
download_url_to_image_tensor,
download_url_to_video_output,
downscale_image_tensor_by_max_side,
@ -51,11 +58,14 @@ from comfy_api_nodes.util import (
image_tensor_pair_to_batch,
poll_op,
sync_op,
tensor_to_base64_string,
upload_audio_to_comfyapi,
upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
upscale_image_tensor_to_min_pixels,
upscale_video_to_min_pixels,
validate_audio_duration,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_string,
@ -2474,6 +2484,311 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode):
return IO.NodeOutput(asset_id, resolved_group)
MODE_TEXT = "text only"
MODE_AUDIO = "audio reference"
MODE_IMAGE = "image reference"
MODE_SPEAKER = "preset voice"
# (speaker_id, display_label) for built-in TTS 2.0 voices; resolvable ids are account-scoped.
SEED_AUDIO_PRESET_VOICES: list[tuple[str, str]] = [
("zh_female_vv_uranus_bigtts", "Vivi (Female, multilingual)"),
("zh_female_xiaohe_uranus_bigtts", "Mindy (Female, multilingual)"),
("en_female_stokie_uranus_bigtts", "Stokie (Female, English)"),
("en_female_dacey_uranus_bigtts", "Dacey (Female, English)"),
("en_male_tim_uranus_bigtts", "Tim (Male, English)"),
("zh_male_m191_uranus_bigtts", "Kian (Male, multilingual)"),
("zh_male_taocheng_uranus_bigtts", "Cedric (Male, multilingual)"),
("zh_male_sophie_uranus_bigtts", "Sophie (Female, multilingual)"),
("zh_female_yingyujiaoxue_uranus_bigtts", "Jean (Female, multilingual)"),
("zh_male_dayi_uranus_bigtts", "Magnus (Male, multilingual)"),
("zh_female_mizai_uranus_bigtts", "Mabel (Female, multilingual)"),
("zh_female_jitangnv_uranus_bigtts", "Nadia (Female, multilingual)"),
("zh_female_meilinvyou_uranus_bigtts", "Opal (Female, multilingual)"),
("zh_female_liuchangnv_uranus_bigtts", "Pearl (Female, multilingual)"),
("zh_male_ruyayichen_uranus_bigtts", "Quentin (Male, multilingual)"),
("zh_female_vivo_uranus_bigtts", "Vienna (Female, multilingual)"),
("zh_female_xiaoai_uranus_bigtts", "Alina (Female, multilingual)"),
("zh_female_cancan_uranus_bigtts", "Corinne (Female, multilingual)"),
("zh_female_tianmeixiaoyuan_uranus_bigtts", "Esther (Female, multilingual)"),
("zh_female_tianmeitaozi_uranus_bigtts", "Freya (Female, multilingual)"),
("zh_female_shuangkuaisisi_uranus_bigtts", "Gigi (Female, multilingual)"),
("zh_female_peiqi_uranus_bigtts", "Holly (Female, multilingual)"),
("zh_female_xiaoxue_uranus_bigtts", "Lyla (Female, multilingual)"),
("zh_female_yuanqi_uranus_bigtts", "Daisy (Female, multilingual)"),
("zh_female_kefunvsheng_uranus_bigtts", "Tracy (Female, multilingual)"),
("zh_male_shaonianzixin_uranus_bigtts", "Jess (Male, multilingual)"),
("zh_female_linjianvhai_uranus_bigtts", "Pinky (Female, multilingual)"),
("zh_female_kiwi_uranus_bigtts", "Sweety (Female, multilingual)"),
("zh_female_sajiaoxuemei_uranus_bigtts", "Sandy (Female, multilingual)"),
("de_male_seven_uranus_bigtts", "Sven (Male, German)"),
("jp_female_minimi_uranus_bigtts", "Minimi (Female, Japanese)"),
("fr_male_usseau_uranus_bigtts", "Usseau (Male, French)"),
("es_male_felipe_uranus_bigtts", "Felipe (Male, Spanish)"),
("id_male_han_uranus_bigtts", "Han (Male, Indonesian)"),
("pt_male_martins_uranus_bigtts", "Martins (Male, Portuguese)"),
("it_male_enzo_uranus_bigtts", "Enzo (Male, Italian)"),
("kr_male_shane_uranus_bigtts", "Shane (Male, Korean)"),
("zh_male_liufei_uranus_bigtts", "Felix (Male, Chinese)"),
("zh_female_qingxinnvsheng_uranus_bigtts", "Celeste (Female, Chinese)"),
("zh_male_sunwukong_uranus_bigtts", "Monkey King (Male, Chinese)"),
]
SEED_AUDIO_VOICE_OPTIONS = [label for _, label in SEED_AUDIO_PRESET_VOICES]
SEED_AUDIO_VOICE_MAP = {label: speaker_id for speaker_id, label in SEED_AUDIO_PRESET_VOICES}
_AUDIO_TAG_RE = re.compile(r"@Audio(\d+)", re.IGNORECASE)
def max_audio_tag(prompt: str) -> int:
"""Highest N referenced as @AudioN in the prompt (0 if none)."""
nums = [int(m) for m in _AUDIO_TAG_RE.findall(prompt or "")]
return max(nums) if nums else 0
def connected_audio_indices(reference_mode: dict) -> list[int]:
"""Indices (1-based) of connected reference_audio sockets, in order."""
return [
i
for i in range(1, 3 + 1)
if reference_mode.get(f"reference_audio_{i}") is not None
]
def validate_seed_audio_inputs(
text_prompt: str,
mode: str,
audio_indices: list[int],
has_image: bool,
preset_voice: str | None = None,
) -> None:
validate_string(text_prompt, field_name="text_prompt", min_length=1, max_length=3000)
max_tag = max_audio_tag(text_prompt)
if mode == MODE_TEXT:
if max_tag:
raise ValueError(
f"The prompt references @Audio{max_tag}, but reference mode is '{MODE_TEXT}'. "
f"Switch to '{MODE_AUDIO}' and connect the reference clip(s)."
)
elif mode == MODE_AUDIO:
if not audio_indices:
raise ValueError(
f"Reference mode '{MODE_AUDIO}' requires at least one reference_audio input "
f"(or switch to '{MODE_TEXT}')."
)
if audio_indices != list(range(1, len(audio_indices) + 1)):
raise ValueError(
"Connect reference_audio inputs in order without gaps: reference_audio_1, then _2, then _3."
)
if max_tag > len(audio_indices):
raise ValueError(
f"The prompt references @Audio{max_tag}, but only {len(audio_indices)} "
f"reference audio(s) are connected."
)
elif mode == MODE_IMAGE:
if not has_image:
raise ValueError(f"Reference mode '{MODE_IMAGE}' requires a reference_image input.")
if max_tag:
raise ValueError(
f"@AudioN tags are not used in '{MODE_IMAGE}' mode; the prompt should contain "
f"only the text to synthesize."
)
elif mode == MODE_SPEAKER:
if not preset_voice or preset_voice not in SEED_AUDIO_VOICE_MAP:
raise ValueError(f"Reference mode '{MODE_SPEAKER}' requires selecting a preset voice.")
if max_tag > 1:
raise ValueError(
f"'{MODE_SPEAKER}' mode uses a single voice, so @Audio{max_tag} is out of range. "
f"Remove the @AudioN tags — the whole prompt is read in the selected voice."
)
else:
raise ValueError(f"Unknown reference mode: {mode!r}")
class ByteDanceSeedAudioNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="ByteDanceSeedAudio",
display_name="ByteDance Seed Audio 1.0",
category="partner/audio/ByteDance",
description=(
"Generate speech, music, sound effects and multi-speaker dialogue from a single prompt "
"with ByteDance Seed Audio 1.0. Describe the voice(s), emotion, ambience, background music "
"and sound effects in the prompt, and include the lines to speak. Optionally pick a built-in "
"preset voice, clone voices from up to 3 reference clips (tagged @Audio1-3 in the prompt), "
"or derive a voice from a character image. Up to 2 minutes of audio per run."
),
inputs=[
IO.String.Input(
"text_prompt",
multiline=True,
default="",
tooltip=(
"Describe the voice(s), emotion, pacing, ambience, background music and sound "
"effects, and include the lines to speak (name characters inline for dialogue). "
"In 'audio reference' mode, refer to connected clips by order as @Audio1, @Audio2, "
"@Audio3. Maximum 3000 characters."
),
),
IO.DynamicCombo.Input(
"reference_mode",
options=[
IO.DynamicCombo.Option(MODE_TEXT, []),
IO.DynamicCombo.Option(
MODE_AUDIO,
[
IO.Audio.Input(
"reference_audio_1",
optional=True,
tooltip="Reference clip for voice cloning, tagged @Audio1 in the prompt. "
"Up to 30s.",
),
IO.Audio.Input(
"reference_audio_2",
optional=True,
tooltip="Reference clip tagged @Audio2 in the prompt. Up to 30s.",
),
IO.Audio.Input(
"reference_audio_3",
optional=True,
tooltip="Reference clip tagged @Audio3 in the prompt. Up to 30s.",
),
],
),
IO.DynamicCombo.Option(
MODE_IMAGE,
[
IO.Image.Input(
"reference_image",
optional=True,
tooltip="A single character image; the model derives a voice from it. "
"Cannot be combined with reference audio.",
),
],
),
IO.DynamicCombo.Option(
MODE_SPEAKER,
[
IO.Combo.Input(
"preset_voice",
options=SEED_AUDIO_VOICE_OPTIONS,
default=SEED_AUDIO_VOICE_OPTIONS[0],
tooltip="A built-in TTS 2.0 voice that reads the prompt. No reference "
"clip needed, and @AudioN tags are not used in this mode.",
),
],
),
],
tooltip=(
"How to condition the voice: 'text only' (describe everything in the prompt), "
"'audio reference' (clone up to 3 voices, tagged @Audio1-3), 'image reference' "
"(derive a voice from one character image), or 'preset voice' (pick a built-in "
"named voice that reads the prompt)."
),
),
IO.Combo.Input(
"sample_rate",
options=["8000", "16000", "24000", "32000", "44100", "48000"],
default="24000",
tooltip="Output sample rate in Hz.",
),
IO.Int.Input(
"speech_rate",
default=0,
min=-50,
max=100,
tooltip="Speaking speed. 0 = normal, 100 = 2.0x, -50 = 0.5x.",
),
IO.Int.Input(
"loudness_rate",
default=0,
min=-50,
max=100,
tooltip="Loudness. 0 = normal, 100 = 2.0x, -50 = 0.5x.",
),
IO.Int.Input(
"pitch_rate",
default=0,
min=-12,
max=12,
tooltip="Pitch shift in semitones (-12 to 12).",
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Audio.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd": 0.2145, "format":{"suffix":"/minute","approximate":true}}""",
),
)
@classmethod
async def execute(
cls,
text_prompt: str,
reference_mode: dict,
sample_rate: str,
speech_rate: int,
loudness_rate: int,
pitch_rate: int,
seed: int,
) -> IO.NodeOutput:
mode = reference_mode["reference_mode"]
audio_indices = connected_audio_indices(reference_mode)
image = reference_mode.get("reference_image")
preset_voice = reference_mode.get("preset_voice")
validate_seed_audio_inputs(text_prompt, mode, audio_indices, image is not None, preset_voice)
references: list[SeedAudioReference] | None = None
if mode == MODE_AUDIO:
references = []
for i in audio_indices:
clip = reference_mode[f"reference_audio_{i}"]
validate_audio_duration(clip, max_duration=30.0)
mp3_bytes = audio_input_to_mp3(clip).getvalue()
references.append(SeedAudioReference(audio_data=base64.b64encode(mp3_bytes).decode("utf-8")))
elif mode == MODE_IMAGE:
image = upscale_image_tensor_to_min_pixels(image, 160_000)
references = [SeedAudioReference(image_data=tensor_to_base64_string(image, mime_type="image/png"))]
elif mode == MODE_SPEAKER:
references = [SeedAudioReference(speaker=SEED_AUDIO_VOICE_MAP[preset_voice])]
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/byteplus/api/v3/tts/create", method="POST"),
response_model=SeedAudioResponse,
data=SeedAudioRequest(
text_prompt=text_prompt,
references=references,
audio_config=SeedAudioConfig(
sample_rate=int(sample_rate),
speech_rate=speech_rate,
loudness_rate=loudness_rate,
pitch_rate=pitch_rate,
),
),
)
if not response.audio:
raise Exception(
f"Seed Audio returned no audio (code={response.code}): {response.message}"
)
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response.audio)))
class ByteDanceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -2490,6 +2805,7 @@ class ByteDanceExtension(ComfyExtension):
ByteDance2ReferenceNode,
ByteDanceCreateImageAsset,
ByteDanceCreateVideoAsset,
ByteDanceSeedAudioNode,
]

View File

@ -13,7 +13,7 @@ import torch
from typing_extensions import override
import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl, Types
from comfy_api_nodes.apis.gemini import (
GeminiContent,
GeminiFileData,
@ -37,6 +37,7 @@ from comfy_api_nodes.util import (
audio_to_base64_string,
bytesio_to_image_tensor,
download_url_to_image_tensor,
download_url_to_video_output,
get_number_of_images,
sync_op,
tensor_to_base64_string,
@ -45,6 +46,7 @@ from comfy_api_nodes.util import (
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
validate_video_duration,
video_to_base64_string,
)
@ -229,10 +231,29 @@ async def get_image_from_response(response: GeminiGenerateContentResponse, thoug
return torch.cat(image_tensors, dim=0)
async def get_video_from_response(
response: GeminiGenerateContentResponse, cls: type[IO.ComfyNode] | None = None
) -> InputImpl.VideoFromFile:
parts = get_parts_by_type(response, "video/*")
for part in parts:
if part.inlineData and part.inlineData.data:
return InputImpl.VideoFromFile(BytesIO(base64.b64decode(part.inlineData.data)))
if part.fileData and part.fileData.fileUri:
return await download_url_to_video_output(part.fileData.fileUri, cls=cls)
model_message = get_text_from_response(response).strip()
if model_message:
raise ValueError(f"Gemini did not generate a video. Model response: {model_message}")
raise ValueError(
"Gemini did not generate a video. Try rephrasing your prompt, "
"shortening the requested duration, or reducing the number of input images/videos."
)
def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None:
if not response.modelVersion:
return None
# Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing
output_video_tokens_price = 0.0
if response.modelVersion == "gemini-2.5-pro":
input_tokens_price = 1.25
output_text_tokens_price = 10.0
@ -249,18 +270,27 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
input_tokens_price = 2
output_text_tokens_price = 12.0
output_image_tokens_price = 0.0
elif response.modelVersion == "gemini-3.1-flash-lite-preview":
elif response.modelVersion in ("gemini-3.1-flash-lite-preview", "gemini-3.1-flash-lite"):
input_tokens_price = 0.25
output_text_tokens_price = 1.50
output_image_tokens_price = 0.0
elif response.modelVersion == "gemini-3-pro-image-preview":
elif response.modelVersion in ("gemini-3-pro-image-preview", "gemini-3-pro-image"):
input_tokens_price = 2
output_text_tokens_price = 12.0
output_image_tokens_price = 120.0
elif response.modelVersion == "gemini-3.1-flash-image-preview":
elif response.modelVersion in ("gemini-3.1-flash-image-preview", "gemini-3.1-flash-image"):
input_tokens_price = 0.5
output_text_tokens_price = 3.0
output_image_tokens_price = 60.0
elif response.modelVersion == "gemini-3.1-flash-lite-image":
input_tokens_price = 0.25
output_text_tokens_price = 1.50
output_image_tokens_price = 30.0
elif response.modelVersion == "gemini-omni-flash-preview":
input_tokens_price = 2.145
output_text_tokens_price = 12.87
output_image_tokens_price = 0.0
output_video_tokens_price = 25.025
else:
return None
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
@ -268,6 +298,8 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
for i in response.usageMetadata.candidatesTokensDetails:
if i.modality == Modality.IMAGE:
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
elif i.modality == Modality.VIDEO:
final_price += output_video_tokens_price * i.tokenCount # for Omni Flash
else:
final_price += output_text_tokens_price * i.tokenCount
if response.usageMetadata.thoughtsTokenCount:
@ -1302,7 +1334,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
)
def _nano_banana_2_v2_model_inputs():
def _nano_banana_2_v2_model_inputs(resolutions: list[str]):
return [
IO.Combo.Input(
"aspect_ratio",
@ -1329,8 +1361,8 @@ def _nano_banana_2_v2_model_inputs():
),
IO.Combo.Input(
"resolution",
options=["1K", "2K", "4K"],
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
options=resolutions,
tooltip="Target output resolution.",
),
IO.Combo.Input(
"thinking_level",
@ -1376,7 +1408,11 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
options=[
IO.DynamicCombo.Option(
"Nano Banana 2 (Gemini 3.1 Flash Image)",
_nano_banana_2_v2_model_inputs(),
_nano_banana_2_v2_model_inputs(resolutions=["1K", "2K", "4K"]),
),
IO.DynamicCombo.Option(
"Nano Banana 2 Lite",
_nano_banana_2_v2_model_inputs(resolutions=["1K"]),
),
],
),
@ -1445,9 +1481,13 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
expr="""
(
$r := $lookup(widgets, "model.resolution");
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
$contains(widgets.model, "lite")
? {"type":"usd","usd": 0.034, "format":{"suffix":"/Image","approximate":true}}
: (
$r := $lookup(widgets, "model.resolution");
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
)
)
""",
),
@ -1468,6 +1508,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
model_choice = model["model"]
if model_choice == "Nano Banana 2 (Gemini 3.1 Flash Image)":
model_id = "gemini-3.1-flash-image-preview"
elif model_choice == "Nano Banana 2 Lite":
model_id = "gemini-3.1-flash-lite-image"
else:
model_id = model_choice
@ -1517,6 +1559,149 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
)
OMNI_MAX_IMAGES = 14
OMNI_MAX_VIDEOS = 3
OMNI_MODELS: dict[str, str] = {
"Omni Flash": "gemini-omni-flash-preview",
}
def _omni_flash_inputs() -> list[Input]:
"""Per-model inputs for the Omni video DynamicCombo (prompt + reference media + sampling)."""
return [
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Describe the video to generate. Specify the length and aspect ratio directly in the "
'prompt, e.g. "a 6-second clip in 16:9". Length may be 3-10 seconds; the aspect ratio must be '
"16:9 (landscape) or 9:16 (portrait). The output is 720p, 24 FPS, with audio.",
),
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, OMNI_MAX_IMAGES + 1)],
min=0,
),
tooltip=f"Optional reference image(s) to guide or animate the video. Up to {OMNI_MAX_IMAGES} images.",
),
IO.Autogrow.Input(
"videos",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("video"),
names=[f"video_{i}" for i in range(1, OMNI_MAX_VIDEOS + 1)],
min=0,
),
tooltip=f"Optional reference video(s) to guide or edit. Up to {OMNI_MAX_VIDEOS} videos, "
f"each up to 10 seconds long.",
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.01,
tooltip="Controls randomness. Lower is more focused/deterministic, higher is more varied.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=0.95,
min=0.0,
max=1.0,
step=0.01,
tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.",
advanced=True,
),
]
class GeminiVideoOmni(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GeminiVideoOmni",
display_name="Google Gemini Omni (Video)",
category="partner/video/Gemini",
essentials_category="Video Generation",
description="Generate a video with audio from a text prompt using Google's Gemini Omni Flash model. "
"Optionally provide reference images and/or videos to guide or edit the result. Describe the desired "
"length (3-10s) and aspect ratio (16:9 or 9:16) directly in the prompt.",
inputs=[
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Omni Flash", _omni_flash_inputs()),
],
tooltip="The Gemini video model used to generate the video.",
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.Video.Output(),
IO.String.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr='{"type":"usd","usd":0.146,"format":{"suffix":"/second","approximate":true}}'
),
)
@classmethod
async def execute(cls, model: dict, seed: int) -> IO.NodeOutput:
prompt = model.get("prompt") or ""
validate_string(prompt, strip_whitespace=True, min_length=1)
model_id = OMNI_MODELS[model["model"]]
images = [t for t in (model.get("images") or {}).values() if t is not None]
videos = [v for v in (model.get("videos") or {}).values() if v is not None]
if sum(get_number_of_images(t) for t in images) > OMNI_MAX_IMAGES:
raise ValueError(f"The current maximum number of supported images is {OMNI_MAX_IMAGES}.")
if len(videos) > OMNI_MAX_VIDEOS:
raise ValueError(f"The current maximum number of supported videos is {OMNI_MAX_VIDEOS}.")
for video in videos:
validate_video_duration(video, max_duration=10)
parts: list[GeminiPart] = []
if images or videos:
parts.extend(await build_gemini_media_parts(cls, images, [], videos))
parts.append(GeminiPart(text=prompt))
response = await sync_op(
cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"),
data=GeminiGenerateContentRequest(
contents=[GeminiContent(role=GeminiRole.user, parts=parts)],
generationConfig=GeminiGenerationConfig(
responseModalities=["TEXT", "VIDEO"],
temperature=model.get("temperature", 1.0),
topP=model.get("top_p", 0.95),
),
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(
await get_video_from_response(response, cls=cls),
get_text_from_response(response),
)
class GeminiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -1527,6 +1712,7 @@ class GeminiExtension(ComfyExtension):
GeminiImage2,
GeminiNanoBanana2,
GeminiNanoBanana2V2,
GeminiVideoOmni,
GeminiInputFiles,
]

View File

@ -5,9 +5,7 @@ from PIL import Image
import numpy as np
import torch
from comfy_api_nodes.apis.ideogram import (
IdeogramGenerateRequest,
IdeogramGenerateResponse,
ImageRequest,
IdeogramV3Request,
IdeogramV3EditRequest,
IdeogramV4Request,
@ -21,101 +19,6 @@ from comfy_api_nodes.util import (
validate_string,
)
V1_V1_RES_MAP = {
"Auto":"AUTO",
"512 x 1536":"RESOLUTION_512_1536",
"576 x 1408":"RESOLUTION_576_1408",
"576 x 1472":"RESOLUTION_576_1472",
"576 x 1536":"RESOLUTION_576_1536",
"640 x 1024":"RESOLUTION_640_1024",
"640 x 1344":"RESOLUTION_640_1344",
"640 x 1408":"RESOLUTION_640_1408",
"640 x 1472":"RESOLUTION_640_1472",
"640 x 1536":"RESOLUTION_640_1536",
"704 x 1152":"RESOLUTION_704_1152",
"704 x 1216":"RESOLUTION_704_1216",
"704 x 1280":"RESOLUTION_704_1280",
"704 x 1344":"RESOLUTION_704_1344",
"704 x 1408":"RESOLUTION_704_1408",
"704 x 1472":"RESOLUTION_704_1472",
"720 x 1280":"RESOLUTION_720_1280",
"736 x 1312":"RESOLUTION_736_1312",
"768 x 1024":"RESOLUTION_768_1024",
"768 x 1088":"RESOLUTION_768_1088",
"768 x 1152":"RESOLUTION_768_1152",
"768 x 1216":"RESOLUTION_768_1216",
"768 x 1232":"RESOLUTION_768_1232",
"768 x 1280":"RESOLUTION_768_1280",
"768 x 1344":"RESOLUTION_768_1344",
"832 x 960":"RESOLUTION_832_960",
"832 x 1024":"RESOLUTION_832_1024",
"832 x 1088":"RESOLUTION_832_1088",
"832 x 1152":"RESOLUTION_832_1152",
"832 x 1216":"RESOLUTION_832_1216",
"832 x 1248":"RESOLUTION_832_1248",
"864 x 1152":"RESOLUTION_864_1152",
"896 x 960":"RESOLUTION_896_960",
"896 x 1024":"RESOLUTION_896_1024",
"896 x 1088":"RESOLUTION_896_1088",
"896 x 1120":"RESOLUTION_896_1120",
"896 x 1152":"RESOLUTION_896_1152",
"960 x 832":"RESOLUTION_960_832",
"960 x 896":"RESOLUTION_960_896",
"960 x 1024":"RESOLUTION_960_1024",
"960 x 1088":"RESOLUTION_960_1088",
"1024 x 640":"RESOLUTION_1024_640",
"1024 x 768":"RESOLUTION_1024_768",
"1024 x 832":"RESOLUTION_1024_832",
"1024 x 896":"RESOLUTION_1024_896",
"1024 x 960":"RESOLUTION_1024_960",
"1024 x 1024":"RESOLUTION_1024_1024",
"1088 x 768":"RESOLUTION_1088_768",
"1088 x 832":"RESOLUTION_1088_832",
"1088 x 896":"RESOLUTION_1088_896",
"1088 x 960":"RESOLUTION_1088_960",
"1120 x 896":"RESOLUTION_1120_896",
"1152 x 704":"RESOLUTION_1152_704",
"1152 x 768":"RESOLUTION_1152_768",
"1152 x 832":"RESOLUTION_1152_832",
"1152 x 864":"RESOLUTION_1152_864",
"1152 x 896":"RESOLUTION_1152_896",
"1216 x 704":"RESOLUTION_1216_704",
"1216 x 768":"RESOLUTION_1216_768",
"1216 x 832":"RESOLUTION_1216_832",
"1232 x 768":"RESOLUTION_1232_768",
"1248 x 832":"RESOLUTION_1248_832",
"1280 x 704":"RESOLUTION_1280_704",
"1280 x 720":"RESOLUTION_1280_720",
"1280 x 768":"RESOLUTION_1280_768",
"1280 x 800":"RESOLUTION_1280_800",
"1312 x 736":"RESOLUTION_1312_736",
"1344 x 640":"RESOLUTION_1344_640",
"1344 x 704":"RESOLUTION_1344_704",
"1344 x 768":"RESOLUTION_1344_768",
"1408 x 576":"RESOLUTION_1408_576",
"1408 x 640":"RESOLUTION_1408_640",
"1408 x 704":"RESOLUTION_1408_704",
"1472 x 576":"RESOLUTION_1472_576",
"1472 x 640":"RESOLUTION_1472_640",
"1472 x 704":"RESOLUTION_1472_704",
"1536 x 512":"RESOLUTION_1536_512",
"1536 x 576":"RESOLUTION_1536_576",
"1536 x 640":"RESOLUTION_1536_640",
}
V1_V2_RATIO_MAP = {
"1:1":"ASPECT_1_1",
"4:3":"ASPECT_4_3",
"3:4":"ASPECT_3_4",
"16:9":"ASPECT_16_9",
"9:16":"ASPECT_9_16",
"2:1":"ASPECT_2_1",
"1:2":"ASPECT_1_2",
"3:2":"ASPECT_3_2",
"2:3":"ASPECT_2_3",
"4:5":"ASPECT_4_5",
"5:4":"ASPECT_5_4",
}
V3_RATIO_MAP = {
"1:3":"1x3",
@ -229,298 +132,6 @@ async def download_and_process_images(image_urls):
return stacked_tensors
class IdeogramV1(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="IdeogramV1",
display_name="Ideogram V1",
category="partner/image/Ideogram",
description="Generates images using the Ideogram V1 model.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
IO.Boolean.Input(
"turbo",
default=False,
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
),
IO.Combo.Input(
"aspect_ratio",
options=list(V1_V2_RATIO_MAP.keys()),
default="1:1",
tooltip="The aspect ratio for image generation.",
optional=True,
),
IO.Combo.Input(
"magic_prompt_option",
options=["AUTO", "ON", "OFF"],
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
advanced=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
control_after_generate=True,
display_mode=IO.NumberDisplay.number,
optional=True,
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Description of what to exclude from the image",
optional=True,
),
IO.Int.Input(
"num_images",
default=1,
min=1,
max=8,
step=1,
display_mode=IO.NumberDisplay.number,
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]),
expr="""
(
$n := widgets.num_images;
$base := (widgets.turbo = true) ? 0.0286 : 0.0858;
{"type":"usd","usd": $round($base * $n, 2)}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt,
turbo=False,
aspect_ratio="1:1",
magic_prompt_option="AUTO",
seed=0,
negative_prompt="",
num_images=1,
):
# Determine the model based on turbo setting
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
model = "V_1_TURBO" if turbo else "V_1"
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
num_images=num_images,
seed=seed,
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
negative_prompt=negative_prompt if negative_prompt else None,
)
),
max_retries=1,
)
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
return IO.NodeOutput(await download_and_process_images(image_urls))
class IdeogramV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="IdeogramV2",
display_name="Ideogram V2",
category="partner/image/Ideogram",
description="Generates images using the Ideogram V2 model.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
IO.Boolean.Input(
"turbo",
default=False,
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
),
IO.Combo.Input(
"aspect_ratio",
options=list(V1_V2_RATIO_MAP.keys()),
default="1:1",
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
optional=True,
),
IO.Combo.Input(
"resolution",
options=list(V1_V1_RES_MAP.keys()),
default="Auto",
tooltip="The resolution for image generation. "
"If not set to AUTO, this overrides the aspect_ratio setting.",
optional=True,
),
IO.Combo.Input(
"magic_prompt_option",
options=["AUTO", "ON", "OFF"],
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
advanced=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
control_after_generate=True,
display_mode=IO.NumberDisplay.number,
optional=True,
),
IO.Combo.Input(
"style_type",
options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
default="NONE",
tooltip="Style type for generation (V2 only)",
optional=True,
advanced=True,
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Description of what to exclude from the image",
optional=True,
),
IO.Int.Input(
"num_images",
default=1,
min=1,
max=8,
step=1,
display_mode=IO.NumberDisplay.number,
optional=True,
),
#"color_palette": (
# IO.STRING,
# {
# "multiline": False,
# "default": "",
# "tooltip": "Color palette preset name or hex colors with weights",
# },
#),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]),
expr="""
(
$n := widgets.num_images;
$base := (widgets.turbo = true) ? 0.0715 : 0.1144;
{"type":"usd","usd": $round($base * $n, 2)}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt,
turbo=False,
aspect_ratio="1:1",
resolution="Auto",
magic_prompt_option="AUTO",
seed=0,
style_type="NONE",
negative_prompt="",
num_images=1,
color_palette="",
):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
resolution = V1_V1_RES_MAP.get(resolution, None)
# Determine the model based on turbo setting
model = "V_2_TURBO" if turbo else "V_2"
# Handle resolution vs aspect_ratio logic
# If resolution is not AUTO, it overrides aspect_ratio
final_resolution = None
final_aspect_ratio = None
if resolution != "AUTO":
final_resolution = resolution
else:
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
num_images=num_images,
seed=seed,
aspect_ratio=final_aspect_ratio,
resolution=final_resolution,
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
style_type=style_type if style_type != "NONE" else None,
negative_prompt=negative_prompt if negative_prompt else None,
color_palette=color_palette if color_palette else None,
)
),
max_retries=1,
)
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
return IO.NodeOutput(await download_and_process_images(image_urls))
class IdeogramV3(IO.ComfyNode):
@classmethod
@ -917,8 +528,6 @@ class IdeogramExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
IdeogramV1,
IdeogramV2,
IdeogramV3,
IdeogramV4,
]

View File

@ -1,932 +0,0 @@
from inspect import cleandoc
from typing import Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, Input, IO
from comfy_api_nodes.apis.stability import (
StabilityUpscaleConservativeRequest,
StabilityUpscaleCreativeRequest,
StabilityAsyncResponse,
StabilityResultsGetResponse,
StabilityStable3_5Request,
StabilityStableUltraRequest,
StabilityStableUltraResponse,
StabilityAspectRatio,
Stability_SD3_5_Model,
Stability_SD3_5_GenerationMode,
get_stability_style_presets,
StabilityTextToAudioRequest,
StabilityAudioToAudioRequest,
StabilityAudioInpaintRequest,
StabilityAudioResponse,
)
from comfy_api_nodes.util import (
validate_audio_duration,
validate_string,
audio_input_to_mp3,
bytesio_to_image_tensor,
tensor_to_bytesio,
audio_bytes_to_audio_input,
sync_op,
poll_op,
ApiEndpoint,
)
import torch
import base64
from io import BytesIO
from enum import Enum
class StabilityPollStatus(str, Enum):
finished = "finished"
in_progress = "in_progress"
failed = "failed"
def get_async_dummy_status(x: StabilityResultsGetResponse):
if x.name is not None or x.errors is not None:
return StabilityPollStatus.failed
elif x.finish_reason is not None:
return StabilityPollStatus.finished
return StabilityPollStatus.in_progress
class StabilityStableImageUltraNode(IO.ComfyNode):
"""
Generates images synchronously based on prompt and resolution.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityStableImageUltraNode",
display_name="Stability AI Stable Image Ultra",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
"elements, colors, and subjects will lead to better results. " +
"To control the weight of a given word use the format `(word:weight)`," +
"where `word` is the word you'd like to control the weight of and `weight`" +
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
"would convey a sky that was blue and green, but more green than blue.",
),
IO.Combo.Input(
"aspect_ratio",
options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.",
),
IO.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
advanced=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.Image.Input(
"image",
optional=True,
),
IO.String.Input(
"negative_prompt",
default="",
tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
advanced=True,
),
IO.Float.Input(
"image_denoise",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.08}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
aspect_ratio: str,
style_preset: str,
seed: int,
image: Optional[torch.Tensor] = None,
negative_prompt: str = "",
image_denoise: Optional[float] = 0.5,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
if image is not None:
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
else:
image_denoise = None
if not negative_prompt:
negative_prompt = None
if style_preset == "None":
style_preset = None
files = {
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityStableUltraRequest(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
seed=seed,
strength=image_denoise,
style_preset=style_preset,
),
files=files,
content_type="multipart/form-data",
)
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return IO.NodeOutput(returned_image)
class StabilityStableImageSD_3_5Node(IO.ComfyNode):
"""
Generates images synchronously based on prompt and resolution.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityStableImageSD_3_5Node",
display_name="Stability AI Stable Diffusion 3.5 Image",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
IO.Combo.Input(
"model",
options=Stability_SD3_5_Model,
),
IO.Combo.Input(
"aspect_ratio",
options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.",
),
IO.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
advanced=True,
),
IO.Float.Input(
"cfg_scale",
default=4.0,
min=1.0,
max=10.0,
step=0.1,
tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.Image.Input(
"image",
optional=True,
),
IO.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
advanced=True,
),
IO.Float.Input(
"image_denoise",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$contains(widgets.model,"large")
? {"type":"usd","usd":0.065}
: {"type":"usd","usd":0.035}
)
""",
),
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
aspect_ratio: str,
style_preset: str,
seed: int,
cfg_scale: float,
image: Optional[torch.Tensor] = None,
negative_prompt: str = "",
image_denoise: Optional[float] = 0.5,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
mode = Stability_SD3_5_GenerationMode.text_to_image
if image is not None:
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
mode = Stability_SD3_5_GenerationMode.image_to_image
aspect_ratio = None
else:
image_denoise = None
if not negative_prompt:
negative_prompt = None
if style_preset == "None":
style_preset = None
files = {
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityStable3_5Request(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
seed=seed,
strength=image_denoise,
style_preset=style_preset,
cfg_scale=cfg_scale,
model=model,
mode=mode,
),
files=files,
content_type="multipart/form-data",
)
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return IO.NodeOutput(returned_image)
class StabilityUpscaleConservativeNode(IO.ComfyNode):
"""
Upscale image with minimal alterations to 4K resolution.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityUpscaleConservativeNode",
display_name="Stability AI Upscale Conservative",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
IO.Float.Input(
"creativity",
default=0.35,
min=0.2,
max=0.5,
step=0.01,
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
advanced=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
prompt: str,
creativity: float,
seed: int,
negative_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
if not negative_prompt:
negative_prompt = None
files = {
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityUpscaleConservativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
seed=seed,
),
files=files,
content_type="multipart/form-data",
)
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return IO.NodeOutput(returned_image)
class StabilityUpscaleCreativeNode(IO.ComfyNode):
"""
Upscale image with minimal alterations to 4K resolution.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityUpscaleCreativeNode",
display_name="Stability AI Upscale Creative",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
IO.Float.Input(
"creativity",
default=0.3,
min=0.1,
max=0.5,
step=0.01,
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
),
IO.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
advanced=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
advanced=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.6}""",
),
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
prompt: str,
creativity: float,
style_preset: str,
seed: int,
negative_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
if not negative_prompt:
negative_prompt = None
if style_preset == "None":
style_preset = None
files = {
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
response_model=StabilityAsyncResponse,
data=StabilityUpscaleCreativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
style_preset=style_preset,
seed=seed,
),
files=files,
content_type="multipart/form-data",
)
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
response_model=StabilityResultsGetResponse,
poll_interval=3,
status_extractor=lambda x: get_async_dummy_status(x),
)
if response_poll.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
image_data = base64.b64decode(response_poll.result)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return IO.NodeOutput(returned_image)
class StabilityUpscaleFastNode(IO.ComfyNode):
"""
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityUpscaleFastNode",
display_name="Stability AI Upscale Fast",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.02}""",
),
)
@classmethod
async def execute(cls, image: torch.Tensor) -> IO.NodeOutput:
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
files = {
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
response_model=StabilityStableUltraResponse,
files=files,
content_type="multipart/form-data",
)
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return IO.NodeOutput(returned_image)
class StabilityTextToAudio(IO.ComfyNode):
"""Generates high-quality music and sound effects from text descriptions."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityTextToAudio",
display_name="Stability AI Text To Audio",
category="partner/audio/Stability AI",
essentials_category="Audio",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
IO.String.Input("prompt", multiline=True, default=""),
IO.Int.Input(
"duration",
default=190,
min=1,
max=190,
step=1,
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
IO.Int.Input(
"steps",
default=8,
min=4,
max=8,
step=1,
tooltip="Controls the number of sampling steps.",
optional=True,
advanced=True,
),
],
outputs=[
IO.Audio.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.2}""",
),
)
@classmethod
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
validate_string(prompt, max_length=10000)
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
)
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityAudioToAudio(IO.ComfyNode):
"""Transforms existing audio samples into new high-quality compositions using text instructions."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityAudioToAudio",
display_name="Stability AI Audio To Audio",
category="partner/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
IO.String.Input("prompt", multiline=True, default=""),
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
IO.Int.Input(
"duration",
default=190,
min=1,
max=190,
step=1,
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
IO.Int.Input(
"steps",
default=8,
min=4,
max=8,
step=1,
tooltip="Controls the number of sampling steps.",
optional=True,
advanced=True,
),
IO.Float.Input(
"strength",
default=1,
min=0.01,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Parameter controls how much influence the audio parameter has on the generated audio.",
optional=True,
),
],
outputs=[
IO.Audio.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.2}""",
),
)
@classmethod
async def execute(
cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float
) -> IO.NodeOutput:
validate_string(prompt, max_length=10000)
validate_audio_duration(audio, 6, 190)
payload = StabilityAudioToAudioRequest(
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
)
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
)
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityAudioInpaint(IO.ComfyNode):
"""Transforms part of existing audio sample using text instructions."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityAudioInpaint",
display_name="Stability AI Audio Inpaint",
category="partner/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
IO.String.Input("prompt", multiline=True, default=""),
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
IO.Int.Input(
"duration",
default=190,
min=1,
max=190,
step=1,
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
IO.Int.Input(
"steps",
default=8,
min=4,
max=8,
step=1,
tooltip="Controls the number of sampling steps.",
optional=True,
advanced=True,
),
IO.Int.Input(
"mask_start",
default=30,
min=0,
max=190,
step=1,
optional=True,
advanced=True,
),
IO.Int.Input(
"mask_end",
default=190,
min=0,
max=190,
step=1,
optional=True,
advanced=True,
),
],
outputs=[
IO.Audio.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.2}""",
),
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
audio: Input.Audio,
duration: int,
seed: int,
steps: int,
mask_start: int,
mask_end: int,
) -> IO.NodeOutput:
validate_string(prompt, max_length=10000)
if mask_end <= mask_start:
raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})")
validate_audio_duration(audio, 6, 190)
payload = StabilityAudioInpaintRequest(
prompt=prompt,
model=model,
duration=duration,
seed=seed,
steps=steps,
mask_start=mask_start,
mask_end=mask_end,
)
response_api = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
)
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
StabilityStableImageUltraNode,
StabilityStableImageSD_3_5Node,
StabilityUpscaleConservativeNode,
StabilityUpscaleCreativeNode,
StabilityUpscaleFastNode,
StabilityTextToAudio,
StabilityAudioToAudio,
StabilityAudioInpaint,
]
async def comfy_entrypoint() -> StabilityExtension:
return StabilityExtension()

View File

@ -26,6 +26,7 @@ from .conversions import (
text_filepath_to_base64_string,
text_filepath_to_data_uri,
trim_video,
upscale_image_tensor_to_min_pixels,
upscale_video_to_min_pixels,
video_to_base64_string,
)
@ -99,6 +100,7 @@ __all__ = [
"text_filepath_to_base64_string",
"text_filepath_to_data_uri",
"trim_video",
"upscale_image_tensor_to_min_pixels",
"upscale_video_to_min_pixels",
"video_to_base64_string",
# Validation utilities

View File

@ -448,6 +448,15 @@ def _compute_upscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[in
return new_w, new_h
def upscale_image_tensor_to_min_pixels(image: torch.Tensor, total_pixels: int) -> torch.Tensor:
samples = image.movedim(-1, 1)
dims = _compute_upscale_dims(samples.shape[3], samples.shape[2], int(total_pixels))
if dims is None:
return image
new_w, new_h = dims
return common_upscale(samples, new_w, new_h, "lanczos", "disabled").movedim(1, -1)
def upscale_video_to_min_pixels(video: Input.Video, min_pixels: int) -> Input.Video:
"""Upscale a video to meet at least ``min_pixels`` (w * h), preserving aspect ratio.

View File

@ -1,5 +1,6 @@
import asyncio
import bisect
import gc
import itertools
import psutil
import time
@ -528,6 +529,38 @@ class RAMPressureCache(LRUCache):
if psutil.virtual_memory().available >= target:
return
def remove_cache_key(key):
del self.cache[key]
self.used_generation.pop(key, None)
self.timestamps.pop(key, None)
self.children.pop(key, None)
def has_old_model_patcher(outputs):
if outputs is None:
return False
for output in outputs:
if isinstance(output, (list, tuple)):
if has_old_model_patcher(output):
return True
elif isinstance(output, ModelPatcher):
return True
return False
old_modelpatcher_keys = []
for key, cache_entry in self.cache.items():
if self.used_generation[key] == self.generation:
continue
if has_old_model_patcher(cache_entry.outputs):
old_modelpatcher_keys.append(key)
for key in old_modelpatcher_keys:
remove_cache_key(key)
if old_modelpatcher_keys:
gc.collect()
if psutil.virtual_memory().available >= target:
return
clean_list = []
for key, cache_entry in self.cache.items():
@ -545,19 +578,17 @@ class RAMPressureCache(LRUCache):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
ram_usage += output.numel() * output.element_size()
elif isinstance(output, ModelPatcher) and self.used_generation[key] != self.generation:
#old ModelPatchers are the first to go
ram_usage = 1e30
scan_list_for_ram_usage(cache_entry.outputs)
oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
bisect.insort(clean_list, (oom_score, self.timestamps[key], ram_usage, key))
while psutil.virtual_memory().available < target and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
self.used_generation.pop(key, None)
self.timestamps.pop(key, None)
self.children.pop(key, None)
to_free = target - psutil.virtual_memory().available
while to_free > 0 and clean_list:
_, _, ram_usage, key = clean_list.pop()
remove_cache_key(key)
to_free -= ram_usage
gc.collect()

View File

@ -8,7 +8,8 @@ class CLIPTextEncodeControlnet(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CLIPTextEncodeControlnet",
category="experimental/conditioning",
display_name="CLIP Text Encode (Controlnet)",
category="model/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Conditioning.Input("conditioning"),
@ -35,11 +36,12 @@ class T5TokenizerOptions(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="T5TokenizerOptions",
category="experimental/conditioning",
display_name="T5 Tokenizer Options",
category="model/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),
io.Int.Input("min_length", default=0, min=0, max=10000, step=1, advanced=True),
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1),
io.Int.Input("min_length", default=0, min=0, max=10000, step=1),
],
outputs=[io.Clip.Output()],
is_experimental=True,

View File

@ -1070,7 +1070,7 @@ class AddNoise(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="AddNoise",
category="experimental/custom_sampling/noise",
category="model/sampling/noise",
is_experimental=True,
inputs=[
io.Model.Input("model"),
@ -1120,7 +1120,7 @@ class ManualSigmas(io.ComfyNode):
return io.Schema(
node_id="ManualSigmas",
search_aliases=["custom noise schedule", "define sigmas"],
category="experimental/custom_sampling",
category="model/sampling/sigmas",
is_experimental=True,
inputs=[
io.String.Input("sigmas", default="1, 0.5", multiline=False)

View File

@ -1,85 +1,68 @@
import os
import sys
import re
import ctypes
import logging
import ctypes.util
import importlib.util
from typing import TypedDict
import numpy as np
import torch
import nodes
import comfy_angle
from comfy_api.latest import ComfyExtension, io, ui
from typing_extensions import override
from utils.install_util import get_missing_requirements_message
logger = logging.getLogger(__name__)
def _check_opengl_availability():
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
logger.debug("_check_opengl_availability: starting")
missing = []
def _preload_angle():
egl_path = comfy_angle.get_egl_path()
gles_path = comfy_angle.get_glesv2_path()
# Check Python packages (using find_spec to avoid importing)
logger.debug("_check_opengl_availability: checking for glfw package")
if importlib.util.find_spec("glfw") is None:
missing.append("glfw")
if sys.platform == "win32":
angle_dir = comfy_angle.get_lib_dir()
os.add_dll_directory(angle_dir)
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
logger.debug("_check_opengl_availability: checking for OpenGL package")
if importlib.util.find_spec("OpenGL") is None:
missing.append("PyOpenGL")
if missing:
raise RuntimeError(
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
)
# On Linux without display, check if headless backends are available
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
if sys.platform.startswith("linux"):
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
if not has_display:
# Check for EGL or OSMesa libraries
logger.debug("_check_opengl_availability: checking for EGL library")
has_egl = ctypes.util.find_library("EGL")
logger.debug("_check_opengl_availability: checking for OSMesa library")
has_osmesa = ctypes.util.find_library("OSMesa")
# Error disabled for CI as it fails this check
# if not has_egl and not has_osmesa:
# raise RuntimeError(
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
# "See error below for installation instructions."
# )
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
logger.debug("_check_opengl_availability: completed")
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
ctypes.CDLL(str(egl_path), mode=mode)
ctypes.CDLL(str(gles_path), mode=mode)
# Run early check at import time
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
_check_opengl_availability()
# OpenGL modules - initialized lazily when context is created
gl = None
glfw = None
EGL = None
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
_preload_angle()
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
def _import_opengl():
"""Import OpenGL module. Called after context is created."""
global gl
if gl is None:
logger.debug("_import_opengl: importing OpenGL.GL")
import OpenGL.GL as _gl
gl = _gl
logger.debug("_import_opengl: import completed")
return gl
import OpenGL
OpenGL.USE_ACCELERATE = False
def _patch_find_library():
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
'libGLESv2'. Patch find_library to return the full ANGLE paths so
PyOpenGL loads the same libraries we pre-loaded."""
if sys.platform == "linux":
return
import ctypes.util
_orig = ctypes.util.find_library
def _patched(name):
if name == 'EGL':
return comfy_angle.get_egl_path()
if name == 'GLESv2':
return comfy_angle.get_glesv2_path()
return _orig(name)
ctypes.util.find_library = _patched
_patch_find_library()
from OpenGL import EGL
from OpenGL import GLES3 as gl
class SizeModeInput(TypedDict):
size_mode: str
width: int
@ -102,7 +85,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
# (-1,-1)---(3,-1)
#
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
VERTEX_SHADER = """#version 330 core
VERTEX_SHADER = """#version 300 es
out vec2 v_texCoord;
void main() {
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
@ -126,14 +109,99 @@ void main() {
"""
def _convert_es_to_desktop(source: str) -> str:
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
# Remove any existing #version directive
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
# Remove precision qualifiers (not needed in desktop GLSL)
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
# Prepend desktop GLSL version
return "#version 330 core\n" + source
def _egl_attribs(*values):
"""Build an EGL_NONE-terminated EGLint attribute array."""
vals = list(values) + [EGL.EGL_NONE]
return (ctypes.c_int32 * len(vals))(*vals)
# EGL platform extension constants
EGL_PLATFORM_ANGLE_ANGLE = 0x3202
EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203
EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450
EGL_MESA_PLATFORM_SURFACELESS = 0x31DD
_eglGetPlatformDisplayEXT = None
def _get_egl_platform_display_ext(platform, native_display, attribs):
"""Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL)."""
global _eglGetPlatformDisplayEXT
if _eglGetPlatformDisplayEXT is None:
from OpenGL import platform as _plat
egl_lib = _plat.PLATFORM.EGL
_get_proc = egl_lib.eglGetProcAddress
_get_proc.restype = ctypes.c_void_p
_get_proc.argtypes = [ctypes.c_char_p]
ptr = _get_proc(b"eglGetPlatformDisplayEXT")
if not ptr:
return None
func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p)
_eglGetPlatformDisplayEXT = func_type(ptr)
raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs)
if not raw:
return None
return ctypes.cast(raw, EGL.EGLDisplay)
def _get_egl_display():
"""Get an EGL display, trying the default first then ANGLE's Vulkan
platform for headless environments without a display server."""
failures = []
# Try the default display first (works when X11/Wayland is available)
display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
if display:
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
try:
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
return display, major.value, minor.value
except Exception as e:
failures.append(f"default: {e}")
logger.info("Default EGL display unavailable, trying headless fallbacks")
# Headless fallback strategies, tried in order:
headless_strategies = [
("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None),
("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None,
_egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)),
]
for name, platform, native_display, attribs in headless_strategies:
display = _get_egl_platform_display_ext(platform, native_display, attribs)
if not display:
failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display")
continue
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
try:
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
logger.info(f"Using EGL {name} platform (headless)")
return display, major.value, minor.value
failures.append(f"{name}: eglInitialize returned false")
except Exception as e:
failures.append(f"{name}: {e}")
continue
details = "\n".join(f" - {f}" for f in failures)
raise RuntimeError(
"Failed to initialize EGL display.\n"
"No display server and no headless EGL platform available.\n"
f"Tried:\n{details}\n"
"Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer."
)
def _gl_str(name):
"""Get an OpenGL string parameter."""
v = gl.glGetString(name)
if not v:
return "Unknown"
if isinstance(v, bytes):
return v.decode(errors="replace")
return ctypes.string_at(v).decode(errors="replace")
def _detect_output_count(source: str) -> int:
@ -159,163 +227,8 @@ def _detect_pass_count(source: str) -> int:
return 1
def _init_glfw():
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
logger.debug("_init_glfw: starting")
# On macOS, glfw.init() must be called from main thread or it hangs forever
if sys.platform == "darwin":
logger.debug("_init_glfw: skipping on macOS")
raise RuntimeError("GLFW backend not supported on macOS")
logger.debug("_init_glfw: importing glfw module")
import glfw as _glfw
logger.debug("_init_glfw: calling glfw.init()")
if not _glfw.init():
raise RuntimeError("glfw.init() failed")
try:
logger.debug("_init_glfw: setting window hints")
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
logger.debug("_init_glfw: calling create_window()")
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
if not window:
raise RuntimeError("glfw.create_window() failed")
logger.debug("_init_glfw: calling make_context_current()")
_glfw.make_context_current(window)
logger.debug("_init_glfw: completed successfully")
return window, _glfw
except Exception:
logger.debug("_init_glfw: failed, terminating glfw")
_glfw.terminate()
raise
def _init_egl():
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
logger.debug("_init_egl: starting")
from OpenGL import EGL as _EGL
from OpenGL.EGL import (
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
eglTerminate, eglDestroyContext, eglDestroySurface,
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
)
logger.debug("_init_egl: imports completed")
display = None
context = None
surface = None
try:
logger.debug("_init_egl: calling eglGetDisplay()")
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
if display == _EGL.EGL_NO_DISPLAY:
raise RuntimeError("eglGetDisplay() failed")
logger.debug("_init_egl: calling eglInitialize()")
major, minor = _EGL.EGLint(), _EGL.EGLint()
if not eglInitialize(display, major, minor):
display = None # Not initialized, don't terminate
raise RuntimeError("eglInitialize() failed")
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
config_attribs = [
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
EGL_DEPTH_SIZE, 0, EGL_NONE
]
configs = (_EGL.EGLConfig * 1)()
num_configs = _EGL.EGLint()
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
config = configs[0]
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
if not eglBindAPI(EGL_OPENGL_API):
raise RuntimeError("eglBindAPI() failed")
logger.debug("_init_egl: calling eglCreateContext()")
context_attribs = [
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
EGL_NONE
]
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
if context == EGL_NO_CONTEXT:
raise RuntimeError("eglCreateContext() failed")
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
if surface == _EGL.EGL_NO_SURFACE:
raise RuntimeError("eglCreatePbufferSurface() failed")
logger.debug("_init_egl: calling eglMakeCurrent()")
if not eglMakeCurrent(display, surface, surface, context):
raise RuntimeError("eglMakeCurrent() failed")
logger.debug("_init_egl: completed successfully")
return display, context, surface, _EGL
except Exception:
logger.debug("_init_egl: failed, cleaning up")
# Clean up any resources on failure
if surface is not None:
eglDestroySurface(display, surface)
if context is not None:
eglDestroyContext(display, context)
if display is not None:
eglTerminate(display)
raise
def _init_osmesa():
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
import ctypes
logger.debug("_init_osmesa: starting")
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
logger.debug("_init_osmesa: importing OpenGL.osmesa")
from OpenGL import GL as _gl
from OpenGL.osmesa import (
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
OSMESA_RGBA,
)
logger.debug("_init_osmesa: imports completed")
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
if not ctx:
raise RuntimeError("OSMesaCreateContextExt() failed")
width, height = 64, 64
buffer = (ctypes.c_ubyte * (width * height * 4))()
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
OSMesaDestroyContext(ctx)
raise RuntimeError("OSMesaMakeCurrent() failed")
logger.debug("_init_osmesa: completed successfully")
return ctx, buffer
class GLContext:
"""Manages OpenGL context and resources for shader execution.
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
"""
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
_instance = None
_initialized = False
@ -327,131 +240,105 @@ class GLContext:
def __init__(self):
if GLContext._initialized:
logger.debug("GLContext.__init__: already initialized, skipping")
return
logger.debug("GLContext.__init__: starting initialization")
global glfw, EGL
import time
start = time.perf_counter()
self._backend = None
self._window = None
self._egl_display = None
self._egl_context = None
self._egl_surface = None
self._osmesa_ctx = None
self._osmesa_buffer = None
self._display = None
self._surface = None
self._context = None
self._vao = None
# Try backends in order: GLFW → EGL → OSMesa
errors = []
logger.debug("GLContext.__init__: trying GLFW backend")
try:
self._window, glfw = _init_glfw()
self._backend = "glfw"
logger.debug("GLContext.__init__: GLFW backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
errors.append(("GLFW", e))
self._display, self._egl_major, self._egl_minor = _get_egl_display()
if self._backend is None:
logger.debug("GLContext.__init__: trying EGL backend")
try:
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
self._backend = "egl"
logger.debug("GLContext.__init__: EGL backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
errors.append(("EGL", e))
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
if self._backend is None:
logger.debug("GLContext.__init__: trying OSMesa backend")
try:
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
self._backend = "osmesa"
logger.debug("GLContext.__init__: OSMesa backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
errors.append(("OSMesa", e))
config = EGL.EGLConfig()
n_configs = ctypes.c_int32(0)
if not EGL.eglChooseConfig(
self._display,
_egl_attribs(
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
),
ctypes.byref(config), 1, ctypes.byref(n_configs),
) or n_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
if self._backend is None:
if sys.platform == "win32":
platform_help = (
"Windows: Ensure GPU drivers are installed and display is available.\n"
" CPU-only/headless mode is not supported on Windows."
)
elif sys.platform == "darwin":
platform_help = (
"macOS: GLFW is not supported.\n"
" Install OSMesa via Homebrew: brew install mesa\n"
" Then: pip install PyOpenGL PyOpenGL-accelerate"
)
else:
platform_help = (
"Linux: Install one of these backends:\n"
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
" Headless (CPU): sudo apt install libosmesa6"
)
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
raise RuntimeError(
f"Failed to create OpenGL context.\n\n"
f"Backend errors:\n{error_details}\n\n"
f"{platform_help}"
self._surface = EGL.eglCreatePbufferSurface(
self._display, config,
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
)
if not self._surface:
raise RuntimeError("eglCreatePbufferSurface() failed")
# Now import OpenGL.GL (after context is current)
logger.debug("GLContext.__init__: importing OpenGL.GL")
_import_opengl()
self._context = EGL.eglCreateContext(
self._display, config, EGL.EGL_NO_CONTEXT,
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
)
if not self._context:
raise RuntimeError("eglCreateContext() failed")
# Create VAO (required for core profile, but OSMesa may use compat profile)
logger.debug("GLContext.__init__: creating VAO")
try:
vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(vao)
self._vao = vao # Only store after successful bind
logger.debug("GLContext.__init__: VAO created successfully")
except Exception as e:
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
# OSMesa with older Mesa may not support VAOs
# Clean up if we created but couldn't bind
if vao:
try:
gl.glDeleteVertexArrays(1, [vao])
except Exception:
pass
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
raise RuntimeError("eglMakeCurrent() failed")
self._vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(self._vao)
except Exception:
self._cleanup()
raise
elapsed = (time.perf_counter() - start) * 1000
# Log device info
renderer = gl.glGetString(gl.GL_RENDERER)
vendor = gl.glGetString(gl.GL_VENDOR)
version = gl.glGetString(gl.GL_VERSION)
renderer = renderer.decode() if renderer else "Unknown"
vendor = vendor.decode() if vendor else "Unknown"
version = version.decode() if version else "Unknown"
renderer = _gl_str(gl.GL_RENDERER)
vendor = _gl_str(gl.GL_VENDOR)
version = _gl_str(gl.GL_VERSION)
GLContext._initialized = True
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}")
def make_current(self):
if self._backend == "glfw":
glfw.make_context_current(self._window)
elif self._backend == "egl":
from OpenGL.EGL import eglMakeCurrent
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
elif self._backend == "osmesa":
from OpenGL.osmesa import OSMesaMakeCurrent
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
err = EGL.eglGetError()
raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})")
if self._vao is not None:
gl.glBindVertexArray(self._vao)
def _cleanup(self):
if not self._display:
return
try:
if self._vao is not None:
gl.glDeleteVertexArrays(1, [self._vao])
self._vao = None
except Exception:
pass
try:
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
except Exception:
pass
try:
if self._context:
EGL.eglDestroyContext(self._display, self._context)
except Exception:
pass
try:
if self._surface:
EGL.eglDestroySurface(self._display, self._surface)
except Exception:
pass
try:
EGL.eglTerminate(self._display)
except Exception:
pass
self._display = None
def _compile_shader(source: str, shader_type: int) -> int:
"""Compile a shader and return its ID."""
@ -459,8 +346,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
gl.glShaderSource(shader, source)
gl.glCompileShader(shader)
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
error = gl.glGetShaderInfoLog(shader).decode()
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
error = gl.glGetShaderInfoLog(shader)
if isinstance(error, bytes):
error = error.decode(errors="replace")
gl.glDeleteShader(shader)
raise RuntimeError(f"Shader compilation failed:\n{error}")
@ -484,8 +373,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
gl.glDeleteShader(vertex_shader)
gl.glDeleteShader(fragment_shader)
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
error = gl.glGetProgramInfoLog(program).decode()
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
error = gl.glGetProgramInfoLog(program)
if isinstance(error, bytes):
error = error.decode(errors="replace")
gl.glDeleteProgram(program)
raise RuntimeError(f"Program linking failed:\n{error}")
@ -530,9 +421,6 @@ def _render_shader_batch(
ctx = GLContext()
ctx.make_current()
# Convert from GLSL ES to desktop GLSL 330
fragment_source = _convert_es_to_desktop(fragment_code)
# Detect how many outputs the shader actually uses
num_outputs = _detect_output_count(fragment_code)
@ -558,9 +446,9 @@ def _render_shader_batch(
try:
# Compile shaders (once for all batches)
try:
program = _create_program(VERTEX_SHADER, fragment_source)
program = _create_program(VERTEX_SHADER, fragment_code)
except RuntimeError:
logger.error(f"Fragment shader:\n{fragment_source}")
logger.error(f"Fragment shader:\n{fragment_code}")
raise
gl.glUseProgram(program)
@ -723,13 +611,13 @@ def _render_shader_batch(
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
# Read back outputs for this batch
# (glGetTexImage is synchronous, implicitly waits for rendering)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
batch_outputs = []
for tex in output_textures:
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
batch_outputs.append(img[::-1, :, :].copy())
for i in range(num_outputs):
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
buf = np.empty((height, width, 4), dtype=np.float32)
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
batch_outputs.append(buf[::-1, :, :].copy())
# Pad with black images for unused outputs
black_img = np.zeros((height, width, 4), dtype=np.float32)
@ -750,18 +638,18 @@ def _render_shader_batch(
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glUseProgram(0)
for tex in input_textures:
gl.glDeleteTextures(int(tex))
for tex in curve_textures:
gl.glDeleteTextures(int(tex))
for tex in output_textures:
gl.glDeleteTextures(int(tex))
for tex in ping_pong_textures:
gl.glDeleteTextures(int(tex))
if input_textures:
gl.glDeleteTextures(len(input_textures), input_textures)
if curve_textures:
gl.glDeleteTextures(len(curve_textures), curve_textures)
if output_textures:
gl.glDeleteTextures(len(output_textures), output_textures)
if ping_pong_textures:
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
if fbo is not None:
gl.glDeleteFramebuffers(1, [fbo])
for pp_fbo in ping_pong_fbos:
gl.glDeleteFramebuffers(1, [pp_fbo])
if ping_pong_fbos:
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
if program is not None:
gl.glDeleteProgram(program)

View File

@ -123,7 +123,8 @@ class PhotoMakerLoader(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PhotoMakerLoader",
category="experimental/photomaker",
display_name="Load PhotoMaker Model",
category="model/loaders",
inputs=[
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
],
@ -149,7 +150,8 @@ class PhotoMakerEncode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PhotoMakerEncode",
category="experimental/photomaker",
display_name="PhotoMaker Encode",
category="model/conditioning/photomaker",
inputs=[
io.Photomaker.Input("photomaker"),
io.Image.Input("image"),

View File

@ -119,7 +119,7 @@ class StableCascade_SuperResolutionControlnet(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StableCascade_SuperResolutionControlnet",
category="experimental/stable_cascade",
category="experimental/stable cascade",
is_experimental=True,
inputs=[
io.Image.Input("image"),

View File

@ -143,7 +143,7 @@ class VAEDecodeTripoSplat(IO.ComfyNode):
return IO.Schema(
node_id="VAEDecodeTripoSplat",
display_name="TripoSplat Decode",
category="3d/latent",
category="model/latent/triposplat",
description="Decode the sampled TripoSplat latent into a 3D gaussian splat. "
"Modify the number of gaussians to vary the density.",
inputs=[
@ -188,7 +188,7 @@ class TripoSplatSamplingPreview(IO.ComfyNode):
return IO.Schema(
node_id="TripoSplatSamplingPreview",
display_name="TripoSplat Sampling Preview",
category="3d/latent",
category="model/latent/triposplat",
description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded "
"gaussian splat preview at each step.",
inputs=[

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.26.0"
__version__ = "0.27.0"

View File

@ -264,6 +264,59 @@ def annotated_filepath(name: str) -> tuple[str, str | None]:
return name, base_dir
# Content types a browser may execute or render inline. File endpoints that
# serve user-controlled content must force these to download (and ideally set
# Content-Disposition: attachment) to avoid stored XSS. Centralised here so the
# /view and /userdata handlers can't drift apart. mimetypes.guess_type may
# return either the text/* or application/* spelling depending on platform, so
# both are listed.
DANGEROUS_CONTENT_TYPES = {
'text/html', 'text/html-sandboxed', 'application/xhtml+xml',
'text/javascript', 'application/javascript', 'application/x-javascript',
'application/ecmascript', 'text/css',
'image/svg+xml', 'application/xml', 'text/xml',
# message/rfc822 (.mht/.mhtml) can carry script in some browsers.
'message/rfc822',
}
def is_dangerous_content_type(content_type: str | None) -> bool:
"""Return True if a browser may execute or render `content_type` inline.
Normalises before matching so the check can't be slipped past with a
charset/boundary parameter (``text/html; charset=utf-8``) or casing
(``TEXT/HTML``). Any XML dialect (``*+xml`` or ``*/xml``) is treated as
dangerous because XML can carry inline script via stylesheet/entity tricks,
which also covers the ``application/{xslt,rss,atom,rdf}+xml`` family without
enumerating each one. Endpoints serving user-controlled content should route
a dangerous type to ``application/octet-stream`` + ``Content-Disposition:
attachment`` + ``X-Content-Type-Options: nosniff``.
"""
if not content_type:
return False
normalized = content_type.split(';', 1)[0].strip().lower()
if normalized in DANGEROUS_CONTENT_TYPES:
return True
return normalized.endswith('+xml') or normalized.endswith('/xml')
def is_within_directory(directory: str, target: str) -> bool:
"""Return True if `target` resolves to a path inside `directory`.
Uses realpath on both operands so that a symlink placed inside `directory`
that points elsewhere cannot escape the containment check at open time.
"""
try:
directory = os.path.realpath(directory)
target = os.path.realpath(target)
return os.path.commonpath((directory, target)) == directory
except ValueError:
# ValueError is raised by realpath() on a path with an embedded null
# byte, and by commonpath() on Windows when the paths are on different
# drives. In either case the target is not safely within the directory.
return False
def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
name, base_dir = annotated_filepath(name)
@ -273,7 +326,12 @@ def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
else:
base_dir = get_input_directory() # fallback path
return os.path.join(base_dir, name)
filepath = os.path.abspath(os.path.join(base_dir, name))
# Prevent path traversal: the resolved path must stay within base_dir.
# repr() the name in the message so a crafted value can't inject log lines.
if not is_within_directory(base_dir, filepath):
raise ValueError("Invalid file path: {!r}".format(name))
return filepath
def exists_annotated_filepath(name) -> bool:
@ -282,7 +340,10 @@ def exists_annotated_filepath(name) -> bool:
if base_dir is None:
base_dir = get_input_directory() # fallback path
filepath = os.path.join(base_dir, name)
filepath = os.path.abspath(os.path.join(base_dir, name))
# Treat traversal attempts as non-existent rather than probing the filesystem.
if not is_within_directory(base_dir, filepath):
return False
return os.path.exists(filepath)

View File

@ -314,7 +314,7 @@ def prompt_worker(q, server_instance):
cache_ram = 0
cache_ram_inactive = 0
if not args.cache_classic and not args.cache_none and args.cache_lru <= 0:
cache_ram = min(10.0, max(2.0, comfy.model_management.total_ram * 0.10 / 1024.0))
cache_ram = min(10.0, max(1.5, comfy.model_management.total_ram * 0.05 / 1024.0))
cache_ram_inactive = min(96.0, comfy.model_management.total_ram / 1024.0)
if len(args.cache_ram) > 0:
cache_ram = args.cache_ram[0]
@ -403,7 +403,7 @@ def prompt_worker(q, server_instance):
hook_breaker_ac10a0.restore_functions()
if not asset_seeder.is_disabled():
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=args.enable_asset_hashing)
asset_seeder.resume()
@ -458,7 +458,7 @@ def setup_database():
if dependencies_available():
init_db()
if args.enable_assets:
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True):
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=args.enable_asset_hashing):
logging.info("Background asset scan initiated for models, input, output")
except Exception as e:
if "database is locked" in str(e):

View File

@ -159,6 +159,29 @@ class ConditioningConcat:
return (out, )
class ConditioningMultiply:
SEARCH_ALIASES = ["scale conditioning", "scale prompt", "multiply conditioning", "multiply prompt"]
@classmethod
def INPUT_TYPES(cls):
return {"required": {"conditioning": ("CONDITIONING", ),
"multiplier": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01})
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "multiply"
CATEGORY = "model/conditioning/transform"
def multiply(self, conditioning, multiplier):
c = []
for t in conditioning:
values = {}
pooled_output = t[1].get("pooled_output", None)
if pooled_output is not None:
values["pooled_output"] = pooled_output * multiplier
scaled = node_helpers.conditioning_set_values([[t[0] * multiplier, t[1]]], values)[0]
c.append(scaled)
return (c,)
class ConditioningSetArea:
SEARCH_ALIASES = ["regional prompt", "area prompt", "spatial conditioning", "localized prompt"]
@ -326,7 +349,7 @@ class VAEDecodeTiled:
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "experimental"
CATEGORY = "model/latent"
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
if tile_size < overlap * 4:
@ -373,7 +396,7 @@ class VAEEncodeTiled:
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "experimental"
CATEGORY = "model/latent"
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
@ -491,7 +514,7 @@ class SaveLatent:
OUTPUT_NODE = True
CATEGORY = "experimental"
CATEGORY = "model/latent"
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
@ -536,7 +559,7 @@ class LoadLatent:
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
return {"required": {"latent": [sorted(files), ]}, }
CATEGORY = "experimental"
CATEGORY = "model/latent"
RETURN_TYPES = ("LATENT", )
FUNCTION = "load"
@ -2050,6 +2073,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningAverage": ConditioningAverage,
"ConditioningCombine": ConditioningCombine,
"ConditioningConcat": ConditioningConcat,
"ConditioningMultiply": ConditioningMultiply,
"ConditioningSetArea": ConditioningSetArea,
"ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
"ConditioningSetAreaStrength": ConditioningSetAreaStrength,
@ -2121,6 +2145,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ConditioningAverage ": "Conditioning (Average)",
"ConditioningAverage": "Conditioning (Average)",
"ConditioningConcat": "Conditioning (Concat)",
"ConditioningMultiply": "Conditioning (Multiply)",
"ConditioningSetArea": "Conditioning (Set Area)",
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
"ConditioningSetAreaStrength": "Conditioning (Set Area Strength)",
@ -2130,6 +2155,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"GLIGENTextBoxApply": "Apply GLIGEN Text Box",
"ConditioningZeroOut": "Conditioning Zero Out",
# Latent
"LoadLatent": "Load Latent",
"SaveLatent": "Save Latent",
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
"SetLatentNoiseMask": "Set Latent Noise Mask",
"VAEDecode": "VAE Decode",
@ -2164,7 +2191,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImageSharpen": "Sharpen Image",
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
"GetImageSize": "Get Image Size",
# experimental
"VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)",
}

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.26.0"
version = "0.27.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.45.19
comfyui-workflow-templates==0.10.7
comfyui-embedded-docs==0.5.5
comfyui-frontend-package==1.45.20
comfyui-workflow-templates==0.11.2
comfyui-embedded-docs==0.5.6
torch
torchsde
torchvision
@ -22,7 +22,7 @@ alembic
SQLAlchemy>=2.0.0
filelock
av>=16.0.0
comfy-kitchen==0.2.12
comfy-kitchen==0.2.16
comfy-aimdo==0.4.10
requests
simpleeval>=1.0.0
@ -33,5 +33,5 @@ kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
PyOpenGL
glfw
PyOpenGL>=3.1.8
comfy-angle

View File

@ -127,6 +127,7 @@ def create_cors_middleware(allowed_origin: str):
return cors_middleware
def is_loopback(host):
if host is None:
return False
@ -616,15 +617,30 @@ class PromptServer():
or 'application/octet-stream'
)
# For security, force certain mimetypes to download instead of display
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
content_type = 'application/octet-stream' # Forces download
# For security, force renderable/active types (HTML, JS,
# CSS, SVG, XML — anything that can carry inline <script>
# and execute in the page origin) to download instead of
# displaying inline, preventing stored XSS. The
# attachment disposition is the load-bearing guard: a
# bare filename= hint does not force a download per
# RFC 6266, so we only attach it on the dangerous branch
# to avoid breaking inline display of legitimate images.
# Escape backslash/quote per RFC 6266 quoted-string so a
# filename containing a double quote (which passes the
# ".."/leading-slash filter above) can't break out of the
# header's quoted-string and malform the disposition.
safe_filename = filename.replace("\\", "\\\\").replace('"', '\\"')
disposition = f"filename=\"{safe_filename}\""
if folder_paths.is_dangerous_content_type(content_type):
content_type = 'application/octet-stream'
disposition = f"attachment; filename=\"{safe_filename}\""
return web.FileResponse(
file,
headers={
"Content-Disposition": f"filename=\"{filename}\"",
"Content-Type": content_type
"Content-Disposition": disposition,
"Content-Type": content_type,
"X-Content-Type-Options": "nosniff"
}
)

View File

@ -1,3 +1,5 @@
import contextlib
import json
import time
import uuid
from datetime import datetime
@ -9,6 +11,40 @@ import requests
from helpers import get_asset_filename, trigger_sync_seed_assets
def test_download_svg_forced_to_attachment(http: requests.Session, api_base: str):
"""GHSA-779p-m5rp-r4h4 CISA-5 (sibling route): an uploaded SVG must never be
served inline from GET /api/assets/{id}/content, or an inline <script> runs
in the app origin (stored XSS). Even with disposition=inline requested, a
dangerous content type must be forced to application/octet-stream +
Content-Disposition: attachment + nosniff. Regression guard for the stale
inline blocklist that previously omitted image/svg+xml and ignored the
centralized folder_paths.is_dangerous_content_type check.
"""
svg = b'<svg xmlns="http://www.w3.org/2000/svg"><script>alert(1)</script></svg>'
files = {"file": ("evil.svg", svg, "image/svg+xml")}
form_data = {
"tags": json.dumps(["models", "checkpoints", "unit-tests", "svgxss"]),
"name": "evil.svg",
}
up = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
body = up.json()
assert up.status_code in (200, 201), body
aid = body["id"]
try:
r = http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline", timeout=120)
r.content
assert r.status_code == 200
ct = r.headers.get("Content-Type", "").lower()
cd = r.headers.get("Content-Disposition", "").lower()
assert "svg" not in ct, f"SVG served with a renderable content type: {ct!r}"
assert ct.startswith("application/octet-stream"), f"expected octet-stream, got {ct!r}"
assert "attachment" in cd, f"inline disposition not overridden to attachment: {cd!r}"
assert r.headers.get("X-Content-Type-Options", "").lower() == "nosniff"
finally:
with contextlib.suppress(Exception):
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]

View File

@ -53,8 +53,11 @@ def test_annotated_filepath():
def test_get_annotated_filepath():
default_dir = "/default/dir"
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
# get_annotated_filepath now normalizes with os.path.abspath (part of the
# GHSA-779p traversal hardening), so compare against the normalized form —
# on Windows abspath also prepends the current drive letter.
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.abspath(os.path.join(default_dir, "test.txt"))
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.abspath(os.path.join(folder_paths.get_output_directory(), "test.txt"))
def test_add_model_folder_path_append(clear_folder_paths):
folder_paths.add_model_folder_path("test_folder", "/default/path", is_default=True)

View File

View File

@ -0,0 +1,192 @@
"""CI unit tests for FIX #2 of GHSA-779p-m5rp-r4h4.
Path traversal / hardening in app/model_manager.py get_model_preview
(route /experiment/models/preview/{folder}/{path_index}/{filename:.*}).
Reference: https://github.com/Comfy-Org/ComfyUI/security/advisories/GHSA-779p-m5rp-r4h4
"""
import pytest
import yarl
from io import BytesIO
from PIL import Image
from aiohttp import web
from unittest.mock import patch
from app.model_manager import ModelFileManager
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def model_manager():
return ModelFileManager()
@pytest.fixture
def app(model_manager):
app = web.Application()
routes = web.RouteTableDef()
model_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_legit_preview_returns_200(aiohttp_client, app, tmp_path):
"""Sanity: a real preview PNG inside the model folder is served as webp 200."""
img = Image.new('RGB', (16, 16), color=(255, 0, 128))
img.save(tmp_path / "test_model.png", format='PNG')
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/0/test_model.png')
assert response.status == 200
assert response.content_type == 'image/webp'
img_bytes = BytesIO(await response.read())
served = Image.open(img_bytes)
assert served.format
assert served.format.lower() == 'webp'
served.close()
async def test_non_integer_path_index_returns_400(aiohttp_client, app, tmp_path):
"""A non-integer path_index segment must be rejected with 400."""
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/abc/test_model.png')
assert response.status == 400
async def test_out_of_range_path_index_returns_404(aiohttp_client, app, tmp_path):
"""A path_index beyond the configured folder list must return 404."""
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/99/test_model.png')
assert response.status == 404
async def test_empty_filename_returns_400(aiohttp_client, app, tmp_path):
"""The "{filename:.*}" capture also matches the empty string (trailing
slash). It would resolve to the folder itself and must be rejected with 400."""
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/0/')
assert response.status == 400
async def test_path_traversal_in_filename_returns_403(aiohttp_client, app, tmp_path):
"""Path traversal in {filename} must be rejected with 403 and must NOT read
a file outside the configured model directory.
GOTCHA: aiohttp/yarl collapses literal ``../`` dot-segments out of the URL
path before it reaches the handler, which would make this test vacuously
pass (the request would hit a different/non-existent route). We percent-encode
the dots and slashes (``%2e%2e%2f``) and send the URL with
``yarl.URL(..., encoded=True)`` so the bytes survive client-side normalization
untouched; aiohttp's router then percent-decodes them into ``match_info``,
delivering the literal ``../`` traversal to the handler's ``{filename:.*}``
capture.
Without the fix the handler computes
``os.path.normpath(os.path.join(folder, "../../../../etc/hosts"))``, which
escapes ``tmp_path`` and would be passed straight to get_model_previews ->
Image.open, serving bytes from outside the model dir (200/served bytes). The
is_within_directory() containment check is the load-bearing fix that turns
that escape into a 403.
"""
# Sanity-anchor: a legit preview exists inside tmp_path, so a 200 path is
# genuinely reachable — proving the 403 below is the containment check
# firing, not an unrelated 404.
img = Image.new('RGB', (16, 16), color=(255, 0, 128))
img.save(tmp_path / "test_model.png", format='PNG')
# Percent-encoded "../../../../etc/hosts" so yarl does not collapse the
# dot-segments before the request leaves the client.
encoded_traversal = '%2e%2e%2f' * 4 + 'etc%2fhosts'
raw_path = '/experiment/models/preview/test_folder/0/' + encoded_traversal
url = yarl.URL(raw_path, encoded=True)
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get(url)
# Confirm the traversal actually reached the handler intact: a 200 here
# would mean either normalization stripped the ``../`` (vacuous pass) or
# the containment check failed open and served outside-dir bytes.
assert response.status == 403, (
f"expected 403 from is_within_directory() containment check, "
f"got {response.status}; traversal may have been normalized away "
f"or the fix failed open"
)
body = await response.read()
assert body == b"", "403 response must not carry any file bytes"
async def test_symlink_companion_preview_returns_403(aiohttp_client, app, tmp_path):
"""A companion preview file is selected by a glob inside get_model_previews
and then opened. If that companion is a symlink whose path is in-dir but
whose target escapes the model folder, it must be rejected with 403 — not
served. The requested path itself stays in-dir (so the first containment
check passes); the load-bearing fix is the SECOND is_within_directory check
on the file actually opened.
"""
model_dir = tmp_path / "models"
model_dir.mkdir()
secret_dir = tmp_path / "secret"
secret_dir.mkdir()
# A real image OUTSIDE the model dir — valid, so without the fix Image.open
# would succeed and its bytes would be served (200).
secret = secret_dir / "secret.png"
Image.new('RGB', (8, 8), color=(0, 0, 0)).save(secret, format='PNG')
# Companion preview, in-dir by name but a symlink escaping the model dir.
# (No real model file is needed — get_model_previews globs companions by
# basename, and omitting a .safetensors avoids the metadata-header read.)
companion = model_dir / "model.preview.png"
try:
companion.symlink_to(secret)
except (OSError, NotImplementedError):
pytest.skip("symlinks not supported on this platform/filesystem")
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(model_dir)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/0/model.safetensors')
assert response.status == 403, (
f"expected 403 — the globbed companion preview is a symlink resolving "
f"outside the model dir and must not be served; got {response.status}"
)
assert await response.read() == b""
async def test_null_byte_in_filename_no_500(aiohttp_client, app, tmp_path):
"""A NUL byte in the filename must yield a clean client rejection, not a 500
from an uncaught ValueError in is_within_directory's realpath() call."""
raw_path = '/experiment/models/preview/test_folder/0/' + 'a%00b'
url = yarl.URL(raw_path, encoded=True)
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get(url)
assert response.status != 500, (
f"NUL byte produced a 500 (uncaught ValueError); expected a clean "
f"4xx rejection, got {response.status}"
)
assert 400 <= response.status < 500

View File

@ -0,0 +1,165 @@
"""Security tests for GHSA-779p-m5rp-r4h4 — FIX #3.
Path traversal in folder_paths.get_annotated_filepath / exists_annotated_filepath,
plus the shared is_within_directory() containment helper.
These are pure-function tests (no running server). The input/output/temp
directories are pointed at tmp_path via the folder_paths setters, so a crafted
name containing `../`, an absolute path, or a symlink that escapes the base
directory must be rejected.
Reference: https://github.com/Comfy-Org/ComfyUI/security/advisories/GHSA-779p-m5rp-r4h4
"""
import os
import pytest
import folder_paths
from comfy.options import enable_args_parsing
enable_args_parsing()
@pytest.fixture
def sandbox(tmp_path):
"""Point folder_paths' input/output/temp dirs at a real temp sandbox.
Yields the realpath'd base, input, output and temp directories. The original
directory values are restored afterward so tests stay isolated.
"""
base = os.path.realpath(str(tmp_path))
input_dir = os.path.join(base, "input")
output_dir = os.path.join(base, "output")
temp_dir = os.path.join(base, "temp")
for d in (input_dir, output_dir, temp_dir):
os.makedirs(d, exist_ok=True)
orig_input = folder_paths.get_input_directory()
orig_output = folder_paths.get_output_directory()
orig_temp = folder_paths.get_temp_directory()
folder_paths.set_input_directory(input_dir)
folder_paths.set_output_directory(output_dir)
folder_paths.set_temp_directory(temp_dir)
yield {
"base": base,
"input": input_dir,
"output": output_dir,
"temp": temp_dir,
}
folder_paths.set_input_directory(orig_input)
folder_paths.set_output_directory(orig_output)
folder_paths.set_temp_directory(orig_temp)
# ---------------------------------------------------------------------------
# is_within_directory() — the shared containment helper
# ---------------------------------------------------------------------------
def test_is_within_directory_legit_child(sandbox):
base = sandbox["input"]
child = os.path.join(base, "sub", "image.png")
assert folder_paths.is_within_directory(base, child) is True
def test_is_within_directory_dotdot_escape(sandbox):
base = sandbox["input"]
escape = os.path.join(base, "..", "..", "etc", "passwd")
assert folder_paths.is_within_directory(base, escape) is False
def test_is_within_directory_symlink_escape(sandbox):
"""A symlink created INSIDE base that points OUTSIDE base must not pass.
This is the key new hardening: is_within_directory realpath()s both operands,
so a symlink planted in the base directory can't be used to read files
elsewhere. We create a real on-disk symlink and a real secret target to
verify the check actually resolves the link.
"""
base = sandbox["input"]
# A directory living outside the base, holding a secret file.
outside = os.path.join(sandbox["base"], "outside_secret_dir")
os.makedirs(outside, exist_ok=True)
secret = os.path.join(outside, "secret.txt")
with open(secret, "w") as f:
f.write("top secret")
# Plant a symlink inside base that points at the outside directory.
# symlink creation can require elevated privileges / Developer Mode on
# Windows, so skip cleanly where it isn't available (same guard as the
# sibling test in test_ghsa_779p_02_preview_traversal.py).
link = os.path.join(base, "escape_link")
try:
os.symlink(outside, link)
except (OSError, NotImplementedError):
pytest.skip("symlinks not supported on this platform/filesystem")
# Accessing the secret "through" the in-base symlink must be rejected.
target_via_link = os.path.join(link, "secret.txt")
assert folder_paths.is_within_directory(base, target_via_link) is False
# ---------------------------------------------------------------------------
# get_annotated_filepath()
# ---------------------------------------------------------------------------
def test_get_annotated_filepath_legit_name(sandbox):
result = folder_paths.get_annotated_filepath("image.png")
assert result == os.path.join(sandbox["input"], "image.png")
assert folder_paths.is_within_directory(sandbox["input"], result)
def test_get_annotated_filepath_input_annotation(sandbox):
result = folder_paths.get_annotated_filepath("image.png [input]")
assert result == os.path.join(sandbox["input"], "image.png")
def test_get_annotated_filepath_output_annotation(sandbox):
result = folder_paths.get_annotated_filepath("image.png [output]")
assert result == os.path.join(sandbox["output"], "image.png")
def test_get_annotated_filepath_temp_annotation(sandbox):
result = folder_paths.get_annotated_filepath("image.png [temp]")
assert result == os.path.join(sandbox["temp"], "image.png")
def test_get_annotated_filepath_dotdot_raises(sandbox):
with pytest.raises(ValueError):
folder_paths.get_annotated_filepath("../etc/passwd")
def test_get_annotated_filepath_dotdot_with_annotation_raises(sandbox):
with pytest.raises(ValueError):
folder_paths.get_annotated_filepath("../../etc/passwd [output]")
def test_get_annotated_filepath_absolute_escape_raises(sandbox):
with pytest.raises(ValueError):
folder_paths.get_annotated_filepath("/etc/passwd")
# ---------------------------------------------------------------------------
# exists_annotated_filepath()
# ---------------------------------------------------------------------------
def test_exists_annotated_filepath_existing_legit_file(sandbox):
real = os.path.join(sandbox["input"], "real.png")
with open(real, "w") as f:
f.write("data")
assert folder_paths.exists_annotated_filepath("real.png") is True
def test_exists_annotated_filepath_traversal_returns_false(sandbox):
"""A traversal name must return False without raising and without probing
outside the base directory (must never reach os.path.exists for the escape).
"""
# /etc/passwd exists on POSIX; the function must still report False because
# the resolved path escapes the input directory.
assert folder_paths.exists_annotated_filepath("../../../../../../etc/passwd") is False
def test_exists_annotated_filepath_absolute_returns_false(sandbox):
assert folder_paths.exists_annotated_filepath("/etc/passwd") is False

View File

@ -0,0 +1,147 @@
"""
CI unit tests for FIX #4 of GHSA-779p-m5rp-r4h4.
Stored-XSS hardening on GET /userdata/{file} in app/user_manager.py.
User data files are arbitrary user-supplied content and must never render
inline in the app origin. The getuserdata handler:
- forces Content-Type to application/octet-stream for any type in
folder_paths.DANGEROUS_CONTENT_TYPES (text/html, image/svg+xml,
text/javascript, ...),
- sets X-Content-Type-Options: nosniff,
- sets Content-Disposition: attachment.
These tests pre-create files in tmp_path and GET them back, asserting the
secure response headers. They mirror the aiohttp_client pattern in
tests-unit/prompt_server_test/user_manager_test.py.
"""
import pytest
import os
from aiohttp import web
from app.user_manager import UserManager
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def user_manager(tmp_path):
um = UserManager()
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
tmp_path, file
) if file else tmp_path
return um
@pytest.fixture
def app(user_manager):
app = web.Application()
routes = web.RouteTableDef()
user_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_html_served_as_octet_stream(aiohttp_client, app, tmp_path):
(tmp_path / "evil.html").write_text(
"<script>console.log('xss-marker-ghsa-779p')</script>"
)
client = await aiohttp_client(app)
resp = await client.get("/userdata/evil.html")
assert resp.status == 200
ct = resp.headers.get("Content-Type", "")
# The load-bearing assertion: a .html file must NOT be served as text/html.
assert "text/html" not in ct.lower(), (
f"Content-Type {ct!r} would let a browser render/execute the file (stored XSS)."
)
assert ct == "application/octet-stream"
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
assert "attachment" in resp.headers.get("Content-Disposition", "")
async def test_svg_served_as_octet_stream(aiohttp_client, app, tmp_path):
(tmp_path / "evil.svg").write_text(
'<?xml version="1.0"?>'
'<svg xmlns="http://www.w3.org/2000/svg">'
'<script>console.log("xss-marker-ghsa-779p")</script>'
"</svg>"
)
client = await aiohttp_client(app)
resp = await client.get("/userdata/evil.svg")
assert resp.status == 200
ct = resp.headers.get("Content-Type", "")
# SVG can carry inline <script>; it must not be served as image/svg+xml.
assert "svg" not in ct.lower(), (
f"Content-Type {ct!r} would let a browser render the SVG and execute embedded scripts."
)
assert ct == "application/octet-stream"
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
assert "attachment" in resp.headers.get("Content-Disposition", "")
async def test_js_served_as_octet_stream(aiohttp_client, app, tmp_path):
(tmp_path / "evil.js").write_text("alert('xss-marker-ghsa-779p')")
client = await aiohttp_client(app)
resp = await client.get("/userdata/evil.js")
assert resp.status == 200
ct = resp.headers.get("Content-Type", "").lower()
# Must not be served as any executable JavaScript content type.
assert "javascript" not in ct, (
f"Content-Type {ct!r} is an executable JS type."
)
assert "ecmascript" not in ct, (
f"Content-Type {ct!r} is an executable JS type."
)
assert ct == "application/octet-stream"
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
assert "attachment" in resp.headers.get("Content-Disposition", "")
async def test_xml_dialect_served_as_octet_stream(aiohttp_client, app, tmp_path):
"""An XML dialect outside the original blocklist (.xslt -> application/xslt+xml)
must still be forced to download. This pins the normalised *+xml family rule
in folder_paths.is_dangerous_content_type(); a plain set-membership test would
have served this inline."""
(tmp_path / "evil.xslt").write_text(
'<?xml version="1.0"?>'
'<xsl:stylesheet version="1.0" '
'xmlns:xsl="http://www.w3.org/1999/XSL/Transform">'
"<!-- xss-marker-ghsa-779p -->"
"</xsl:stylesheet>"
)
client = await aiohttp_client(app)
resp = await client.get("/userdata/evil.xslt")
assert resp.status == 200
ct = resp.headers.get("Content-Type", "")
assert ct == "application/octet-stream", (
f"Content-Type {ct!r}: an *+xml dialect must be forced to octet-stream "
f"(it can carry inline script via stylesheet/entity tricks)."
)
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
assert "attachment" in resp.headers.get("Content-Disposition", "")
async def test_benign_txt_still_served(aiohttp_client, app, tmp_path):
(tmp_path / "note.txt").write_text("just a harmless note")
client = await aiohttp_client(app)
resp = await client.get("/userdata/note.txt")
assert resp.status == 200
assert await resp.text() == "just a harmless note"
ct = resp.headers.get("Content-Type", "")
# text/plain is not in the dangerous set, so it is acceptable here. The
# defence-in-depth headers must still be present regardless.
assert "text/plain" in ct.lower()
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
assert "attachment" in resp.headers.get("Content-Disposition", "")

View File

@ -0,0 +1,138 @@
"""CI unit guard for FIX #5 of GHSA-779p-m5rp-r4h4 — the /view forced-download set.
Vuln #5 was stored XSS via SVG upload: the /view endpoint's Content-Type
blocklist covered text/html, text/javascript, etc. but was missing
image/svg+xml, so an uploaded SVG carrying an inline <script> was served as
image/svg+xml and executed in the page origin when rendered.
The /view forced-download decision lives in the view_image closure registered by
server.PromptServer.add_routes (server.py ~line 596), which calls
`folder_paths.is_dangerous_content_type(content_type)` — a normalising check that
strips charset/boundary parameters and casing and folds in the whole */xml and
*+xml dialect family — rather than a bypassable raw
`content_type in folder_paths.DANGEROUS_CONTENT_TYPES` membership test. On a match
it rewrites the response to application/octet-stream with a
Content-Disposition: attachment header. server.py cannot be imported in a unit
test (importing it spins up the full PromptServer/aiohttp app and its global side
effects), so these tests pin the underlying dangerous-content data
(folder_paths.DANGEROUS_CONTENT_TYPES) and the normalising is_dangerous_content_type()
helper that the closure actually calls.
The end-to-end /view assertion (upload an SVG, GET /view, confirm the response
is not served as image/svg+xml) lives in the live POC at
.security/pocs/test_security_ghsa_779p.py::TestViewSvgContentType, which
requires a running server. This file is the fast, server-free CI guard on the
set contents so the blocklist can't silently regress.
"""
import folder_paths
# Active/renderable content types that must be forced to download. Each of these
# can carry an inline <script> (or otherwise execute) in the page origin if a
# browser renders it. image/svg+xml is the original missing item that caused
# vuln #5.
DANGEROUS = [
'image/svg+xml',
'application/xml',
'text/xml',
'text/html',
'text/html-sandboxed',
'application/xhtml+xml',
'text/javascript',
'application/javascript',
'application/x-javascript',
'application/ecmascript',
'text/css',
]
# Benign image types that browsers display inline and that must keep rendering;
# forcing these to download would break legitimate previews.
BENIGN_INLINE_IMAGES = [
'image/png',
'image/jpeg',
'image/webp',
'image/gif',
]
def test_dangerous_content_types_is_a_set():
assert isinstance(folder_paths.DANGEROUS_CONTENT_TYPES, set)
def test_svg_is_in_the_blocklist():
"""The specific item whose absence caused vuln #5."""
assert 'image/svg+xml' in folder_paths.DANGEROUS_CONTENT_TYPES, (
"image/svg+xml missing from DANGEROUS_CONTENT_TYPES — this is exactly "
"the regression that reopens GHSA-779p-m5rp-r4h4 vuln #5 (stored XSS "
"via SVG upload on /view)."
)
def test_all_dangerous_types_present():
missing = [ct for ct in DANGEROUS if ct not in folder_paths.DANGEROUS_CONTENT_TYPES]
assert not missing, (
f"DANGEROUS_CONTENT_TYPES is missing required active/renderable types: "
f"{missing}. The /view closure only forces a download for content types "
f"in this set; anything missing here is served inline and can execute."
)
def test_benign_inline_image_types_absent():
leaked = [ct for ct in BENIGN_INLINE_IMAGES if ct in folder_paths.DANGEROUS_CONTENT_TYPES]
assert not leaked, (
f"Benign inline-displayable image types found in DANGEROUS_CONTENT_TYPES: "
f"{leaked}. Forcing these to download would break legitimate image "
f"previews in /view — they must keep rendering inline."
)
# ---------------------------------------------------------------------------
# is_dangerous_content_type() — the normalising check the /view and /userdata
# handlers now call instead of a raw `in DANGEROUS_CONTENT_TYPES` membership
# test. An exact-string membership test was bypassable with a charset parameter
# or odd casing, and missed the wider XML dialect family; these tests pin the
# normalisation so that bypass can't reopen.
# ---------------------------------------------------------------------------
def test_function_matches_plain_dangerous_types():
for ct in DANGEROUS:
assert folder_paths.is_dangerous_content_type(ct) is True, ct
def test_function_strips_parameters_and_casing():
"""A charset/boundary parameter or casing must not slip a type past the check.
This is the bypass surfaced by review: the /view blake3 branch can serve an
attacker-controlled, unvalidated asset mime_type like 'text/html; charset=utf-8',
which an exact-string set test missed.
"""
for ct in (
'text/html; charset=utf-8',
'TEXT/HTML',
'Text/HTML; charset=UTF-8',
'image/svg+xml; charset=utf-8',
' text/html ',
):
assert folder_paths.is_dangerous_content_type(ct) is True, ct
def test_function_covers_xml_dialect_family():
"""Any *+xml / */xml dialect is dangerous without enumerating each one."""
for ct in (
'application/xslt+xml',
'application/rss+xml',
'application/atom+xml',
'application/rdf+xml',
'application/mathml+xml',
'message/rfc822',
):
assert folder_paths.is_dangerous_content_type(ct) is True, ct
def test_function_allows_benign_and_empty():
for ct in BENIGN_INLINE_IMAGES + ['application/octet-stream', 'text/plain']:
assert folder_paths.is_dangerous_content_type(ct) is False, ct
# None / empty (mimetypes.guess_type miss) must not be treated as dangerous.
assert folder_paths.is_dangerous_content_type(None) is False
assert folder_paths.is_dangerous_content_type('') is False