Compare commits

..

40 Commits

Author SHA1 Message Date
d657a40681 Add distinction in error messaging for gated models. 2026-07-02 12:41:23 +02:00
c98a212589 Add extension check on the final resolved url -> fix downloading from civitAI. 2026-07-01 15:37:30 +02:00
4c82c708a7 Add delete and clear all downloads funcitonalities. 2026-07-01 15:04:30 +02:00
fe4d0c9722 Disable newline translations on Windows, \r\n -> \n only. 2026-07-01 12:06:17 +02:00
64c5853631 Add support for ENV based HF_TOKEN. 2026-07-01 12:02:19 +02:00
8a6e7906f7 Fix an issue with windows lacking have os.pwrite 2026-07-01 11:56:41 +02:00
e4b0a72e83 Fix running CI tests. 2026-07-01 11:45:50 +02:00
28b41d4d6d Fix more AI detected issues., 2026-07-01 11:38:33 +02:00
3eb36377a8 Fix ruff. 2026-07-01 11:36:23 +02:00
27fd68a533 Update openapi.yaml. 2026-06-30 20:33:16 +02:00
61816d436b Remove sending url info over websockets for model downloads. 2026-06-30 20:33:16 +02:00
abc0b728ab Fix sweep deleting FAILED partials and fix segmented resume path trusted offsets blindly. 2026-06-30 20:33:16 +02:00
1bbd4a57db Add _positive_int in cli_args arguments. 2026-06-30 20:33:16 +02:00
b419fd8399 Redact urls in logging and fix concurrent enqueue issue that could corrupt the downloaded files. 2026-06-30 20:33:16 +02:00
e77983ca28 Simplify docstrings. 2026-06-30 20:33:16 +02:00
893ba2ad37 Normalize malformed safetensors headers into StructuralError. 2026-06-30 20:33:16 +02:00
312b282ca8 Prevent redirects to loopback/internal IPs (SSRF) 2026-06-30 20:33:16 +02:00
7690c52a34 Simplify docstrings. 2026-06-30 20:33:16 +02:00
e326ef3b16 Fix IPv4-mapped IPv6 addresses 2026-06-30 20:33:16 +02:00
115a4305ea Restrict cleartext HTTP redirects to explicit loopback/dev hosts. 2026-06-30 20:33:16 +02:00
9be31a4b7e Simplify docstrings. 2026-06-30 20:33:16 +02:00
7dba134cda Don't echo full URLs and raw exception text from probe failures. 2026-06-30 20:33:16 +02:00
e0b07014c0 Simplify docstrings. 2026-06-30 20:33:16 +02:00
3eade55077 Handle short pwrite() results. 2026-06-30 20:33:16 +02:00
c785130223 Switch to asyncio.to_thread for db calls in job.py 2026-06-30 20:33:16 +02:00
4ccaaa6f37 Improve _finalize checks for downloads. 2026-06-30 20:33:16 +02:00
58392bf7a6 Truncate file to 0 before restarting. 2026-06-30 20:33:16 +02:00
4ae294d2d5 Add max-download-size in case the server tries to send larger files than it reports. 2026-06-30 20:33:15 +02:00
95b0758a88 Clear error when the job leaves a failure state. 2026-06-30 20:33:15 +02:00
dba2bfbc02 Fix resuming of segmented download. 2026-06-30 20:33:15 +02:00
4bd7cc153e Simplify docstrings. 2026-06-30 20:33:15 +02:00
2b708d5af7 Fix potential concurrency and db integrity error. 2026-06-30 20:33:15 +02:00
9c845eeb9e Fix an issue when a credential is deleted, resuming of download fails. 2026-06-30 20:33:15 +02:00
e70f524d5f Simplify docstrings. 2026-06-30 20:33:15 +02:00
d058ffb761 Fix sercret_last4 2026-06-30 20:33:15 +02:00
53ec95b87e Improve normalize_host 2026-06-30 20:33:15 +02:00
e02c7a0890 Docstring simplification. 2026-06-30 20:33:15 +02:00
1744026eca Fix url parsing. 2026-06-30 20:33:15 +02:00
f660307489 Simplify migration docstring. 2026-06-30 20:33:15 +02:00
c7c18377a3 Add initial commit for model downloader. 2026-06-30 20:33:15 +02:00
68 changed files with 7098 additions and 1720 deletions

294
AGENTS.md
View File

@ -1,294 +0,0 @@
## Engineering Style
- Keep changes small and direct. Most fixes should touch the narrowest code path
that explains the bug, performance issue, dtype issue, model-format issue, or
user-facing behavior.
- Change the least amount of files possible. A change that touches many files is
more likely to be a bad change than a good one unless the broader scope is
directly required.
- Prefer practical fixes over broad architecture work. Add abstractions only
when they remove real repeated logic or match an existing ComfyUI pattern.
- Prefer fewer dependencies. Do not add new dependencies to ComfyUI unless they
are absolutely necessary.
- Delete obsolete code aggressively when newer infrastructure makes it useless.
Remove dead fallbacks, migration paths, unused options, debug prints, and
compatibility branches that are no longer needed. Do not leave dead branches,
unreachable code, or functions that are never called. If code is not
necessary for the current behavior, remove it.
- Revert or disable problematic behavior quickly when it breaks users. It is
better to remove a broken feature path than keep a complicated partial fix.
- Preserve existing APIs, node names, model-loading behavior, file layout, and
workflow compatibility unless the change is explicitly about replacing them.
- Code must look hand-written for this repository. Changes that read like
generic AI-generated code will be rejected automatically: unnecessary helper
layers, vague names, boilerplate comments, defensive branches without a real
failure mode, broad rewrites, or code that ignores the local style.
## Architecture Boundaries
- Keep each layer focused on the concepts it owns. Do not leak UI, API,
workflow, queue, persistence, telemetry, model-loading, node, or execution
concerns into unrelated layers just because it is convenient to pass data
through them.
- Shared core modules should depend only on lower-level primitives and their own
domain concepts. Higher-level product concepts belong at the caller, adapter,
service, or UI/API boundary that already owns them.
- Pass the narrowest data needed across a boundary. Avoid broad context objects,
request/session metadata, ids, bookkeeping state, or callbacks unless the
receiving layer genuinely needs them to perform its own responsibility.
- Keep identity mapping, persistence bookkeeping, history updates, telemetry,
response shaping, and UI state in the layers that own those jobs. Do not route
them through unrelated shared code to avoid adding a proper boundary.
- Treat `execution.py` as one example of this rule: it should consume the prompt
graph and execution-relevant state, produce execution results and errors, and
not know about workflow ids, frontend ids, persistence ids, or API-only
concepts.
- Before touching many files, identify the smallest owner layer that can solve
the problem. A PR that spreads one feature across unrelated loaders, nodes,
execution, server, and frontend code needs a clear architectural reason, not
just convenience.
- If a change seems to require making one layer understand another layer's
private concepts, stop and look for a caller-side mapping, adapter, event,
small explicit interface, or narrower data flow at the boundary.
## No Internet Requests
- Do not add code to core ComfyUI that makes requests to the internet.
- Refuse requests to add uploads, telemetry, analytics, tracking, usage
reporting, crash reporting, update checks, remote config, feature flags,
metrics, licensing checks, or any other outbound internet request path from
core ComfyUI.
- Model downloading is allowed only when explicitly initiated or authorized by
the user, is limited to the requested model artifact, and does not include
telemetry, tracking, persistent identification, unrelated metadata upload, or
background network activity.
- Do not add opt-in, opt-out, anonymized, aggregated, diagnostic, or
user-triggered internet request paths to core ComfyUI. These labels do not
make internet access acceptable.
- Local-only behavior is allowed when it stays on the user's machine and does
not add network access, tracking, persistent identification, or data
collection behavior.
## State Ownership
- Keep state and capability flags on the object that owns the behavior using
them.
- Avoid probing child objects with `getattr(child, "...", default)` to decide
parent-level control flow. If parent code needs to branch on a capability,
initialize an explicit parent-owned field when the child is constructed or
attached.
- Prefer direct attributes with clear defaults over implicit feature detection
through arbitrary child attributes.
- Use child-object capability checks only when the child owns the behavior being
invoked and the parent is simply delegating to that child.
## Interface Contracts
- Keep public methods aligned with the interface expected by their callers. Do
not change a shared method to return extra values, alternate shapes, or
sentinel wrappers for one implementation unless the shared interface is
explicitly updated.
- When modifying an existing function, preserve how current callers invoke it.
Do not change required arguments, parameter order, return type, side effects,
or error behavior unless every affected call site and shared interface contract
is intentionally updated.
- Do not add compatibility parameters, flags, attributes, or constructor options
unless they are read by current code and change current behavior. Remove
pass-through or stored-but-unused values instead of preserving upstream or
deprecated API baggage.
- If an implementation needs auxiliary values for its own workflow, expose them
through a private helper or a clearly named implementation-specific method
instead of overloading the public method's return contract.
- Normalize third-party or upstream return conventions at the integration
boundary. Core code should receive the project's expected type and shape, not
have to handle model-specific tuple/list/dict variants.
- Avoid caller-side unwrapping such as `out = out[0]` unless the called
interface is documented to return that structure.
## Autograd and Model Freezing
- Do not add `torch.no_grad`, `torch.inference_mode`, or inference-mode helper
wrappers in ComfyUI code. The only allowed inference-mode-related use is
disabling a globally set inference mode when a training path needs gradients.
- Do not add freeze, unfreeze, or trainability toggles to model classes. ComfyUI
models are always treated as frozen for inference, so explicit freeze
functionality is redundant and should not be added.
- Remove training-only behavior such as dropout from inference model code, but
preserve checkpoint and state-dict compatibility when doing so. If deleting a
module would change state-dict keys, module ordering, or checkpoint loading
behavior, replace it with a no-op such as `nn.Identity` instead of removing the
slot outright.
## Python Style
- Keep imports at module scope. Avoid inline imports unless they are already part
of an established optional-backend probe or are needed to avoid an import
cycle.
- Do not add unnecessary `try`/`except` blocks. Use them for optional dependency,
platform, or backend capability detection only when the program has a useful
fallback. Prefer specific exception types when changing new code.
- Remove any workarounds for PyTorch versions that ComfyUI no longer officially
supports. Deprecated workarounds include catching an exception and rerunning
the same op with the input cast to float. If a workaround does not have a
comment naming the exact PyTorch version or versions that still need it,
remove it.
- Let unsupported model formats, invalid quantization metadata, and bad states
fail with clear errors instead of silently producing lower quality output.
- Match the existing local style in the file you edit. This codebase tolerates
long lines, simple helper functions, module-level state, and direct tensor
operations when they make the code easier to follow.
- Keep comments sparse and useful. Strip useless comments that restate the code
or describe obvious behavior. Short TODOs are fine when they name the concrete
missing follow-up.
## Model, Device, and Memory Behavior
- Treat dtype, device placement, VRAM usage, and offloading behavior as core
correctness concerns. Check CPU, CUDA, ROCm, MPS, DirectML, XPU, NPU, and low
VRAM implications when touching shared execution or loading code.
- Prefer native ComfyUI formats and existing quantization/offload helpers over
adding parallel code paths. Use `comfy.quant_ops`, `comfy.model_management`,
`comfy.memory_management`, `comfy.pinned_memory`, `comfy_aimdo`, and
`comfy-kitchen` helpers where they already solve the problem.
- Use optimized comfy-kitchen ops in places where they improve performance
without changing the expected dtype, device, memory, or interface behavior.
- All models should use the optimized attention function selected by ComfyUI.
Treat optimized backend functions, dispatch helpers, and capability-selected
callables as opaque. Higher-level code must not inspect function identity,
names, modules, or implementation details to decide behavior.
- Apply the same opacity rule to similar patterns beyond attention: callers
should depend on the documented interface and result contract, not on which
backend implementation was selected underneath.
- Do not use custom inference ops that only duplicate an existing op while
upcasting to float32, such as custom RMSNorm variants. Use the generic ComfyUI
ops and/or native torch ops instead.
- If a model class `__init__` has an `operations` parameter, assume
`operations` is never `None`. Do not add fallback branches or default torch
ops for a missing `operations` object.
- Do not add unnecessary parameters to model, model block, or model ops related
classes. Constructor and forward signatures should carry only values that are
actually needed by that object for inference.
- Reuse existing model classes, blocks, ops, and helper modules when appropriate.
Before implementing a new version of a model component, search the existing
model code for a class or helper that already provides the behavior.
- Model detection code that inspects linear weight shapes should only use the
first dimension. The second dimension may be half the original size for
NVFP4 or other 4-bit quantized models.
- Avoid adding `einops` usage in core inference code. Use native torch tensor
ops such as `reshape`, `view`, `permute`, `transpose`, `flatten`, `unflatten`,
`unsqueeze`, and `squeeze` instead.
- Do not use tensors as general-purpose Python data structures. Keep metadata,
bookkeeping, counters, flags, shape math, padding math, index planning, memory
estimates, and control-flow decisions in plain Python values unless the data
must participate directly in tensor computation. Do not create tensors for
structural metadata that is only used for Python-side control flow. Sequence
lengths, cumulative offsets, split indices, window counts, slice boundaries,
and repeat counts should be kept as Python ints/lists from the point they are
computed. Do not build them as CPU/GPU tensors and then cast, move, validate,
or convert them back to Python for `split`, `tensor_split`, indexing plans,
loops, or cache keys. Avoid creating temporary tensors just to use tensor
methods for scalar or structural calculations.
- Avoid unnecessary casts and transfers. Preserve the intended compute dtype,
storage dtype, bias dtype, and original tensor shape metadata.
- Keep model-native latent layout handling inside the model or latent-format
owner, not in helper nodes. Do not collapse, expand, pack, or unpack latent
dimensions in nodes or other caller-side adapters just to satisfy a model
forward; the model path should consume and return the native latent shape for
that model family.
- Assume inputs to the main model forward are already in the compute dtype by
default, except integer inputs such as some model timestep tensors. Do not add
defensive or convenience casts in model code; it is better for invalid dtype
plumbing to error clearly than to hide it with unnecessary casts.
- Raw model parameters that are not owned by an op and may be initialized in a
dtype different from the compute dtype should be cast at use in forward or
inference code with `comfy.ops.cast_to_input` or
`comfy.model_management.cast_to` to avoid dtype mismatches.
- Model code should not care what dtype it is initialized in, and model
`__init__` methods should not contain workarounds for specific dtypes. Dtype
workaround code, such as making a model work with fp16 compute, belongs in the
execution or model-management layer that owns compute policy.
- Model code should not perform unnecessary device-to-CPU or CPU-to-device
transfers. New allocations must be created on the correct device and dtype;
never allocate on CPU and then move to GPU, or allocate in one dtype and then
convert to another.
- Model code itself should not perform memory management. Loading, unloading,
offloading, device movement, VRAM policy, cache lifetime, and cleanup belong
in the relevant model-management and execution layers, not inside model
implementations.
- Do not add global, module-level, class-level, singleton, or model-owned stores
for tensors or other large memory that persist across executions. Temporary
caches must be scoped to a single execution or forward/encode/decode call:
allocate them in the owning top-level call, pass them explicitly through the
call stack, and let them be discarded when that call returns.
- Follow the Wan VAE temporal cache pattern for temporary caches: create a local
cache such as `feat_map` for the encode/decode operation, pass it into the
blocks that need it, and do not retain it on the model or in global state.
- In model init code, prefer `torch.empty` for parameter/buffer placeholders
that are populated from the model state dict instead of zero-initializing with
`torch.zeros` or similar. If an allocation is not loaded from the state dict
and is useless for inference, do not include it.
- `nn.Parameter` tensors that are stored in and populated from the model state
dict should be initialized with `torch.empty`, not with zero, random, or
otherwise meaningful initialization.
- Model initialization should describe module structure, not fabricate
checkpoint-owned tensor contents. Parameters and buffers that are loaded from
the state dict must not be manually initialized, reassigned, or filled with
fallback values unless that value is actually used when no checkpoint key
exists.
- When slicing large tensors, copy the slice if the sliced tensor's lifetime
exceeds the current function scope. Do not keep a long-lived view into a large
backing tensor when a smaller copy would release memory sooner.
- Use fused or compound torch operations such as `addcmul` when they naturally
match the math. Reducing Python and torch dispatch overhead is a valid
optimization when it does not obscure the code or change dtype/device
behavior.
- Avoid caches that persist across different executions as much as possible.
Persistent caches are acceptable only when they use a very minimal amount of
memory and have a clear ownership and invalidation story.
- When optimizing, favor small measurable changes: fewer allocations, fewer
device transfers, less peak memory, better batching, or use of a faster
existing backend op.
## Nodes and User-Facing Behavior
- Follow existing node conventions: `INPUT_TYPES`, `RETURN_TYPES`, `FUNCTION`,
`CATEGORY`, and registration through the local mapping used by that file.
- Keep node changes backward compatible by default. Add inputs with sensible
defaults and avoid changing output types unless the request requires it.
- Model implementations should add the minimal number of ComfyUI nodes required
to run the model. Reuse existing nodes as much as possible; adapting the model
to work with existing nodes is strongly preferred over creating new nodes.
- Nodes should output only values they own. Do not add pass-through outputs for
workflow convenience unless the node is explicitly an output node. Existing
models, latents, conditioning, or other inputs should flow directly to the
next consumer instead of being re-emitted unchanged.
- Nodes should expose only inputs they actually read to produce current
behavior. Do not add placeholder, pass-through, compatibility, or
workflow-shaping inputs that are ignored or could flow directly to another
node.
- Node-level code must not patch model code directly. Any node behavior that
modifies, wraps, hooks, or changes model behavior must go through the model
patcher class instead of reaching into model internals.
- The official mascot of ComfyUI is a very cute anime girl with massive fennec
ears, a big fluffy tail, long blonde wavy hair, and blue eyes. Feel free to
use her in ComfyUI materials, UI text, examples, tests, generated assets, or
comments, but do not disrespect her.
- Warning and info messages should be short and actionable. Remove noisy or
misleading messages rather than adding more logging.
- Documentation and README edits should be concise, factual, and tied to the
changed behavior.
## Commit and Review Habits
- If asked to write commit messages, use short direct subjects like the existing
history: `Fix ...`, `Add ...`, `Support ...`, `Remove ...`, `Update ...`,
`Make ...`, `Use ...`, `Disable ...`, `Bump ...`, or `Revert ...`.
- Keep PR descriptions short and reviewable. State the problem, the behavioral
change, and the tests run; avoid long narrative explanations, implementation
diaries, or exhaustive file-by-file summaries unless the reviewer explicitly
needs that context.
- Prefer one coherent behavioral change per commit. Dependency pins, tests, and
the code that needs them may be in the same commit when they are inseparable.
- In reviews, prioritize real user impact: crashes, wrong dtype/device behavior,
memory regressions, broken model loading, workflow incompatibility, and noisy
or misleading user-facing output.

View File

@ -0,0 +1,115 @@
"""
Download manager schema.
Adds the three tables that back the server-side model download manager
: transient job/queue state (``downloads`` + per-segment
``download_segments``) and one-API-key-per-host auth (``host_credentials``).
Revision ID: 0005_download_manager
Revises: 0004_drop_tag_type
Create Date: 2026-06-27
"""
from alembic import op
import sqlalchemy as sa
revision = "0005_download_manager"
down_revision = "0004_drop_tag_type"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"downloads",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("url", sa.Text(), nullable=False),
sa.Column("final_url", sa.Text(), nullable=True),
sa.Column("model_id", sa.String(length=1024), nullable=False),
sa.Column("dest_path", sa.Text(), nullable=False),
sa.Column("temp_path", sa.Text(), nullable=False),
sa.Column("status", sa.String(length=16), nullable=False),
sa.Column("priority", sa.Integer(), nullable=False, server_default="0"),
sa.Column("total_bytes", sa.BigInteger(), nullable=True),
sa.Column("bytes_done", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("etag", sa.String(length=512), nullable=True),
sa.Column("last_modified", sa.String(length=128), nullable=True),
sa.Column(
"accept_ranges", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column("expected_sha256", sa.String(length=64), nullable=True),
sa.Column("credential_id", sa.String(length=36), nullable=True),
sa.Column(
"allow_any_extension",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
sa.Column("attempts", sa.Integer(), nullable=False, server_default="0"),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
sa.CheckConstraint("bytes_done >= 0", name="ck_downloads_bytes_done_nonneg"),
sa.CheckConstraint(
"total_bytes IS NULL OR total_bytes >= 0",
name="ck_downloads_total_bytes_nonneg",
),
)
op.create_index("ix_downloads_status", "downloads", ["status"])
op.create_index("ix_downloads_priority", "downloads", ["priority"])
op.create_index("ix_downloads_model_id", "downloads", ["model_id"])
op.create_table(
"download_segments",
sa.Column(
"download_id",
sa.String(length=36),
sa.ForeignKey("downloads.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("idx", sa.Integer(), nullable=False),
sa.Column("start_offset", sa.BigInteger(), nullable=False),
sa.Column("end_offset", sa.BigInteger(), nullable=False),
sa.Column("bytes_done", sa.BigInteger(), nullable=False, server_default="0"),
sa.PrimaryKeyConstraint("download_id", "idx", name="pk_download_segments"),
sa.CheckConstraint("bytes_done >= 0", name="ck_segments_bytes_done_nonneg"),
sa.CheckConstraint("end_offset >= start_offset", name="ck_segments_range"),
)
op.create_table(
"host_credentials",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("host", sa.String(length=255), nullable=False),
sa.Column(
"match_subdomains",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
sa.Column("label", sa.String(length=255), nullable=True),
sa.Column(
"auth_scheme", sa.String(length=16), nullable=False, server_default="bearer"
),
sa.Column("header_name", sa.String(length=255), nullable=True),
sa.Column("query_param", sa.String(length=255), nullable=True),
sa.Column("secret", sa.Text(), nullable=False),
sa.Column("secret_last4", sa.String(length=4), nullable=True),
sa.Column("enabled", sa.Boolean(), nullable=False, server_default=sa.text("true")),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
)
op.create_index(
"uq_host_credentials_host", "host_credentials", ["host"], unique=True
)
def downgrade() -> None:
op.drop_index("uq_host_credentials_host", table_name="host_credentials")
op.drop_table("host_credentials")
op.drop_table("download_segments")
op.drop_index("ix_downloads_model_id", table_name="downloads")
op.drop_index("ix_downloads_priority", table_name="downloads")
op.drop_index("ix_downloads_status", table_name="downloads")
op.drop_table("downloads")

View File

@ -306,15 +306,12 @@ async def download_asset_content(request: web.Request) -> web.Response:
404, "FILE_NOT_FOUND", "Underlying file not found on disk."
)
# User-controlled asset content must never render inline in the app origin
# (stored XSS via SVG/HTML/XML). Force dangerous types to download and
# override any requested inline disposition. Centralised through
# folder_paths.is_dangerous_content_type so this can't drift from /view and
# /userdata (the previous inline set here omitted image/svg+xml and missed
# the charset/casing/+xml-dialect bypasses).
if folder_paths.is_dangerous_content_type(content_type):
_DANGEROUS_MIME_TYPES = {
"text/html", "text/html-sandboxed", "application/xhtml+xml",
"text/javascript", "text/css",
}
if content_type in _DANGEROUS_MIME_TYPES:
content_type = "application/octet-stream"
disposition = "attachment"
safe_name = (filename or "").replace("\r", "").replace("\n", "")
encoded = urllib.parse.quote(safe_name)

View File

@ -4,7 +4,11 @@ import shutil
from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message
from filelock import FileLock, Timeout
from comfy.cli_args import args
# NOTE: import the module (not `from ... import args`) so we always read the
# live `args` object. Tests reload `comfy.cli_args`, which replaces the module
# global; a bound `args` reference would go stale and point at the default
# database URL instead of the one configured for the test.
import comfy.cli_args
_DB_AVAILABLE = False
Session = None
@ -21,6 +25,7 @@ try:
from app.database.models import Base
import app.assets.database.models # noqa: F401 — register models with Base.metadata
import app.model_downloader.database.models # noqa: F401 — register models with Base.metadata
_DB_AVAILABLE = True
except ImportError as e:
@ -57,13 +62,13 @@ def get_alembic_config():
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
config.set_main_option("sqlalchemy.url", comfy.cli_args.args.database_url)
return config
def get_db_path():
url = args.database_url
url = comfy.cli_args.args.database_url
if url.startswith("sqlite:///"):
return url.split("///")[1]
else:
@ -97,7 +102,7 @@ def _is_memory_db(db_url):
def init_db():
db_url = args.database_url
db_url = comfy.cli_args.args.database_url
logging.debug(f"Database URL: {db_url}")
if _is_memory_db(db_url):

View File

@ -0,0 +1,220 @@
"""aiohttp routes for the download manager.
Endpoint surface (all under ``/api/download``), mirroring the response
envelope used by ``app/assets/api/routes.py``:
POST /api/download/enqueue
GET /api/download
POST /api/download/availability
POST /api/download/clear
POST /api/download/credentials
GET /api/download/credentials
GET /api/download/credentials/{id}
DELETE /api/download/credentials/{id}
GET /api/download/{id}
DELETE /api/download/{id}
POST /api/download/{id}/pause
POST /api/download/{id}/resume
POST /api/download/{id}/cancel
POST /api/download/{id}/priority
Note on ordering: the static ``credentials`` routes are registered before the
dynamic ``/api/download/{id}`` route so a request to ``.../credentials`` is not
captured as ``id == "credentials"``.
"""
from __future__ import annotations
import json
from aiohttp import web
from pydantic import BaseModel, ValidationError
from app.model_downloader.api import schemas_in, schemas_out
from app.model_downloader.credentials.store import (
CREDENTIAL_STORE,
CredentialValidationError,
)
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
ROUTES = web.RouteTableDef()
def register_routes(app: web.Application) -> None:
"""Wire the download-manager routes into the running aiohttp app."""
app.add_routes(ROUTES)
# ----- envelope helpers (same shape as app/assets/api/routes.py) -----
def _error(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
return web.json_response(
{"error": {"code": code, "message": message, "details": details or {}}},
status=status,
)
def _ok(payload, status: int = 200) -> web.Response:
return web.json_response(payload, status=status)
async def _parse(request: web.Request, model: type[BaseModel]):
try:
raw = await request.json()
except json.JSONDecodeError:
return _error(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
return model.model_validate(raw)
except ValidationError as ve:
return _error(400, "INVALID_BODY", "Validation failed.", {"errors": json.loads(ve.json())})
def _from_download_error(e: DownloadError) -> web.Response:
return _error(e.http_status, e.code, e.message)
# ----- downloads: collection + enqueue + availability -----
@ROUTES.post("/api/download/enqueue")
async def enqueue(request: web.Request) -> web.Response:
parsed = await _parse(request, schemas_in.EnqueueRequest)
if isinstance(parsed, web.Response):
return parsed
try:
download_id = await DOWNLOAD_MANAGER.enqueue(
parsed.url,
parsed.model_id,
priority=parsed.priority,
expected_sha256=parsed.expected_sha256,
allow_any_extension=parsed.allow_any_extension,
credential_id=parsed.credential_id,
)
except DownloadError as e:
return _from_download_error(e)
return _ok({"download_id": download_id, "accepted": True}, status=202)
@ROUTES.get("/api/download")
async def list_downloads(request: web.Request) -> web.Response:
return _ok({"downloads": await DOWNLOAD_MANAGER.list()})
@ROUTES.post("/api/download/availability")
async def availability(request: web.Request) -> web.Response:
parsed = await _parse(request, schemas_in.AvailabilityRequest)
if isinstance(parsed, web.Response):
return parsed
return _ok({"models": await DOWNLOAD_MANAGER.availability(parsed.models)})
@ROUTES.post("/api/download/clear")
async def clear(request: web.Request) -> web.Response:
deleted = await DOWNLOAD_MANAGER.clear()
return _ok({"deleted": deleted})
# ----- credentials (secrets are write-only) — must precede /{id} -----
@ROUTES.post("/api/download/credentials")
async def upsert_credential(request: web.Request) -> web.Response:
parsed = await _parse(request, schemas_in.CredentialUpsertRequest)
if isinstance(parsed, web.Response):
return parsed
try:
view = await CREDENTIAL_STORE.upsert(
parsed.host,
parsed.secret,
auth_scheme=parsed.auth_scheme,
header_name=parsed.header_name,
query_param=parsed.query_param,
label=parsed.label,
match_subdomains=parsed.match_subdomains,
enabled=parsed.enabled,
)
except CredentialValidationError as e:
return _error(400, "INVALID_CREDENTIAL", str(e))
return _ok(schemas_out.credential_to_dict(view), status=201)
@ROUTES.get("/api/download/credentials")
async def list_credentials(request: web.Request) -> web.Response:
views = await CREDENTIAL_STORE.list()
return _ok({"credentials": [schemas_out.credential_to_dict(v) for v in views]})
@ROUTES.get("/api/download/credentials/{id}")
async def get_credential(request: web.Request) -> web.Response:
view = await CREDENTIAL_STORE.get(request.match_info["id"])
if view is None:
return _error(404, "NOT_FOUND", "No such credential.")
return _ok(schemas_out.credential_to_dict(view))
@ROUTES.delete("/api/download/credentials/{id}")
async def delete_credential(request: web.Request) -> web.Response:
deleted = await CREDENTIAL_STORE.delete(request.match_info["id"])
if not deleted:
return _error(404, "NOT_FOUND", "No such credential.")
return _ok({"deleted": True})
# ----- single download by id (dynamic; registered last) -----
@ROUTES.get("/api/download/{id}")
async def get_download(request: web.Request) -> web.Response:
view = await DOWNLOAD_MANAGER.status(request.match_info["id"])
if view is None:
return _error(404, "NOT_FOUND", "No such download.")
return _ok(view)
@ROUTES.delete("/api/download/{id}")
async def delete_download(request: web.Request) -> web.Response:
try:
await DOWNLOAD_MANAGER.delete(request.match_info["id"])
except DownloadError as e:
return _from_download_error(e)
return _ok({"deleted": True})
@ROUTES.post("/api/download/{id}/pause")
async def pause(request: web.Request) -> web.Response:
try:
await DOWNLOAD_MANAGER.pause(request.match_info["id"])
except DownloadError as e:
return _from_download_error(e)
return _ok({"ok": True})
@ROUTES.post("/api/download/{id}/resume")
async def resume(request: web.Request) -> web.Response:
try:
await DOWNLOAD_MANAGER.resume(request.match_info["id"])
except DownloadError as e:
return _from_download_error(e)
return _ok({"ok": True})
@ROUTES.post("/api/download/{id}/cancel")
async def cancel(request: web.Request) -> web.Response:
try:
await DOWNLOAD_MANAGER.cancel(request.match_info["id"])
except DownloadError as e:
return _from_download_error(e)
return _ok({"ok": True})
@ROUTES.post("/api/download/{id}/priority")
async def set_priority(request: web.Request) -> web.Response:
parsed = await _parse(request, schemas_in.PriorityRequest)
if isinstance(parsed, web.Response):
return parsed
try:
await DOWNLOAD_MANAGER.set_priority(request.match_info["id"], parsed.priority)
except DownloadError as e:
return _from_download_error(e)
return _ok({"ok": True})

View File

@ -0,0 +1,51 @@
"""Request schemas for the download manager API.
Pydantic enforces shape at the boundary; handlers operate only on validated
values past that point.
"""
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel, Field
from app.model_downloader.constants import AUTH_SCHEME_BEARER
class EnqueueRequest(BaseModel):
url: str
model_id: str
priority: int = 0
expected_sha256: Optional[str] = None
allow_any_extension: bool = False
credential_id: Optional[str] = None
class PriorityRequest(BaseModel):
priority: int
class AvailabilityRequest(BaseModel):
"""``{model_id: url}`` — the URLs declared in the workflow JSON."""
models: dict[str, str] = Field(default_factory=dict)
class CredentialUpsertRequest(BaseModel):
host: str
secret: str
auth_scheme: str = AUTH_SCHEME_BEARER
header_name: Optional[str] = None
query_param: Optional[str] = None
label: Optional[str] = None
match_subdomains: bool = False
enabled: bool = True
__all__ = [
"EnqueueRequest",
"PriorityRequest",
"AvailabilityRequest",
"CredentialUpsertRequest",
]

View File

@ -0,0 +1,26 @@
"""Response helpers for the download manager API.
The download/status read models are plain dicts produced by the manager. This
module only needs to mask credentials for output (the secret is never returned).
"""
from __future__ import annotations
from app.model_downloader.credentials.store import CredentialView
def credential_to_dict(view: CredentialView) -> dict:
"""API-safe credential representation — never includes the secret."""
return {
"id": view.id,
"host": view.host,
"auth_scheme": view.auth_scheme,
"header_name": view.header_name,
"query_param": view.query_param,
"label": view.label,
"match_subdomains": view.match_subdomains,
"enabled": view.enabled,
"secret_last4": view.secret_last4,
"created_at": view.created_at,
"updated_at": view.updated_at,
}

View File

@ -0,0 +1,47 @@
"""Shared constants for the download manager.
Status values are persisted as TEXT in the ``downloads`` table; keep them
stable. The lifecycle is:
queued -> active -> verifying -> completed
| |-> paused -> (resume) -> active
| |-> failed (network, retryable) -> queued (backoff)
|-> cancelled
"""
from __future__ import annotations
# Auth schemes for HostCredential
AUTH_SCHEME_BEARER = "bearer"
AUTH_SCHEME_HEADER = "header"
AUTH_SCHEME_QUERY = "query"
AUTH_SCHEMES = (AUTH_SCHEME_BEARER, AUTH_SCHEME_HEADER, AUTH_SCHEME_QUERY)
# Hosts for which a bearer token can be sourced from the environment when no
# stored credential matches. Values are the env var names to try, in order.
# Only consulted during auto-resolve for an exact host match over https, so the
# same per-hop boundary rules apply (e.g. the token is dropped on a redirect to
# a CDN host). Kept here so the host->env-var mapping lives in one place.
ENV_TOKEN_HOSTS = {
"huggingface.co": ("HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"),
}
class DownloadStatus:
QUEUED = "queued"
ACTIVE = "active"
PAUSED = "paused"
VERIFYING = "verifying"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
#: States from which a worker is doing (or about to do) network I/O.
LIVE = (QUEUED, ACTIVE, VERIFYING)
#: Terminal states — the job will not transition again on its own.
TERMINAL = (COMPLETED, FAILED, CANCELLED)
# Default temp-file suffix. Distinctive so the startup orphan sweep only
# removes files THIS subsystem created, never unrelated *.tmp files.
TMP_SUFFIX = ".comfy-download.part"

View File

@ -0,0 +1,111 @@
"""Turn a stored credential into a per-hop request modifier (PRD section 9.4.2).
The critical rule: a credential is only ever attached when *the current hop's
host* matches a stored credential, and only over https. This is recomputed
from scratch on every redirect hop, so a token bound to ``huggingface.co`` is
silently dropped when the request is redirected to a presigned CDN host —
which is exactly what these hubs expect.
"""
from __future__ import annotations
import asyncio
import os
from dataclasses import dataclass, field
from typing import Optional
from urllib.parse import urlencode, urlsplit, urlunsplit
from app.model_downloader.constants import (
AUTH_SCHEME_BEARER,
AUTH_SCHEME_HEADER,
AUTH_SCHEME_QUERY,
ENV_TOKEN_HOSTS,
)
from app.model_downloader.credentials.store import normalize_host
from app.model_downloader.database import queries
from app.model_downloader.database.models import HostCredential
@dataclass
class RequestAuth:
"""How to modify a single request to carry a credential."""
headers: dict[str, str] = field(default_factory=dict)
query: dict[str, str] = field(default_factory=dict)
def apply_to_url(self, url: str) -> str:
if not self.query:
return url
parts = urlsplit(url)
# Append only the credential params, leaving the original query string
# (including any repeated keys and existing encoding) untouched.
creds = urlencode(self.query)
query = f"{parts.query}&{creds}" if parts.query else creds
return urlunsplit(parts._replace(query=query))
def _matches(cred: HostCredential, hop_host: str) -> bool:
cred_host = cred.host
if hop_host == cred_host:
return True
if cred.match_subdomains:
# Label-boundary suffix: api.example.com matches example.com, but
# evil-example.com does NOT.
return hop_host.endswith("." + cred_host)
return False
def _build_auth(cred: HostCredential) -> RequestAuth:
if cred.auth_scheme == AUTH_SCHEME_BEARER:
return RequestAuth(headers={"Authorization": f"Bearer {cred.secret}"})
if cred.auth_scheme == AUTH_SCHEME_HEADER:
name = cred.header_name or "Authorization"
return RequestAuth(headers={name: cred.secret})
if cred.auth_scheme == AUTH_SCHEME_QUERY and cred.query_param:
return RequestAuth(query={cred.query_param: cred.secret})
return RequestAuth()
def _resolve_sync(
host: str, scheme: str, explicit_credential_id: Optional[str]
) -> Optional[RequestAuth]:
# Never attach a secret over a non-https hop (PRD section 9.4.2).
if scheme.lower() != "https":
return None
hop_host = normalize_host(host)
if not hop_host:
return None
if explicit_credential_id is not None:
cred = queries.get_credential(explicit_credential_id)
# An explicit credential is still subject to the per-hop host check —
# it is not forced onto a non-matching host.
if cred is None or not cred.enabled or not _matches(cred, hop_host):
return None
return _build_auth(cred)
# Auto-resolve: exact host first, then any subdomain-matching credential.
cred = queries.get_credential_by_host(hop_host)
if cred is not None and cred.enabled:
return _build_auth(cred)
for sub in queries.list_subdomain_credentials():
if sub.enabled and _matches(sub, hop_host):
return _build_auth(sub)
# Env fallback: only for an exact host match, and only after the DB lookups
# miss, so a user-set credential always takes precedence. The token is never
# persisted; it is read fresh from the environment on each hop.
for var in ENV_TOKEN_HOSTS.get(hop_host, ()):
token = os.environ.get(var)
if token:
return RequestAuth(headers={"Authorization": f"Bearer {token}"})
return None
async def resolve_auth_for_hop(
host: str, scheme: str, *, explicit_credential_id: Optional[str] = None
) -> Optional[RequestAuth]:
"""Resolve the credential (if any) to attach for one request hop."""
return await asyncio.to_thread(
_resolve_sync, host, scheme, explicit_credential_id
)

View File

@ -0,0 +1,141 @@
"""The credential store: one API key per host.
Secrets are write-only over the API — :class:`CredentialView` carries only
masked metadata (``secret_last4`` + scheme + label), never the secret itself.
At-rest protection for v1 is filesystem permissions on the shared DB (the DB
is the trust boundary); encryption-at-rest is a noted future seam.
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Optional
from urllib.parse import urlsplit
from app.model_downloader.constants import (
AUTH_SCHEME_BEARER,
AUTH_SCHEME_HEADER,
AUTH_SCHEME_QUERY,
AUTH_SCHEMES,
)
from app.model_downloader.database import queries
from app.model_downloader.database.models import HostCredential
def normalize_host(host: str) -> str:
"""Lowercase, strip port, IDNA-encode."""
if not host:
return ""
host = host.strip()
if "://" in host: # a full URL was pasted — extract just the host
host = urlsplit(host).hostname or ""
host = host.lower()
if host.startswith("[") and "]" in host: # bracketed IPv6 literal
host = host[1 : host.index("]")]
elif host.count(":") == 1: # host:port (not IPv6)
host = host.split(":", 1)[0]
try:
host = host.encode("idna").decode("ascii")
except (UnicodeError, ValueError):
pass
return host
@dataclass(frozen=True)
class CredentialView:
"""Masked, API-safe view of a credential — never includes the secret."""
id: str
host: str
auth_scheme: str
header_name: Optional[str]
query_param: Optional[str]
label: Optional[str]
match_subdomains: bool
enabled: bool
secret_last4: Optional[str]
created_at: int
updated_at: int
def _to_view(row: HostCredential) -> CredentialView:
return CredentialView(
id=row.id,
host=row.host,
auth_scheme=row.auth_scheme,
header_name=row.header_name,
query_param=row.query_param,
label=row.label,
match_subdomains=row.match_subdomains,
enabled=row.enabled,
secret_last4=row.secret_last4,
created_at=row.created_at,
updated_at=row.updated_at,
)
class CredentialValidationError(ValueError):
"""A credential upsert had inconsistent fields."""
class CredentialStore:
"""Async facade over the ``host_credentials`` table.
DB access is synchronous (SQLite) and offloaded via ``asyncio.to_thread``.
"""
async def upsert(
self,
host: str,
secret: str,
*,
auth_scheme: str = AUTH_SCHEME_BEARER,
header_name: Optional[str] = None,
query_param: Optional[str] = None,
label: Optional[str] = None,
match_subdomains: bool = False,
enabled: bool = True,
) -> CredentialView:
host = normalize_host(host)
if not host:
raise CredentialValidationError("host is required")
if not secret:
raise CredentialValidationError("secret is required")
if auth_scheme not in AUTH_SCHEMES:
raise CredentialValidationError(
f"auth_scheme must be one of {AUTH_SCHEMES}, got {auth_scheme!r}"
)
if auth_scheme == AUTH_SCHEME_HEADER and not header_name:
header_name = "Authorization"
if auth_scheme == AUTH_SCHEME_QUERY and not query_param:
raise CredentialValidationError(
"query_param is required when auth_scheme='query'"
)
values = {
"host": host,
"secret": secret,
"secret_last4": secret[-4:] if len(secret) > 4 else None,
"auth_scheme": auth_scheme,
"header_name": header_name,
"query_param": query_param,
"label": label,
"match_subdomains": match_subdomains,
"enabled": enabled,
}
row = await asyncio.to_thread(queries.upsert_credential, values)
return _to_view(row)
async def list(self) -> list[CredentialView]:
rows = await asyncio.to_thread(queries.list_credentials)
return [_to_view(r) for r in rows]
async def get(self, credential_id: str) -> Optional[CredentialView]:
row = await asyncio.to_thread(queries.get_credential, credential_id)
return _to_view(row) if row is not None else None
async def delete(self, credential_id: str) -> bool:
return await asyncio.to_thread(queries.delete_credential, credential_id)
CREDENTIAL_STORE = CredentialStore()

View File

@ -0,0 +1,173 @@
"""SQLAlchemy models for the download manager.
Three tables:
- ``downloads`` one row per requested file (job + queue state).
- ``download_segments`` per-segment byte progress, for segmented resume.
- ``host_credentials`` one API key per host, reused across downloads.
On completion a finished file is registered into the assets catalog;
``downloads`` is kept only as job history.
"""
from __future__ import annotations
import time
import uuid
from sqlalchemy import (
BigInteger,
Boolean,
CheckConstraint,
ForeignKey,
Index,
Integer,
String,
Text,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database.models import Base
def _uuid() -> str:
return str(uuid.uuid4())
def _now() -> int:
return int(time.time())
class Download(Base):
__tablename__ = "downloads"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
# Original requested URL and the final URL after validated redirects.
url: Mapped[str] = mapped_column(Text, nullable=False)
final_url: Mapped[str | None] = mapped_column(Text, nullable=True)
# Canonical "<directory>/<filename>" identifier (resolved via folder_paths).
model_id: Mapped[str] = mapped_column(String(1024), nullable=False)
# Final on-disk location and the .part write target.
dest_path: Mapped[str] = mapped_column(Text, nullable=False)
temp_path: Mapped[str] = mapped_column(Text, nullable=False)
status: Mapped[str] = mapped_column(String(16), nullable=False)
priority: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
total_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
bytes_done: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
etag: Mapped[str | None] = mapped_column(String(512), nullable=True)
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
accept_ranges: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
# Optional hub-provided checksum to verify against (NOT the dedup key).
expected_sha256: Mapped[str | None] = mapped_column(String(64), nullable=True)
# Explicit credential override; otherwise auto-resolved by host.
# RESTRICT keeps a credential from being deleted while a download references it.
credential_id: Mapped[str | None] = mapped_column(
String(36),
ForeignKey("host_credentials.id", ondelete="RESTRICT"),
nullable=True,
)
allow_any_extension: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
# How many retryable failures we have seen (for backoff capping).
attempts: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
error: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[int] = mapped_column(BigInteger, nullable=False, default=_now)
updated_at: Mapped[int] = mapped_column(
BigInteger, nullable=False, default=_now, onupdate=_now
)
segments: Mapped[list[DownloadSegment]] = relationship(
"DownloadSegment",
back_populates="download",
cascade="all,delete-orphan",
passive_deletes=True,
order_by="DownloadSegment.idx",
)
credential: Mapped[HostCredential | None] = relationship(
"HostCredential", back_populates="downloads"
)
__table_args__ = (
Index("ix_downloads_status", "status"),
Index("ix_downloads_priority", "priority"),
Index("ix_downloads_model_id", "model_id"),
CheckConstraint("bytes_done >= 0", name="ck_downloads_bytes_done_nonneg"),
CheckConstraint(
"total_bytes IS NULL OR total_bytes >= 0",
name="ck_downloads_total_bytes_nonneg",
),
)
def __repr__(self) -> str:
return f"<Download id={self.id} model_id={self.model_id!r} status={self.status}>"
class DownloadSegment(Base):
__tablename__ = "download_segments"
download_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("downloads.id", ondelete="CASCADE"),
primary_key=True,
)
idx: Mapped[int] = mapped_column(Integer, primary_key=True)
start_offset: Mapped[int] = mapped_column(BigInteger, nullable=False)
end_offset: Mapped[int] = mapped_column(BigInteger, nullable=False)
bytes_done: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
download: Mapped[Download] = relationship("Download", back_populates="segments")
__table_args__ = (
CheckConstraint("bytes_done >= 0", name="ck_segments_bytes_done_nonneg"),
CheckConstraint("end_offset >= start_offset", name="ck_segments_range"),
)
def __repr__(self) -> str:
return (
f"<DownloadSegment {self.download_id}#{self.idx} "
f"{self.start_offset}-{self.end_offset} done={self.bytes_done}>"
)
class HostCredential(Base):
__tablename__ = "host_credentials"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
# Normalized lowercase hostname, e.g. "civitai.com".
host: Mapped[str] = mapped_column(String(255), nullable=False)
match_subdomains: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
label: Mapped[str | None] = mapped_column(String(255), nullable=True)
auth_scheme: Mapped[str] = mapped_column(
String(16), nullable=False, default="bearer"
)
header_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
query_param: Mapped[str | None] = mapped_column(String(255), nullable=True)
# The API key itself. Write-only over the API; never returned. See PRD 9.4.4.
secret: Mapped[str] = mapped_column(Text, nullable=False)
secret_last4: Mapped[str | None] = mapped_column(String(4), nullable=True)
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
created_at: Mapped[int] = mapped_column(BigInteger, nullable=False, default=_now)
updated_at: Mapped[int] = mapped_column(
BigInteger, nullable=False, default=_now, onupdate=_now
)
downloads: Mapped[list[Download]] = relationship(
"Download", back_populates="credential"
)
__table_args__ = (
Index("uq_host_credentials_host", "host", unique=True),
)
def __repr__(self) -> str:
return f"<HostCredential id={self.id} host={self.host!r} scheme={self.auth_scheme}>"

View File

@ -0,0 +1,272 @@
"""Synchronous DB access for the download manager.
All functions open their own short-lived session via ``create_session`` and
commit before returning, mirroring ``app/assets`` usage. They are blocking
(SQLite) and should be called from async code through ``asyncio.to_thread``.
"""
from __future__ import annotations
import time
from typing import Optional
from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError
from app.database.db import create_session
from app.model_downloader.constants import DownloadStatus
from app.model_downloader.database.models import (
Download,
DownloadSegment,
HostCredential,
)
# ----- downloads -----
def insert_download(values: dict) -> None:
with create_session() as session:
session.add(Download(**values))
session.commit()
def get_download(download_id: str) -> Optional[Download]:
with create_session() as session:
row = session.get(Download, download_id)
if row is not None:
session.expunge_all()
return row
def list_downloads() -> list[Download]:
with create_session() as session:
rows = list(
session.execute(
select(Download).order_by(Download.created_at.desc())
).scalars()
)
session.expunge_all()
return rows
def list_segments(download_id: str) -> list[DownloadSegment]:
with create_session() as session:
rows = list(
session.execute(
select(DownloadSegment)
.where(DownloadSegment.download_id == download_id)
.order_by(DownloadSegment.idx)
).scalars()
)
session.expunge_all()
return rows
def update_download(download_id: str, **fields) -> None:
if not fields:
return
fields.setdefault("updated_at", int(time.time()))
with create_session() as session:
row = session.get(Download, download_id)
if row is None:
return
for key, value in fields.items():
setattr(row, key, value)
session.commit()
def delete_download(download_id: str) -> None:
with create_session() as session:
row = session.get(Download, download_id)
if row is not None:
session.delete(row)
session.commit()
def delete_downloads(download_ids: list[str]) -> int:
"""Delete many downloads in one transaction; returns the number removed.
Uses a bulk ``DELETE ... WHERE id IN (...)``. Segment rows are removed by
the ``ON DELETE CASCADE`` foreign key (SQLite ``PRAGMA foreign_keys=ON`` is
set in ``app/database/db.py``), so this stays consistent without loading the
ORM relationship.
"""
if not download_ids:
return 0
with create_session() as session:
result = session.execute(
delete(Download).where(Download.id.in_(download_ids))
)
session.commit()
return result.rowcount or 0
def replace_segments(download_id: str, segments: list[dict]) -> None:
"""Atomically replace the segment plan for a download."""
with create_session() as session:
session.query(DownloadSegment).filter(
DownloadSegment.download_id == download_id
).delete()
for seg in segments:
session.add(DownloadSegment(download_id=download_id, **seg))
session.commit()
def update_segment_progress(download_id: str, idx: int, bytes_done: int) -> None:
with create_session() as session:
row = session.get(DownloadSegment, {"download_id": download_id, "idx": idx})
if row is None:
return
row.bytes_done = bytes_done
session.commit()
def list_queued_downloads() -> list[Download]:
"""Queued rows ordered for admission (priority desc, then FIFO)."""
with create_session() as session:
rows = list(
session.execute(
select(Download)
.where(Download.status == DownloadStatus.QUEUED)
.order_by(Download.priority.desc(), Download.created_at.asc())
).scalars()
)
session.expunge_all()
return rows
def reconcile_live_downloads() -> list[Download]:
"""Reset any ``active``/``verifying`` rows left by a previous run.
On a clean restart there can be no live worker, so anything still marked
live is stale. Move it back to ``queued`` (offsets are preserved on the
segment rows) so the scheduler re-admits it. Returns the rows that should
be re-queued by the scheduler (queued + paused).
"""
with create_session() as session:
stale = list(
session.execute(
select(Download).where(
Download.status.in_([DownloadStatus.ACTIVE, DownloadStatus.VERIFYING])
)
).scalars()
)
now = int(time.time())
for row in stale:
row.status = DownloadStatus.QUEUED
row.updated_at = now
session.commit()
resumable = list(
session.execute(
select(Download)
.where(Download.status == DownloadStatus.QUEUED)
.order_by(Download.priority.desc(), Download.created_at.asc())
).scalars()
)
session.expunge_all()
return resumable
# ----- host credentials -----
def get_credential(credential_id: str) -> Optional[HostCredential]:
with create_session() as session:
row = session.get(HostCredential, credential_id)
if row is not None:
session.expunge_all()
return row
def get_credential_by_host(host: str) -> Optional[HostCredential]:
with create_session() as session:
row = (
session.execute(
select(HostCredential).where(HostCredential.host == host).limit(1)
)
.scalars()
.first()
)
if row is not None:
session.expunge_all()
return row
def list_credentials() -> list[HostCredential]:
with create_session() as session:
rows = list(
session.execute(
select(HostCredential).order_by(HostCredential.host)
).scalars()
)
session.expunge_all()
return rows
def list_subdomain_credentials() -> list[HostCredential]:
"""Credentials that opted into subdomain matching, for suffix checks."""
with create_session() as session:
rows = list(
session.execute(
select(HostCredential).where(HostCredential.match_subdomains.is_(True))
).scalars()
)
session.expunge_all()
return rows
def upsert_credential(values: dict) -> HostCredential:
"""Insert or update a credential keyed by ``host``.
Callers can target the same host concurrently (each runs in its own
short-lived session on a separate connection), so the read-then-write here
can race: two callers both see no existing row and both attempt an insert.
The ``host`` column is uniquely indexed, so the loser's insert raises
``IntegrityError``. We recover by rolling back and retrying, at which point
the now-committed row is found and updated in place, letting concurrent
calls converge instead of failing or creating duplicates.
"""
host = values["host"]
now = int(time.time())
last_error: IntegrityError | None = None
for _ in range(2):
with create_session() as session:
row = (
session.execute(
select(HostCredential).where(HostCredential.host == host).limit(1)
)
.scalars()
.first()
)
if row is None:
row = HostCredential(**values)
row.created_at = now
row.updated_at = now
session.add(row)
else:
for key, value in values.items():
setattr(row, key, value)
row.updated_at = now
try:
session.commit()
except IntegrityError as exc:
session.rollback()
last_error = exc
continue
session.refresh(row)
session.expunge(row)
return row
assert last_error is not None
raise last_error
def delete_credential(credential_id: str) -> bool:
with create_session() as session:
row = session.get(HostCredential, credential_id)
if row is None:
return False
session.delete(row)
session.commit()
return True

View File

@ -0,0 +1,612 @@
"""The per-download worker.
One :class:`DownloadJob` drives a single file from probe to verified, cataloged
completion. It supports cooperative pause / resume / cancel, segmented
multi-connection transfer with positioned writes, and a verification gate
(size + structural + optional sha256) before the atomic rename into place.
Control is cooperative: external callers flip ``_control`` via
:meth:`request_pause` / :meth:`request_cancel`; segment loops observe it between
chunks and raise, which unwinds cleanly and persists resume offsets.
"""
from __future__ import annotations
import asyncio
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Callable, Optional
from comfy.cli_args import args
from app.model_downloader.constants import DownloadStatus
from app.model_downloader.database import queries
from app.model_downloader.engine.planner import (
effective_segment_count,
plan_segments,
)
from app.model_downloader.engine.writer import FileWriter
from app.model_downloader.net.http import open_validated, redact_url
from app.model_downloader.net.probe import gated_error_message, probe
from app.model_downloader.verify import checksum, dedup, structural
_RETRYABLE_STATUSES = {408, 429, 500, 502, 503, 504}
_PERSIST_INTERVAL = 2.0 # seconds between throttled progress persists
class Paused(Exception):
pass
class Cancelled(Exception):
pass
class RemoteChanged(Exception):
"""The remote file changed under a resume (got 200 where 206 expected)."""
class RetryableError(Exception):
pass
class FatalError(Exception):
"""Non-retryable: 4xx, checksum mismatch, structural failure, gated, etc."""
@dataclass
class SegmentRuntime:
idx: int
start: int
end: int # inclusive; may be -1 for unknown-size single stream
bytes_done: int = 0
@property
def length(self) -> int:
return self.end - self.start + 1
@dataclass
class RuntimeState:
download_id: str
model_id: str
url: str
priority: int
status: str
total_bytes: Optional[int] = None
bytes_done: int = 0
error: Optional[str] = None
segments: list[SegmentRuntime] = field(default_factory=list)
started_at: float = field(default_factory=time.monotonic)
_last_bytes: int = 0
_last_time: float = field(default_factory=time.monotonic)
speed_bps: float = 0.0
@property
def progress(self) -> Optional[float]:
if not self.total_bytes:
return None
return min(1.0, self.bytes_done / self.total_bytes)
@property
def eta_seconds(self) -> Optional[float]:
if not self.total_bytes or self.speed_bps <= 0:
return None
remaining = max(0, self.total_bytes - self.bytes_done)
return remaining / self.speed_bps
@dataclass
class JobSpec:
download_id: str
url: str
model_id: str
dest_path: str
temp_path: str
priority: int = 0
credential_id: Optional[str] = None
expected_sha256: Optional[str] = None
allow_any_extension: bool = False
etag: Optional[str] = None
attempts: int = 0
class DownloadJob:
def __init__(
self, spec: JobSpec, notify_cb: Optional[Callable[[str], None]] = None
) -> None:
self.spec = spec
self._notify = notify_cb
self._control = "run" # run | pause | cancel
self.state = RuntimeState(
download_id=spec.download_id,
model_id=spec.model_id,
url=spec.url,
priority=spec.priority,
status=DownloadStatus.QUEUED,
)
self._writer: Optional[FileWriter] = None
self._etag: Optional[str] = spec.etag
self._last_persist = 0.0
# ----- external control -----
def request_pause(self) -> None:
if self._control == "run":
self._control = "pause"
def request_cancel(self) -> None:
self._control = "cancel"
def _check_control(self) -> None:
if self._control == "cancel":
raise Cancelled()
if self._control == "pause":
raise Paused()
# ----- lifecycle -----
async def run(self) -> str:
"""Run to a terminal/paused state; returns the final status string."""
await self._set_status(DownloadStatus.ACTIVE, error=None)
try:
pr = await self._probe_and_plan()
await self._transfer(pr)
await self._finalize()
await self._set_status(DownloadStatus.COMPLETED)
except Paused:
await self._persist_progress(force=True)
await self._set_status(DownloadStatus.PAUSED)
except Cancelled:
await self._close_writer()
self._remove_temp()
await self._set_status(DownloadStatus.CANCELLED)
except RemoteChanged:
await self._reset_for_restart()
await self._set_status(
DownloadStatus.QUEUED, error="remote file changed; restarting"
)
except RetryableError as e:
await self._persist_progress(force=True)
await self._set_status(DownloadStatus.QUEUED, error=str(e))
except FatalError as e:
await self._close_writer()
self._remove_temp()
await self._set_status(DownloadStatus.FAILED, error=str(e))
except Exception as e: # unexpected -> treat as retryable
logging.warning(
"[model_downloader] %s unexpected error: %s",
self.spec.model_id, e, exc_info=True,
)
await self._persist_progress(force=True)
await self._set_status(DownloadStatus.QUEUED, error=f"{type(e).__name__}: {e}")
finally:
await self._close_writer()
return self.state.status
# ----- probe + plan -----
async def _probe_and_plan(self):
pr = await probe(self.spec.url, credential_id=self.spec.credential_id)
if not pr.ok:
if pr.gated:
raise FatalError(gated_error_message(self.spec.url, pr))
if pr.status == 0 or pr.status in _RETRYABLE_STATUSES:
raise RetryableError(pr.error or "probe failed")
raise FatalError(pr.error or f"probe returned HTTP {pr.status}")
max_bytes = self._max_download_bytes()
if max_bytes is not None and pr.total_bytes is not None and pr.total_bytes > max_bytes:
raise FatalError(
f"file size {pr.total_bytes} exceeds the maximum allowed "
f"download size {max_bytes} (--download-max-bytes)"
)
self._etag = pr.etag or self._etag
self.state.total_bytes = pr.total_bytes
await asyncio.to_thread(
queries.update_download,
self.spec.download_id,
final_url=pr.final_url,
total_bytes=pr.total_bytes,
accept_ranges=pr.accept_ranges,
etag=pr.etag,
last_modified=pr.last_modified,
)
seg_count = effective_segment_count(
pr.total_bytes, pr.accept_ranges, max(1, args.download_segments)
)
existing = await asyncio.to_thread(queries.list_segments, self.spec.download_id)
can_resume_segmented = (
seg_count > 1
and existing
and pr.total_bytes is not None
and existing[-1].end_offset == pr.total_bytes - 1
)
if can_resume_segmented and not self._segmented_part_valid(pr.total_bytes):
# The persisted per-segment offsets describe bytes in a preallocated
# .part that is now gone or the wrong size (e.g. the partial of a
# failed download was swept on restart, or removed by a fatal
# error). Trusting them would skip already-"complete" segments and
# leave zero-filled holes. Discard the offsets and re-plan fresh.
logging.info(
"[model_downloader] %s discarding segmented resume offsets "
"(preallocated .part missing or wrong size); restarting",
self.spec.model_id,
)
self._remove_temp()
await asyncio.to_thread(
queries.replace_segments, self.spec.download_id, []
)
await asyncio.to_thread(
queries.update_download, self.spec.download_id, bytes_done=0
)
existing = []
can_resume_segmented = False
if can_resume_segmented:
# Resume an existing segmented plan.
self.state.segments = [
SegmentRuntime(s.idx, s.start_offset, s.end_offset, s.bytes_done)
for s in existing
]
elif seg_count > 1 and pr.total_bytes is not None:
plans = plan_segments(pr.total_bytes, seg_count)
await asyncio.to_thread(
queries.replace_segments,
self.spec.download_id,
[
{"idx": p.idx, "start_offset": p.start, "end_offset": p.end, "bytes_done": 0}
for p in plans
],
)
self.state.segments = [SegmentRuntime(p.idx, p.start, p.end, 0) for p in plans]
else:
# Single-stream: one logical segment; bytes_done tracked on the row.
row = await asyncio.to_thread(queries.get_download, self.spec.download_id)
resume_from = row.bytes_done if row else 0
end = (pr.total_bytes - 1) if pr.total_bytes else -1
# ``row.bytes_done`` may be the SUM of per-segment offsets from a
# prior segmented run (a preallocated, non-contiguous .part). A
# single-stream resume writes a contiguous prefix, so the offset is
# only trustworthy when the on-disk file is exactly that many
# contiguous bytes. This guards the case where a download that ran
# segmented now resolves to one segment (server dropped
# Accept-Ranges, or --download-segments was lowered between runs):
# resuming over non-contiguous data would corrupt the output.
if resume_from > 0 and not self._contiguous_prefix_valid(resume_from):
logging.info(
"[model_downloader] %s discarding untrusted resume offset "
"%d (on-disk .part not a contiguous prefix); restarting",
self.spec.model_id, resume_from,
)
resume_from = 0
self._remove_temp()
if await asyncio.to_thread(queries.list_segments, self.spec.download_id):
await asyncio.to_thread(
queries.replace_segments, self.spec.download_id, []
)
await asyncio.to_thread(
queries.update_download, self.spec.download_id, bytes_done=0
)
self.state.segments = [SegmentRuntime(0, 0, end, resume_from)]
self._recompute_bytes_done()
return pr
# ----- transfer -----
async def _transfer(self, pr) -> None:
self._writer = FileWriter(self.spec.temp_path)
await self._writer.open()
segmented = len(self.state.segments) > 1
if segmented and self.state.total_bytes:
await self._writer.preallocate(self.state.total_bytes)
await self._run_segmented()
else:
await self._run_single()
await self._writer.flush()
async def _run_segmented(self) -> None:
pending = [
asyncio.ensure_future(self._run_segment(seg))
for seg in self.state.segments
if seg.bytes_done < seg.length
]
if not pending:
return
done, not_done = await asyncio.wait(
pending, return_when=asyncio.FIRST_EXCEPTION
)
first_exc: Optional[BaseException] = None
for task in done:
exc = task.exception()
if exc is not None and first_exc is None:
first_exc = exc
if first_exc is not None:
for task in not_done:
task.cancel()
await asyncio.gather(*not_done, return_exceptions=True)
raise first_exc
async def _run_segment(self, seg: SegmentRuntime) -> None:
offset = seg.start + seg.bytes_done
headers = {
"Range": f"bytes={offset}-{seg.end}",
"Accept-Encoding": "identity",
}
if self._etag:
headers["If-Range"] = self._etag
async with open_validated(
"GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers
) as (resp, _final):
if resp.status == 200:
# Server ignored the range -> remote changed / no resume support.
raise RemoteChanged()
if resp.status not in (206,):
self._raise_for_status(resp.status)
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
self._check_control()
# Never write past this segment's planned range: a
# non-conforming 206 that returns more than the requested
# bytes would otherwise overrun adjacent segments and the
# preallocated file. Cap the write and abort on overflow.
remaining = seg.length - seg.bytes_done
if remaining <= 0:
raise FatalError(
f"segment {seg.idx}: server returned more than the "
f"requested {seg.length} bytes"
)
overflow = len(chunk) > remaining
if overflow:
chunk = chunk[:remaining]
await self._writer.write_at(offset, chunk)
offset += len(chunk)
seg.bytes_done += len(chunk)
self._recompute_bytes_done()
await self._persist_progress()
if overflow:
raise FatalError(
f"segment {seg.idx}: server returned more than the "
f"requested {seg.length} bytes"
)
async def _run_single(self) -> None:
seg = self.state.segments[0]
offset = seg.bytes_done # resume from here for single-stream
headers = {"Accept-Encoding": "identity"}
if offset > 0:
headers["Range"] = f"bytes={offset}-"
if self._etag:
headers["If-Range"] = self._etag
async with open_validated(
"GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers
) as (resp, _final):
if offset > 0 and resp.status == 200:
# Resume not honoured -> start over from the beginning. Truncate
# the existing partial so stale trailing bytes from the prior
# attempt cannot survive past the new (possibly shorter) end.
offset = 0
seg.bytes_done = 0
self.state.bytes_done = 0
await self._writer.truncate(0)
elif offset > 0 and resp.status != 206:
self._raise_for_status(resp.status)
elif offset == 0 and resp.status != 200:
self._raise_for_status(resp.status)
# Byte ceiling for this stream: the known total when the server
# reported a size, otherwise the configured maximum download size.
# Without a bound, a non-conforming response or an unknown-length
# stream (end == -1) that never closes could fill the disk (DoS).
limit = (seg.end + 1) if seg.end >= 0 else self._max_download_bytes()
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
self._check_control()
overflow = False
if limit is not None:
remaining = limit - offset
if remaining <= 0:
raise FatalError(
f"download exceeded the maximum size {limit} bytes"
)
if len(chunk) > remaining:
chunk = chunk[:remaining]
overflow = True
await self._writer.write_at(offset, chunk)
offset += len(chunk)
seg.bytes_done = offset
self.state.bytes_done = offset
await self._persist_progress()
if overflow:
raise FatalError(
f"download exceeded the maximum size {limit} bytes"
)
def _max_download_bytes(self) -> Optional[int]:
"""Configured maximum download size in bytes, or ``None`` if disabled."""
cap = getattr(args, "download_max_bytes", 0)
return cap if cap and cap > 0 else None
def _raise_for_status(self, status: int) -> None:
if status in (401, 403):
raise FatalError(
f"{redact_url(self.spec.url)} returned {status}; add/update an API key for "
f"this host at /api/download/credentials."
)
if status in _RETRYABLE_STATUSES:
raise RetryableError(f"HTTP {status}")
raise FatalError(f"unexpected HTTP {status}")
# ----- finalize / verify (PRD section 8.4) -----
async def _finalize(self) -> None:
self._check_control()
await self._close_writer()
await self._set_status(DownloadStatus.VERIFYING)
total = self.state.total_bytes
segmented = len(self.state.segments) > 1
if segmented:
# The .part was preallocated to total_bytes, so its on-disk size is
# not evidence of completeness: a segment that ends short (truncated
# 206 / server closes mid-range) leaves a zero-filled hole while the
# file size still equals total. Verify each segment wrote its full
# planned range, and trust the byte counter (== sum of segments)
# rather than os.path.getsize for the total check.
for seg in self.state.segments:
if seg.bytes_done != seg.length:
raise FatalError(
f"segment {seg.idx} incomplete: wrote {seg.bytes_done} "
f"of {seg.length} bytes"
)
observed = self.state.bytes_done
else:
# Single-stream writes a contiguous prefix, so the on-disk size is
# an independent witness of how much actually landed.
observed = os.path.getsize(self.spec.temp_path)
if total is not None and observed != total:
raise FatalError(
f"size mismatch: wrote {observed} of {total} bytes"
)
# Structural gate (cheap, no full read) then optional sha256 (full read).
# Both failures are non-retryable (a truncated/corrupt or mismatched file
# will not heal on retry), so surface them as FatalError rather than
# letting the plain Exceptions fall through to the retryable handler.
# ``temp_path`` carries the ``.part`` suffix; pass ``dest_path`` so the
# structural check detects the real file format instead of skipping it.
try:
await asyncio.to_thread(
structural.validate, self.spec.temp_path, self.spec.dest_path
)
if self.spec.expected_sha256:
await asyncio.to_thread(
checksum.verify_sha256,
self.spec.temp_path,
self.spec.expected_sha256,
)
except (structural.StructuralError, checksum.ChecksumError) as e:
raise FatalError(str(e)) from e
os.makedirs(os.path.dirname(self.spec.dest_path), exist_ok=True)
os.replace(self.spec.temp_path, self.spec.dest_path)
logging.info(
"[model_downloader] completed %s (%d bytes)",
self.spec.model_id, observed,
)
# Catalog into the assets system (blake3 dedup identity). Best-effort.
await dedup.register_completed(self.spec.dest_path)
# ----- helpers -----
def _recompute_bytes_done(self) -> None:
self.state.bytes_done = sum(s.bytes_done for s in self.state.segments)
now = time.monotonic()
dt = now - self.state._last_time
if dt >= 0.5:
self.state.speed_bps = (self.state.bytes_done - self.state._last_bytes) / dt
self.state._last_bytes = self.state.bytes_done
self.state._last_time = now
async def _persist_progress(self, force: bool = False) -> None:
# Both the DB write and the websocket notify are gated by the same
# throttle: persisting hits SQLite, and notifying broadcasts to every
# client, so doing either per-chunk (small --download-chunk-size or
# many concurrent segments) would overwhelm both. Skip entirely inside
# the window; the next persist (or a forced one) ships the latest bytes.
now = time.monotonic()
if not force and now - self._last_persist < _PERSIST_INTERVAL:
return
self._last_persist = now
# SQLite is blocking; run it off the event loop per the queries module
# contract so progress persists don't stall the web server.
await asyncio.to_thread(self._write_progress)
if self._notify:
self._notify(self.spec.download_id)
def _write_progress(self) -> None:
queries.update_download(self.spec.download_id, bytes_done=self.state.bytes_done)
for seg in self.state.segments:
if seg.end >= seg.start: # skip unknown-size sentinel
queries.update_segment_progress(
self.spec.download_id, seg.idx, seg.bytes_done
)
async def _reset_for_restart(self) -> None:
await self._close_writer()
self._remove_temp()
for seg in self.state.segments:
seg.bytes_done = 0
self.state.bytes_done = 0
await asyncio.to_thread(
queries.update_download, self.spec.download_id, bytes_done=0
)
if await asyncio.to_thread(queries.list_segments, self.spec.download_id):
await asyncio.to_thread(
queries.replace_segments, self.spec.download_id, []
)
async def _close_writer(self) -> None:
if self._writer is not None:
try:
await self._writer.close()
except Exception:
logging.debug("[model_downloader] writer close error", exc_info=True)
self._writer = None
def _segmented_part_valid(self, total_bytes: int) -> bool:
"""True when the temp file is the preallocated segmented ``.part``.
A segmented transfer preallocates the .part to ``total_bytes`` up front
and tracks how much of each range landed via per-segment offsets. Those
offsets are only trustworthy when the file they describe is still on
disk at its full preallocated size. A missing file (swept after a
failure, removed on a fatal error, deleted by hand) or a wrong-sized one
means the persisted offsets no longer correspond to real bytes and must
not be resumed over. Doing so would skip "complete" segments and leave
zero-filled holes that pass the size-only verification gate.
"""
try:
return os.path.getsize(self.spec.temp_path) == total_bytes
except OSError:
return False
def _contiguous_prefix_valid(self, prefix_len: int) -> bool:
"""True when the temp file is exactly ``prefix_len`` contiguous bytes.
Single-stream resume appends sequentially, so a valid resume point
implies the .part size equals the persisted offset. A larger file (e.g.
one preallocated to ``total_bytes`` by a previous segmented run) or a
missing/short file means the persisted offset is not a trustworthy
contiguous prefix and must not be resumed over.
"""
try:
return os.path.getsize(self.spec.temp_path) == prefix_len
except OSError:
return False
def _remove_temp(self) -> None:
try:
os.remove(self.spec.temp_path)
except FileNotFoundError:
pass
except OSError as e:
logging.warning(
"[model_downloader] could not remove %s: %s", self.spec.temp_path, e
)
async def _set_status(self, status: str, error: Optional[str] = None) -> None:
# ``error`` is authoritative: passing None clears any prior failure
# text so transitions out of a failure state (retry/success) don't
# leave stale messages on RuntimeState or in the persisted row.
self.state.status = status
self.state.error = error
fields = {"status": status, "bytes_done": self.state.bytes_done, "error": error}
if status == DownloadStatus.QUEUED:
fields["attempts"] = self.spec.attempts + 1
self.spec.attempts += 1
await asyncio.to_thread(queries.update_download, self.spec.download_id, **fields)
if self._notify:
self._notify(self.spec.download_id)

View File

@ -0,0 +1,51 @@
"""Segment planning.
Split a known byte range into S roughly-equal segments, each fetched by its
own coroutine with ``Range: bytes=start-end``. Falls back to a single segment
when the server doesn't support ranges or the size is unknown/too small for
segmentation to be worthwhile.
"""
from __future__ import annotations
from dataclasses import dataclass
# Below this size, the per-connection setup cost outweighs any parallelism.
_MIN_SEGMENT_BYTES = 1 * 1024 * 1024
@dataclass(frozen=True)
class SegmentPlan:
idx: int
start: int
end: int # inclusive
@property
def length(self) -> int:
return self.end - self.start + 1
def effective_segment_count(
total_bytes: int | None, accept_ranges: bool, configured: int
) -> int:
"""How many segments to actually use for this file."""
if not accept_ranges or total_bytes is None or total_bytes <= 0:
return 1
by_size = max(1, total_bytes // _MIN_SEGMENT_BYTES)
return max(1, min(configured, by_size))
def plan_segments(total_bytes: int, num_segments: int) -> list[SegmentPlan]:
"""Return ``num_segments`` contiguous, inclusive byte ranges covering [0, total)."""
if total_bytes <= 0 or num_segments <= 1:
return [SegmentPlan(idx=0, start=0, end=max(0, total_bytes - 1))]
base = total_bytes // num_segments
plans: list[SegmentPlan] = []
start = 0
for i in range(num_segments):
# Last segment soaks up the remainder.
length = base if i < num_segments - 1 else total_bytes - start
end = start + length - 1
plans.append(SegmentPlan(idx=i, start=start, end=end))
start = end + 1
return plans

View File

@ -0,0 +1,110 @@
"""Positioned, off-loop file writes.
Network I/O stays on the event loop; every blocking disk op (preallocate,
positioned write, fsync) is run in a bounded thread pool via
``run_in_executor`` so downloads never stall inference or the web server.
A single file descriptor is opened for the whole download. Segments write to
their own offsets with ``os.pwrite`` — which is offset-addressed and atomic
per call, so concurrent segment writers need no extra locking. Per-chunk
fsync is avoided; we fsync once at completion.
``os.pwrite`` is unavailable on Windows, so there we fall back to
``os.lseek`` + ``os.write`` guarded by a per-writer lock (the seek/write pair
is not atomic, so concurrent segment writers must be serialized).
"""
from __future__ import annotations
import asyncio
import os
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
# One shared, bounded pool for all download disk I/O.
_EXECUTOR = ThreadPoolExecutor(max_workers=8, thread_name_prefix="dl-writer")
_HAS_PWRITE = hasattr(os, "pwrite")
# On Windows ``os.open`` defaults to text mode, which translates every ``\n``
# byte into ``\r\n`` on write and corrupts binary payloads (the file grows by
# one byte per 0x0A). ``O_BINARY`` disables that translation; it does not exist
# on POSIX, where the default is already binary.
_O_BINARY = getattr(os, "O_BINARY", 0)
class FileWriter:
"""Owns the ``.part`` file descriptor for one download."""
def __init__(self, path: str) -> None:
self.path = path
self._fd: Optional[int] = None
# Serializes lseek+write on platforms without os.pwrite (Windows).
self._seek_lock = threading.Lock()
def _open(self) -> None:
os.makedirs(os.path.dirname(self.path), exist_ok=True)
self._fd = os.open(self.path, os.O_RDWR | os.O_CREAT | _O_BINARY, 0o644)
async def open(self) -> None:
await asyncio.get_running_loop().run_in_executor(_EXECUTOR, self._open)
async def preallocate(self, size: int) -> None:
"""Grow the file to ``size`` so segments write to their offsets."""
if self._fd is None or size <= 0:
return
await asyncio.get_running_loop().run_in_executor(
_EXECUTOR, os.ftruncate, self._fd, size
)
async def truncate(self, size: int = 0) -> None:
"""Truncate the file to ``size`` bytes (default: empty it)."""
if self._fd is None:
return
await asyncio.get_running_loop().run_in_executor(
_EXECUTOR, os.ftruncate, self._fd, size
)
def _pwrite_all(self, data: bytes, offset: int) -> None:
"""A positioned write may write fewer bytes than requested (signal
interruption, near-ENOSPC); loop until every byte lands so we never
leave a gap while the caller advances by the full chunk length.
Uses ``os.pwrite`` where available (offset-addressed, atomic per call).
On Windows it falls back to ``os.lseek`` + ``os.write`` under a lock,
since that pair is not atomic across concurrent segment writers."""
assert self._fd is not None, "writer not opened"
view = memoryview(data)
written = 0
total = len(view)
while written < total:
if _HAS_PWRITE:
n = os.pwrite(self._fd, view[written:], offset + written)
else:
with self._seek_lock:
os.lseek(self._fd, offset + written, os.SEEK_SET)
n = os.write(self._fd, view[written:])
if n == 0:
raise OSError(
f"positioned write wrote 0 bytes at offset {offset + written} "
f"({written}/{total} bytes written)"
)
written += n
async def write_at(self, offset: int, data: bytes) -> None:
assert self._fd is not None, "writer not opened"
await asyncio.get_running_loop().run_in_executor(
_EXECUTOR, self._pwrite_all, data, offset
)
async def flush(self) -> None:
if self._fd is None:
return
await asyncio.get_running_loop().run_in_executor(_EXECUTOR, os.fsync, self._fd)
async def close(self) -> None:
if self._fd is None:
return
fd, self._fd = self._fd, None
await asyncio.get_running_loop().run_in_executor(_EXECUTOR, os.close, fd)

View File

@ -0,0 +1,454 @@
"""Public facade for the download manager.
This is the only object the server imports. It validates requests, owns the
:class:`Scheduler`, and exposes a small async API plus read models for status.
"""
from __future__ import annotations
import asyncio
import logging
import os
import uuid
from typing import Callable, Optional
from app.model_downloader.constants import DownloadStatus
from app.model_downloader.database import queries
from app.model_downloader.net.probe import gated_error_message, probe
from app.model_downloader.scheduler import SCHEDULER
from app.model_downloader.security import paths
from app.model_downloader.net.http import redact_url
from app.model_downloader.security.allowlist import (
ALLOWED_MODEL_EXTENSIONS,
filename_extension,
is_host_allowed_url,
is_url_downloadable,
url_path_extension,
)
from app.model_downloader.security.paths import InvalidModelId
# Non-terminal statuses: an existing row in one of these blocks a re-enqueue.
_LIVE_STATUSES = (
DownloadStatus.QUEUED,
DownloadStatus.ACTIVE,
DownloadStatus.PAUSED,
DownloadStatus.VERIFYING,
)
class DownloadError(Exception):
"""A user-facing error with a stable machine-readable code."""
def __init__(self, code: str, message: str, status: int = 400) -> None:
super().__init__(message)
self.code = code
self.message = message
self.http_status = status
class DownloadManager:
def __init__(self) -> None:
self._scheduler = SCHEDULER
self._notify_cb: Optional[Callable[[str], None]] = None
# Serializes the "check for a live download, then write" critical section
# per model_id. ``downloads`` has no uniqueness constraint on model_id
# (history rows are kept), so without this two concurrent enqueue/resume
# calls could both pass the live check and admit two jobs sharing one
# temp/dest path. The manager is a process singleton over a local SQLite
# DB, so an in-process lock is sufficient (and avoids a migration).
self._model_locks: dict[str, asyncio.Lock] = {}
def set_notify(self, cb: Optional[Callable[[str], None]]) -> None:
self._notify_cb = cb
self._scheduler.set_notify(cb)
async def start(self) -> None:
await self._scheduler.start()
# ----- enqueue -----
async def enqueue(
self,
url: str,
model_id: str,
*,
priority: int = 0,
expected_sha256: Optional[str] = None,
allow_any_extension: bool = False,
credential_id: Optional[str] = None,
) -> str:
# Coarse gate first: host/scheme must be allowlisted, and any extension
# present in the URL path must be a known model type. A URL whose path
# carries NO extension (e.g. Civitai's ``/api/download/models/<id>``) is
# admitted here and its real extension is resolved from the network
# below before the download is finally accepted.
if allow_any_extension:
if not is_host_allowed_url(url):
raise DownloadError(
"URL_NOT_ALLOWED",
"URL is not on the download allowlist (host/scheme).",
)
elif not is_url_downloadable(url):
raise DownloadError(
"URL_NOT_ALLOWED",
"URL is not on the download allowlist (host/scheme/extension).",
)
# When the URL path has no extension, follow it to where it resolves and
# adopt the real extension from the response, forcing the stored
# filename to match. Skipped when the caller opted into any extension.
if not allow_any_extension and url_path_extension(url) == "":
resolved_ext = await self._resolve_extension(url, credential_id)
model_id = paths.apply_extension(model_id, resolved_ext)
try:
paths.parse_model_id(model_id, allow_any_extension)
dest_path, temp_path = paths.resolve_destination(model_id, allow_any_extension)
except InvalidModelId as e:
raise DownloadError("INVALID_MODEL_ID", str(e))
if await asyncio.to_thread(
paths.resolve_existing, model_id, allow_any_extension
):
raise DownloadError(
"ALREADY_AVAILABLE",
f"Model already exists on disk: {model_id}",
status=409,
)
download_id = str(uuid.uuid4())
# Hold the per-model lock across the live check and the insert so a
# concurrent enqueue/resume for the same model_id cannot interleave
# between them and create a second job against the same temp/dest path.
async with self._model_lock(model_id):
if await self._has_live_download(model_id):
raise DownloadError(
"ALREADY_DOWNLOADING",
f"A download for {model_id} is already in progress.",
status=409,
)
await asyncio.to_thread(
queries.insert_download,
{
"id": download_id,
"url": url,
"model_id": model_id,
"dest_path": dest_path,
"temp_path": temp_path,
"status": DownloadStatus.QUEUED,
"priority": priority,
"expected_sha256": expected_sha256,
"credential_id": credential_id,
"allow_any_extension": allow_any_extension,
},
)
logging.info("[model_downloader] enqueued %s -> %s", redact_url(url), model_id)
await self._scheduler.pump()
return download_id
async def _resolve_extension(
self, url: str, credential_id: Optional[str]
) -> str:
"""Follow ``url`` to its final response and return the real extension.
Used for allowlisted URLs whose path has no extension (e.g. Civitai
download endpoints): the filename lives in the ``Content-Disposition``
header or the post-redirect URL. Raises :class:`DownloadError` when the
URL can't be resolved, needs credentials, or resolves to something that
is not a known model file — so we never persist a bogus destination.
"""
pr = await probe(url, credential_id=credential_id)
if not pr.ok:
if pr.gated:
raise DownloadError(
"GATED_REPO" if pr.is_gated_repo else "CREDENTIALS_REQUIRED",
gated_error_message(url, pr),
status=401,
)
raise DownloadError(
"URL_RESOLVE_FAILED",
f"Could not resolve {redact_url(url)}: {pr.error or 'unknown error'}",
status=502,
)
ext = filename_extension(pr.filename) if pr.filename else ""
if ext not in ALLOWED_MODEL_EXTENSIONS:
raise DownloadError(
"URL_NOT_ALLOWED",
f"URL resolves to {pr.filename or '<unknown>'!r}, which is not a "
f"known model file type {ALLOWED_MODEL_EXTENSIONS}.",
)
return ext
def _model_lock(self, model_id: str) -> asyncio.Lock:
# Lazily create one lock per model_id. There is no ``await`` between the
# lookup and the insert, so under the single asyncio thread this is
# atomic and cannot hand out two different locks for the same model_id.
lock = self._model_locks.get(model_id)
if lock is None:
lock = asyncio.Lock()
self._model_locks[model_id] = lock
return lock
async def _has_live_download(
self, model_id: str, *, exclude_id: Optional[str] = None
) -> bool:
rows = await asyncio.to_thread(queries.list_downloads)
return any(
r.model_id == model_id
and r.id != exclude_id
and r.status in _LIVE_STATUSES
for r in rows
)
# ----- control -----
async def pause(self, download_id: str) -> None:
job = self._scheduler.get_job(download_id)
if job is not None:
job.request_pause()
return
row = await asyncio.to_thread(queries.get_download, download_id)
if row is None:
raise DownloadError("NOT_FOUND", "No such download.", status=404)
if row.status == DownloadStatus.QUEUED:
await asyncio.to_thread(
queries.update_download, download_id, status=DownloadStatus.PAUSED
)
async def resume(self, download_id: str) -> None:
row = await asyncio.to_thread(queries.get_download, download_id)
if row is None:
raise DownloadError("NOT_FOUND", "No such download.", status=404)
if row.status not in (DownloadStatus.PAUSED, DownloadStatus.FAILED):
return
# Re-queueing a paused/failed row must respect the single-live-per-model
# invariant: another download (e.g. a newer enqueue) may already be live
# for this model_id and would share this row's temp/dest path. Hold the
# per-model lock across the check and the status flip, and exclude this
# row itself (a paused row is already a "live" status).
async with self._model_lock(row.model_id):
if await self._has_live_download(row.model_id, exclude_id=download_id):
raise DownloadError(
"ALREADY_DOWNLOADING",
f"A download for {row.model_id} is already in progress.",
status=409,
)
await asyncio.to_thread(
queries.update_download,
download_id,
status=DownloadStatus.QUEUED,
error=None,
)
await self._scheduler.pump()
async def cancel(self, download_id: str) -> None:
job = self._scheduler.get_job(download_id)
if job is not None:
job.request_cancel()
return
row = await asyncio.to_thread(queries.get_download, download_id)
if row is None:
raise DownloadError("NOT_FOUND", "No such download.", status=404)
if row.status in _LIVE_STATUSES:
import os
try:
os.remove(row.temp_path)
except OSError:
pass
await asyncio.to_thread(
queries.update_download, download_id, status=DownloadStatus.CANCELLED
)
async def set_priority(self, download_id: str, priority: int) -> None:
row = await asyncio.to_thread(queries.get_download, download_id)
if row is None:
raise DownloadError("NOT_FOUND", "No such download.", status=404)
await asyncio.to_thread(
queries.update_download, download_id, priority=priority
)
# Admission-order only; a higher priority is
# picked up the next time a slot frees. Pump in case a slot is free now.
await self._scheduler.pump()
async def delete(self, download_id: str) -> None:
"""Delete a terminal download so it stays gone from history.
Refuses to delete a live download so a record is never removed out from
under a running worker; cancel it first. Any leftover ``.part`` temp
file (e.g. from a failed transfer) is removed, but the finished model
file on disk is never touched.
"""
if self._scheduler.get_job(download_id) is not None:
raise DownloadError(
"DOWNLOAD_ACTIVE",
"Cannot delete a download that is still in progress.",
status=409,
)
row = await asyncio.to_thread(queries.get_download, download_id)
if row is None:
raise DownloadError("NOT_FOUND", "No such download.", status=404)
if row.status in _LIVE_STATUSES:
raise DownloadError(
"DOWNLOAD_ACTIVE",
"Cannot delete a download that is still in progress.",
status=409,
)
try:
os.remove(row.temp_path)
except OSError:
pass
await asyncio.to_thread(queries.delete_download, download_id)
async def clear(self) -> int:
"""Delete all terminal downloads from history in one transaction.
Skips anything still live (queued/active/paused/verifying, or a running
job) so an in-flight download is never removed out from under a worker.
Finished model files on disk are never touched; only leftover ``.part``
temp files from failed/cancelled transfers are removed. Returns the
number of history rows deleted.
"""
rows = await asyncio.to_thread(queries.list_downloads)
deletable = [
r
for r in rows
if r.status not in _LIVE_STATUSES
and self._scheduler.get_job(r.id) is None
]
if not deletable:
return 0
for r in deletable:
try:
os.remove(r.temp_path)
except OSError:
pass
return await asyncio.to_thread(
queries.delete_downloads, [r.id for r in deletable]
)
# ----- read models -----
def _view(self, row) -> dict:
"""Combine the persisted row with live in-memory progress, if running."""
job = self._scheduler.get_job(row.id)
bytes_done = row.bytes_done
total = row.total_bytes
speed = None
eta = None
segments = None
if job is not None:
st = job.state
bytes_done = st.bytes_done
total = st.total_bytes if st.total_bytes is not None else total
speed = st.speed_bps
eta = st.eta_seconds
segments = [
{"idx": s.idx, "bytes_done": s.bytes_done, "length": s.length}
for s in st.segments
if s.end >= s.start
]
progress = (bytes_done / total) if total else None
return {
"download_id": row.id,
"model_id": row.model_id,
"url": redact_url(row.url),
"status": row.status,
"priority": row.priority,
"total_bytes": total,
"bytes_done": bytes_done,
"progress": progress,
"speed_bps": speed,
"eta_seconds": eta,
"segments": segments,
"error": row.error,
"created_at": row.created_at,
"updated_at": row.updated_at,
}
def _view_from_state(self, job) -> dict:
"""Build a view purely from the live in-memory job state (no DB)."""
st = job.state
return {
"download_id": st.download_id,
"model_id": st.model_id,
"url": redact_url(st.url),
"status": st.status,
"priority": st.priority,
"total_bytes": st.total_bytes,
"bytes_done": st.bytes_done,
"progress": st.progress,
"speed_bps": st.speed_bps,
"eta_seconds": st.eta_seconds,
"segments": [
{"idx": s.idx, "bytes_done": s.bytes_done, "length": s.length}
for s in st.segments
if s.end >= s.start
],
"error": st.error,
}
def status_sync(self, download_id: str) -> Optional[dict]:
"""Synchronous status read for the websocket notify path.
Uses live in-memory state when the job is running (no DB round-trip on
the hot path); falls back to a quick DB read otherwise.
"""
job = self._scheduler.get_job(download_id)
if job is not None:
return self._view_from_state(job)
row = queries.get_download(download_id)
return self._view(row) if row is not None else None
async def status(self, download_id: str) -> Optional[dict]:
row = await asyncio.to_thread(queries.get_download, download_id)
return self._view(row) if row is not None else None
async def list(self) -> list[dict]:
rows = await asyncio.to_thread(queries.list_downloads)
return [self._view(r) for r in rows]
async def availability(self, models: dict[str, str]) -> dict[str, dict]:
"""Bulk per-id ``{state, progress, ...}`` for the frontend poll.
``state`` is ``available`` (on disk), ``downloading`` (live row), or
``missing``. Cheap: a path lookup plus an in-memory/DB status check.
"""
rows = await asyncio.to_thread(queries.list_downloads)
by_model: dict[str, object] = {}
for r in rows:
if r.status in _LIVE_STATUSES or r.model_id not in by_model:
by_model[r.model_id] = r
# ``url_allowed`` mirrors the coarse enqueue gate (host/scheme + a
# non-disallowed extension); URLs whose extension is only known after a
# network resolve — e.g. Civitai download endpoints — report allowed.
out: dict[str, dict] = {}
for model_id, url in models.items():
try:
exists = await asyncio.to_thread(paths.resolve_existing, model_id)
except InvalidModelId:
out[model_id] = {"state": "missing", "url_allowed": is_url_downloadable(url)}
continue
if exists:
out[model_id] = {"state": "available", "url_allowed": is_url_downloadable(url)}
continue
row = by_model.get(model_id)
if row is not None and row.status in _LIVE_STATUSES:
view = self._view(row)
out[model_id] = {
"state": "downloading",
"url_allowed": is_url_downloadable(url),
"download_id": view["download_id"],
"progress": view["progress"],
"bytes_done": view["bytes_done"],
"total_bytes": view["total_bytes"],
"speed_bps": view["speed_bps"],
}
else:
out[model_id] = {"state": "missing", "url_allowed": is_url_downloadable(url)}
return out
DOWNLOAD_MANAGER = DownloadManager()

View File

@ -0,0 +1,148 @@
"""Manual, validated redirect-following request opener.
Automatic redirects are disabled. We follow hops ourselves
so that on *every* hop we (a) re-validate scheme + reject credentials-in-URL,
(b) recompute which stored credential — if any — applies to that hop's host,
and (c) let the connector's resolver screen the IP. This is the single place
that attaches credentials, so a token can never ride a redirect to a CDN host.
"""
from __future__ import annotations
import logging
import re
from contextlib import asynccontextmanager
from typing import AsyncIterator, Optional
from urllib.parse import unquote, urljoin, urlsplit, urlunsplit
import aiohttp
from app.model_downloader.credentials.resolver import resolve_auth_for_hop
from app.model_downloader.net.session import get_session
from app.model_downloader.security.ssrf import (
MAX_REDIRECTS,
SSRFError,
check_redirect_hop,
)
_REDIRECT_CODES = {301, 302, 303, 307, 308}
DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120)
def redact_url(url: str) -> str:
"""Drop the query string so a query-scheme secret is never logged/stored."""
try:
parts = urlsplit(url)
except ValueError:
return "<unparseable-url>"
return urlunsplit(parts._replace(query=""))
_CD_FILENAME_STAR = re.compile(
r"filename\*\s*=\s*[^']*'[^']*'([^;]+)", re.IGNORECASE
)
_CD_FILENAME_QUOTED = re.compile(r'filename\s*=\s*"([^"]+)"', re.IGNORECASE)
_CD_FILENAME_BARE = re.compile(r"filename\s*=\s*([^;]+)", re.IGNORECASE)
def filename_from_content_disposition(value: Optional[str]) -> Optional[str]:
"""Extract the download filename from a ``Content-Disposition`` header.
Prefers the RFC 5987 ``filename*=`` form (percent-decoded) over the plain
``filename=`` form. Any directory components in the value are stripped so a
hostile header can only influence the *name*, never the target directory.
Returns ``None`` when no filename is present.
"""
if not value:
return None
for pat, decode in (
(_CD_FILENAME_STAR, True),
(_CD_FILENAME_QUOTED, False),
(_CD_FILENAME_BARE, False),
):
m = pat.search(value)
if not m:
continue
raw = m.group(1).strip().strip('"')
if decode:
try:
raw = unquote(raw)
except Exception:
pass
name = raw.replace("\\", "/").rsplit("/", 1)[-1].strip()
if name:
return name
return None
async def _resolve_final_response(
method: str,
url: str,
credential_id: Optional[str],
base_headers: dict[str, str],
timeout: aiohttp.ClientTimeout,
) -> tuple[aiohttp.ClientResponse, str]:
"""Follow redirects manually until a non-redirect response.
Each intermediate redirect response is released before the next hop.
Returns the final ``(response, final_url)``; the caller owns releasing it.
"""
session = await get_session()
current = url
hops = 0
while True:
check_redirect_hop(current, is_initial_url=(hops == 0))
parts = urlsplit(current)
auth = await resolve_auth_for_hop(
parts.hostname or "", parts.scheme, explicit_credential_id=credential_id
)
req_headers = dict(base_headers)
req_url = current
if auth is not None:
req_headers.update(auth.headers)
req_url = auth.apply_to_url(current)
resp = await session.request(
method,
req_url,
allow_redirects=False,
headers=req_headers,
timeout=timeout,
)
if resp.status in _REDIRECT_CODES and resp.headers.get("Location"):
next_url = urljoin(str(resp.url), resp.headers["Location"])
await resp.release()
hops += 1
if hops > MAX_REDIRECTS:
raise SSRFError(
f"too many redirects (> {MAX_REDIRECTS}) for {redact_url(url)}"
)
current = next_url
continue
return resp, redact_url(str(resp.url))
@asynccontextmanager
async def open_validated(
method: str,
url: str,
*,
credential_id: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
timeout: aiohttp.ClientTimeout = DEFAULT_TIMEOUT,
) -> AsyncIterator[tuple[aiohttp.ClientResponse, str]]:
"""Open ``method url`` following redirects manually and validated.
Yields ``(response, final_url)`` where ``final_url`` is redacted of any
query string. The response is released automatically on exit.
"""
resp, final_url = await _resolve_final_response(
method, url, credential_id, dict(headers or {}), timeout
)
try:
yield resp, final_url
finally:
try:
await resp.release()
except Exception: # pragma: no cover - best-effort cleanup
logging.debug("[model_downloader] response release error", exc_info=True)

View File

@ -0,0 +1,157 @@
"""Pre-download probe.
Issues a tiny ranged GET (``Range: bytes=0-0``) — which doubles as a
range-support test — to discover ``Content-Length``, ``Accept-Ranges``,
``ETag``/``Last-Modified``, and the final post-redirect URL. For HuggingFace
LFS files the true size also appears in the non-standard ``X-Linked-Size``
header, which we read as a fallback.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Optional
from urllib.parse import urlparse, urlsplit
import aiohttp
from app.model_downloader.net.http import (
filename_from_content_disposition,
open_validated,
redact_url,
)
from app.model_downloader.net.session import parse_int_header
_PROBE_TIMEOUT = aiohttp.ClientTimeout(total=60, sock_connect=30, sock_read=30)
@dataclass
class ProbeResult:
ok: bool
status: int
final_url: Optional[str] = None
total_bytes: Optional[int] = None
accept_ranges: bool = False
etag: Optional[str] = None
last_modified: Optional[str] = None
gated: bool = False # 401/403 — needs (or has wrong) credentials
error: Optional[str] = None
# HuggingFace's ``X-Error-Code`` header (e.g. ``GatedRepo``,
# ``RepoNotFound``) when the host reports one. Lets us tell "this repo is
# gated — request access" apart from "you just need a token".
error_code: Optional[str] = None
# Filename the server intends this response to be saved as: the
# ``Content-Disposition`` name if present, else the post-redirect URL's
# basename. Used to resolve the real extension for URLs (e.g. Civitai's
# ``/api/download`` endpoints) that carry no extension in their path.
filename: Optional[str] = None
@property
def is_gated_repo(self) -> bool:
"""True when the host says the repo is gated (access must be granted).
Distinct from a plain missing/invalid token: even a valid credential
won't help until the user accepts the model's terms on its page.
"""
return (self.error_code or "").lower() == "gatedrepo"
def _total_from_content_range(value: Optional[str]) -> Optional[int]:
# "bytes 0-0/12345" -> 12345 ; "bytes 0-0/*" -> None
if not value or "/" not in value:
return None
total = value.rsplit("/", 1)[1].strip()
return parse_int_header(total)
def _filename_from_response(
content_disposition: Optional[str], final_url: Optional[str]
) -> Optional[str]:
name = filename_from_content_disposition(content_disposition)
if name:
return name
if final_url:
base = urlsplit(final_url).path.rsplit("/", 1)[-1]
if base:
return base
return None
async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult:
"""Probe ``url`` and return discovered metadata, failing soft."""
try:
async with open_validated(
"GET",
url,
credential_id=credential_id,
headers={"Range": "bytes=0-0", "Accept-Encoding": "identity"},
timeout=_PROBE_TIMEOUT,
) as (resp, final_url):
if resp.status in (401, 403):
error_code = resp.headers.get("X-Error-Code")
error_message = resp.headers.get("X-Error-Message")
return ProbeResult(
ok=False, status=resp.status, final_url=final_url, gated=True,
error_code=error_code,
error=(
error_message
or f"host returned {resp.status} (authentication required)"
),
)
if resp.status not in (200, 206):
return ProbeResult(
ok=False, status=resp.status, final_url=final_url,
error=f"probe returned HTTP {resp.status}",
)
headers = resp.headers
accept_ranges = False
total: Optional[int] = None
if resp.status == 206:
accept_ranges = True
total = _total_from_content_range(headers.get("Content-Range"))
else: # 200: server ignored the range
accept_ranges = headers.get("Accept-Ranges", "").lower() == "bytes"
total = parse_int_header(headers.get("Content-Length"))
if total is None:
total = parse_int_header(headers.get("X-Linked-Size"))
return ProbeResult(
ok=True,
status=resp.status,
final_url=final_url,
total_bytes=total,
accept_ranges=accept_ranges,
etag=headers.get("ETag"),
last_modified=headers.get("Last-Modified"),
filename=_filename_from_response(
headers.get("Content-Disposition"), final_url
),
)
except Exception as e: # network / SSRF / timeout
host = urlparse(url).netloc or "<unknown>"
logging.debug("[model_downloader] probe failed for %s: %s", host, type(e).__name__)
return ProbeResult(ok=False, status=0, error="probe failed: network error")
def gated_error_message(url: str, pr: ProbeResult) -> str:
"""Build a user-facing message for a gated/auth-required probe result.
Distinguishes a *gated* repo (access must be requested/granted on the model
page — a token alone is not enough) from a plain missing/invalid credential.
"""
redacted = redact_url(url)
if pr.is_gated_repo:
detail = (pr.error or "access is restricted").rstrip()
if detail and not detail.endswith((".", "!", "?")):
detail += "."
return (
f"{redacted} is a gated model — {detail} Request access on the model's "
f"page, add an API key for this host at /api/download/credentials, and retry."
)
return (
f"{redacted} requires authentication. Add an API key for this host at "
f"/api/download/credentials and retry."
)

View File

@ -0,0 +1,72 @@
"""Lazily-created shared :class:`aiohttp.ClientSession`.
A single session reuses TLS handshakes and TCP connections across the probe
and the many segment GETs to the same host (HuggingFace is the dominant
case), which is a large speedup on cold connections and exactly the
connection-reuse strategy that lets us match aria2c.
The connector uses :class:`ValidatingResolver` so every connection — initial
or post-redirect — is screened for private/special-use IPs at connect time.
TLS is pinned to certifi's CA bundle because the OS trust store is not wired
up on some Python installs (python.org macOS, slim containers).
"""
from __future__ import annotations
import asyncio
import ssl
from typing import Optional
import aiohttp
try:
import certifi
_CA_FILE = certifi.where()
except Exception: # pragma: no cover - certifi is a transitive dep of aiohttp
_CA_FILE = None
from comfy.cli_args import args
from app.model_downloader.security.ssrf import ValidatingResolver
_session: Optional[aiohttp.ClientSession] = None
_lock = asyncio.Lock()
def ssl_context() -> ssl.SSLContext:
if _CA_FILE is not None:
return ssl.create_default_context(cafile=_CA_FILE)
return ssl.create_default_context()
async def get_session() -> aiohttp.ClientSession:
"""Return the shared session, creating it on first use."""
global _session
if _session is not None and not _session.closed:
return _session
async with _lock:
if _session is None or _session.closed:
connector = aiohttp.TCPConnector(
limit_per_host=max(1, getattr(args, "download_max_connections_per_host", 16)),
ssl=ssl_context(),
resolver=ValidatingResolver(),
)
_session = aiohttp.ClientSession(connector=connector)
return _session
async def close_session() -> None:
global _session
if _session is not None and not _session.closed:
await _session.close()
_session = None
def parse_int_header(value: Optional[str]) -> Optional[int]:
"""Parse a non-negative integer header value, or None if bad/absent."""
if not value:
return None
try:
n = int(value)
except (TypeError, ValueError):
return None
return n if n >= 0 else None

View File

@ -0,0 +1,177 @@
"""Priority scheduler + lifecycle.
Owns the set of running jobs and admits queued downloads up to a global
concurrency limit (K), highest priority first, FIFO within a priority. Runs
entirely on the existing ComfyUI asyncio loop; blocking work (disk, hashing,
DB) is offloaded by the job/writer layers.
On startup it reconciles DB vs. disk: ``active``/``verifying`` rows left by a
previous run are reset to ``queued`` and resumed from persisted offsets, and
orphaned ``.part`` files with no live download row are swept.
"""
from __future__ import annotations
import asyncio
import logging
import os
import random
import time
from typing import Callable, Optional
from comfy.cli_args import args
from app.model_downloader.constants import DownloadStatus
from app.model_downloader.database import queries
from app.model_downloader.engine.job import DownloadJob, JobSpec
from app.model_downloader.security import paths
# Backoff for retryable failures
_BACKOFF_BASE = 2.0
_BACKOFF_CAP = 300.0
_MAX_ATTEMPTS = 6
class Scheduler:
def __init__(self) -> None:
self._jobs: dict[str, DownloadJob] = {}
self._tasks: dict[str, asyncio.Task] = {}
self._backoff_until: dict[str, float] = {}
self._pump_lock = asyncio.Lock()
self._notify_cb: Optional[Callable[[str], None]] = None
self._started = False
@property
def max_active(self) -> int:
return max(1, getattr(args, "download_max_active", 3))
def set_notify(self, cb: Optional[Callable[[str], None]]) -> None:
self._notify_cb = cb
def get_job(self, download_id: str) -> Optional[DownloadJob]:
return self._jobs.get(download_id)
def is_active(self, download_id: str) -> bool:
return download_id in self._tasks
# ----- startup -----
async def start(self) -> None:
if self._started:
return
self._started = True
try:
await asyncio.to_thread(queries.reconcile_live_downloads)
await asyncio.to_thread(self._sweep_orphan_temp_files)
except Exception as e:
logging.warning("[model_downloader] startup reconcile failed: %s", e)
await self.pump()
@staticmethod
def _sweep_orphan_temp_files() -> None:
"""Remove ``.part`` files not referenced by a resumable download row.
Resumable partials are preserved; only truly orphaned temp files from
crashed runs are deleted. ``FAILED`` is included because
:meth:`DownloadManager.resume` explicitly permits resuming a
retry-exhausted failed row: deleting its partial here while the
per-segment offsets survive in the DB would make the next resume
preallocate a fresh sparse file, skip every "complete" segment, and
leave zero-filled holes that pass the size-only verification gate.
"""
live = {
row.temp_path
for row in queries.list_downloads()
if row.status
in (
DownloadStatus.QUEUED,
DownloadStatus.PAUSED,
DownloadStatus.FAILED,
)
}
for path in paths.iter_all_tmp_paths():
if path in live:
continue
try:
os.remove(path)
logging.info("[model_downloader] removed orphan temp file: %s", path)
except OSError as e:
logging.warning("[model_downloader] could not remove %s: %s", path, e)
# ----- admission -----
async def pump(self) -> None:
async with self._pump_lock:
slots = self.max_active - len(self._tasks)
if slots <= 0:
return
now = time.monotonic()
candidates = await asyncio.to_thread(queries.list_queued_downloads)
for row in candidates:
if slots <= 0:
break
if row.id in self._tasks:
continue
if self._backoff_until.get(row.id, 0.0) > now:
continue
self._admit(row)
slots -= 1
def _admit(self, row) -> None:
spec = JobSpec(
download_id=row.id,
url=row.url,
model_id=row.model_id,
dest_path=row.dest_path,
temp_path=row.temp_path,
priority=row.priority,
credential_id=row.credential_id,
expected_sha256=row.expected_sha256,
allow_any_extension=row.allow_any_extension,
etag=row.etag,
attempts=row.attempts,
)
job = DownloadJob(spec, notify_cb=self._notify_cb)
self._jobs[row.id] = job
self._tasks[row.id] = asyncio.ensure_future(self._run_job(job))
async def _run_job(self, job: DownloadJob) -> None:
download_id = job.spec.download_id
status = DownloadStatus.FAILED
try:
status = await job.run()
except Exception as e: # run() is defensive, but never let a task die silently
logging.error("[model_downloader] job %s crashed: %s", download_id, e)
queries.update_download(
download_id,
status=DownloadStatus.FAILED,
error=f"internal error: {e}",
)
if self._notify_cb:
self._notify_cb(download_id)
finally:
self._tasks.pop(download_id, None)
self._jobs.pop(download_id, None)
if status == DownloadStatus.QUEUED:
if job.spec.attempts >= _MAX_ATTEMPTS:
queries.update_download(
download_id,
status=DownloadStatus.FAILED,
error=f"giving up after {job.spec.attempts} attempts",
)
if self._notify_cb:
self._notify_cb(download_id)
else:
delay = min(
_BACKOFF_CAP, _BACKOFF_BASE ** job.spec.attempts
) + random.uniform(0, 1.0)
self._backoff_until[download_id] = time.monotonic() + delay
asyncio.ensure_future(self._delayed_pump(delay))
await self.pump()
async def _delayed_pump(self, delay: float) -> None:
await asyncio.sleep(delay)
await self.pump()
SCHEDULER = Scheduler()

View File

@ -0,0 +1,140 @@
"""URL allowlist for server-side model fetches.
Default-deny. A URL is downloadable only when its parsed host + scheme are
allowlisted AND (unless explicitly relaxed) its final filename ends in a
known model extension.
The built-in host defaults mirror the frontend's ``isModelDownloadable``
allowlist so the two flows agree on what is eligible; ``--download-allowed-hosts``
extends it for self-hosted mirrors. Matching is done on ``urlparse().hostname``
(never a raw string prefix) so userinfo tricks like
``http://127.0.0.1@169.254.169.254/x.safetensors`` — whose real host is the
metadata IP — cannot slip past.
"""
from __future__ import annotations
from urllib.parse import urlparse
from comfy.cli_args import args
# host -> set of allowed schemes. Frontend parity (HuggingFace / Civitai /
# localhost). Extra hosts from --download-allowed-hosts are https-only.
_DEFAULT_ALLOWED_HOSTS: dict[str, set[str]] = {
"huggingface.co": {"https"},
"civitai.com": {"https"},
"localhost": {"http", "https"},
"127.0.0.1": {"http", "https"},
}
# Hosts for which loopback addresses are intentionally permitted (the localhost
# "download a local model" feature). Every other host's loopback resolution is
# rejected by the SSRF resolver.
LOOPBACK_HOSTS = frozenset({"localhost", "127.0.0.1", "::1"})
# Known model file extensions (frontend parity). Checked on the final filename.
ALLOWED_MODEL_EXTENSIONS = (
".safetensors",
".sft",
".ckpt",
".pth",
".pt",
".gguf",
".bin",
)
def _allowed_hosts() -> dict[str, set[str]]:
hosts = {h: set(s) for h, s in _DEFAULT_ALLOWED_HOSTS.items()}
for extra in getattr(args, "download_allowed_hosts", []) or []:
host = extra.strip().lower()
if host:
hosts.setdefault(host, set()).add("https")
return hosts
def is_host_allowed(host: str | None, scheme: str | None) -> bool:
"""True iff ``host`` is allowlisted for ``scheme``.
Used both for the initial URL and re-checked on every redirect hop,
so a whitelisted URL cannot 30x into an off-list host.
"""
if not host or not scheme:
return False
allowed = _allowed_hosts().get(host.lower())
return allowed is not None and scheme.lower() in allowed
def has_allowed_extension(path: str, allow_any_extension: bool = False) -> bool:
if allow_any_extension:
return True
return path.lower().endswith(ALLOWED_MODEL_EXTENSIONS)
def filename_extension(name: str) -> str:
"""Lowercased extension (including the leading dot) of a bare filename.
Returns ``""`` when there is no extension. A leading-dot name
(``.safetensors``) is treated as having no extension (all stem), matching
``os.path.splitext`` semantics so dotfiles aren't mistaken for typed files.
"""
base = name.replace("\\", "/").rsplit("/", 1)[-1]
dot = base.rfind(".")
if dot <= 0:
return ""
return base[dot:].lower()
def is_allowed_extension_name(name: str) -> bool:
"""True iff ``name`` ends in one of the known model extensions."""
return name.lower().endswith(ALLOWED_MODEL_EXTENSIONS)
def is_host_allowed_url(url: str) -> bool:
"""True iff ``url`` parses and its host+scheme are allowlisted."""
if not isinstance(url, str) or not url:
return False
try:
parsed = urlparse(url)
except ValueError:
return False
return is_host_allowed(parsed.hostname, parsed.scheme)
def url_path_extension(url: str) -> str:
"""Extension of the URL *path* basename (query ignored), or ``""``."""
try:
parsed = urlparse(url)
except ValueError:
return ""
return filename_extension(parsed.path)
def is_url_downloadable(url: str) -> bool:
"""Coarse enqueue gate: host/scheme allowed and extension not disallowed.
Unlike :func:`is_url_allowed` (which demands a known extension *in the URL*),
this also admits URLs whose path carries no extension at all — e.g. a Civitai
``/api/download/models/<id>`` endpoint whose real filename only shows up in
the redirect target / ``Content-Disposition``. The true extension is then
resolved from the network and re-validated before the download is admitted.
A path bearing an explicit *non-model* extension (``.zip``, ``.html``, ...)
is still rejected here.
"""
if not is_host_allowed_url(url):
return False
ext = url_path_extension(url)
return ext == "" or ext in ALLOWED_MODEL_EXTENSIONS
def is_url_allowed(url: str, allow_any_extension: bool = False) -> bool:
"""Check whether ``url`` is permitted as a server-side download source."""
if not isinstance(url, str) or not url:
return False
try:
parsed = urlparse(url)
except ValueError:
return False
if not is_host_allowed(parsed.hostname, parsed.scheme):
return False
return has_allowed_extension(parsed.path, allow_any_extension)

View File

@ -0,0 +1,132 @@
"""Path resolution + traversal safety for downloads.
A ``model_id`` is a *relative destination path* of the form
``<directory>/<filename>`` (e.g. ``loras/my_lora.safetensors``). This module
turns one into an absolute on-disk path under one of ComfyUI's registered
model folders, rejecting unknown folders, path traversal, and symlink escape.
This is the only thing that composes destination paths, so the engine never
touches user-supplied path strings directly.
"""
from __future__ import annotations
import os
import re
from typing import Iterator, Optional
import folder_paths
from app.model_downloader.constants import TMP_SUFFIX
from app.model_downloader.security.allowlist import ALLOWED_MODEL_EXTENSIONS
# A model_id component is a single path segment of safe characters — no slashes,
# no "..", no leading dots that could escape the target directory.
_SEGMENT_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$")
class InvalidModelId(ValueError):
"""Raised when a model_id is malformed or names an unknown model folder."""
def parse_model_id(model_id: str, allow_any_extension: bool = False) -> tuple[str, str]:
"""Split ``<directory>/<filename>`` and validate both components.
Returns ``(directory, filename)``. Does not touch the filesystem.
"""
if not isinstance(model_id, str) or "/" not in model_id:
raise InvalidModelId(
f"model_id must be '<directory>/<filename>', got {model_id!r}"
)
directory, _, filename = model_id.partition("/")
if "/" in filename or not directory or not filename:
raise InvalidModelId(
f"model_id must have exactly one '/' separator, got {model_id!r}"
)
if not _SEGMENT_RE.match(directory):
raise InvalidModelId(f"invalid directory segment {directory!r}")
if not _SEGMENT_RE.match(filename):
raise InvalidModelId(f"invalid filename segment {filename!r}")
if not allow_any_extension and not filename.lower().endswith(
ALLOWED_MODEL_EXTENSIONS
):
raise InvalidModelId(
f"filename must end with a known model extension "
f"{ALLOWED_MODEL_EXTENSIONS}, got {filename!r}"
)
if directory not in folder_paths.folder_names_and_paths:
raise InvalidModelId(f"unknown model folder {directory!r}")
return directory, filename
def apply_extension(model_id: str, ext: str) -> str:
"""Return ``model_id`` with its filename forced to end in ``ext``.
``ext`` includes the leading dot (e.g. ``".safetensors"``). If the filename
already ends in a *known model extension* it is replaced; otherwise ``ext``
is appended (so ``loras/mymodel`` -> ``loras/mymodel.safetensors`` and
``loras/mymodel.ckpt`` -> ``loras/mymodel.safetensors``). A filename with a
non-model suffix (``my.model.v2``) is treated as an extensionless stem and
``ext`` is appended. The directory part is left untouched; validation is
still the caller's job via :func:`parse_model_id`.
"""
directory, sep, filename = model_id.partition("/")
if not sep:
return model_id # malformed; parse_model_id will reject it
low = filename.lower()
for known in ALLOWED_MODEL_EXTENSIONS:
if low.endswith(known):
filename = filename[: -len(known)]
break
return f"{directory}{sep}{filename}{ext}"
def resolve_existing(model_id: str, allow_any_extension: bool = False) -> Optional[str]:
"""Return the absolute path of an installed model, or None if missing.
Honours ``extra_model_paths.yaml`` transparently via ``get_full_path``.
"""
directory, filename = parse_model_id(model_id, allow_any_extension)
return folder_paths.get_full_path(directory, filename)
def resolve_destination(
model_id: str, allow_any_extension: bool = False
) -> tuple[str, str]:
"""Return ``(final_path, temp_path)`` for a download.
Downloads land at the first registered path for the model's directory
(the "primary" location). ``temp_path`` is a sibling ``.part`` file that
is atomically renamed onto ``final_path`` on success. The result is
asserted to stay within the registered root (defence in depth on top of
the segment regex).
"""
directory, filename = parse_model_id(model_id, allow_any_extension)
roots = folder_paths.get_folder_paths(directory)
if not roots:
raise InvalidModelId(f"no on-disk path registered for folder {directory!r}")
root = os.path.realpath(roots[0])
final_path = os.path.realpath(os.path.join(root, filename))
if final_path != root and not final_path.startswith(root + os.sep):
raise InvalidModelId(f"resolved path escapes model root: {model_id!r}")
temp_path = f"{final_path}{TMP_SUFFIX}"
return final_path, temp_path
def iter_all_tmp_paths() -> Iterator[str]:
"""Yield this subsystem's temp files under every registered model folder.
Matches only the distinctive ``TMP_SUFFIX`` so the startup orphan sweep
can never delete temp files created by other tools.
"""
seen_roots: set[str] = set()
for directory in list(folder_paths.folder_names_and_paths.keys()):
for root in folder_paths.get_folder_paths(directory):
if root in seen_roots or not os.path.isdir(root):
continue
seen_roots.add(root)
try:
for entry in os.scandir(root):
if entry.is_file() and entry.name.endswith(TMP_SUFFIX):
yield entry.path
except OSError:
continue

View File

@ -0,0 +1,163 @@
"""SSRF / exfiltration defenses.
Two cooperating layers:
1. :class:`ValidatingResolver` is installed on the shared connector. Every
connection — the initial probe and every segment GET, including ones made
after a redirect — resolves its host through this resolver, which rejects
any address that lands on a private / special-use IP range. Because the
resolve and the connect happen together inside the connector, there is no
check-then-connect window for DNS rebinding to exploit.
2. :func:`check_redirect_hop` re-validates every hop. The host allowlist gates
only the *initial* user-supplied URL (anti-SSRF for arbitrary input);
legitimate downloads from allowlisted origins redirect to presigned CDN
hosts that are deliberately NOT on the allowlist (HF ->
``cdn-lfs*.huggingface.co``, Civitai -> signed Cloudflare/S3), so hops are
instead screened for scheme, embedded credentials, and — via the resolver
above — private IPs. Credentials are only ever attached when a hop's host
exactly matches a stored credential, so they are dropped on the CDN hop.
Loopback (the "download a local model" feature) is exempt from IP filtering
only for the initial URL: a *redirect* may never target a loopback host or
a blocked IP-literal, which the resolver alone can't enforce (it exempts
loopback literals and never sees IP literals through DNS).
"""
from __future__ import annotations
import ipaddress
import socket
from urllib.parse import urlparse
from aiohttp.abc import AbstractResolver
from aiohttp.resolver import DefaultResolver
from app.model_downloader.security.allowlist import LOOPBACK_HOSTS
# Cap the redirect chain length a hop may use.
MAX_REDIRECTS = 5
class SSRFError(Exception):
"""A hop failed an SSRF / allowlist check."""
def is_scheme_allowed(scheme: str | None, host: str | None) -> bool:
"""True iff ``scheme`` is permitted for ``host`` on a download hop.
https is always allowed; plain http only for loopback/approved dev hosts.
"""
if not scheme:
return False
scheme = scheme.lower()
if scheme == "https":
return True
if scheme == "http":
return bool(host) and host.lower() in LOOPBACK_HOSTS
return False
def is_blocked_ip(ip_str: str) -> bool:
"""True for any address we refuse to connect to.
Covers loopback, link-local (incl. 169.254.169.254 cloud metadata),
RFC1918 private ranges, unique-local (ULA), unspecified (0.0.0.0/::),
multicast and other reserved ranges.
"""
try:
ip = ipaddress.ip_address(ip_str)
except ValueError:
return True # unparseable -> refuse
# On CPython before the gh-113171 fix (backported to 3.12.4/3.11.9/
# 3.10.14/3.9.19) the is_* properties don't see through IPv4-mapped IPv6
# (e.g. ::ffff:169.254.169.254), so resolve and re-check the embedded IPv4
# to keep mapped metadata/private addresses from slipping past the filter.
mapped = getattr(ip, "ipv4_mapped", None)
if mapped is not None:
ip = mapped
return (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
or ip.is_unspecified
)
class ValidatingResolver(AbstractResolver):
"""Delegating resolver that drops blocked IPs from every resolution.
If a hostname resolves only to blocked addresses, the connection fails
closed with an :class:`OSError`, which aiohttp surfaces as a connection
error to the caller.
"""
def __init__(self) -> None:
self._inner = DefaultResolver()
async def resolve(self, host, port=0, family=socket.AF_INET):
infos = await self._inner.resolve(host, port, family)
# localhost/127.0.0.1 are an explicit, opt-in allowlist feature.
if isinstance(host, str) and host.lower() in LOOPBACK_HOSTS:
return infos
safe = [info for info in infos if not is_blocked_ip(info["host"])]
if not safe:
raise OSError(
f"refusing to connect to {host!r}: resolves only to "
f"private/special-use addresses"
)
return safe
async def close(self) -> None:
await self._inner.close()
def check_redirect_hop(url: str, *, is_initial_url: bool = False) -> str:
"""Validate one hop's URL.
Returns the URL unchanged on success; raises :class:`SSRFError` otherwise.
Requires https for external hosts (http only for loopback/approved dev
hosts) and forbids credentials-in-URL. The host is NOT re-checked against
the allowlist (CDN redirect targets are off-list by design); credential
leakage is prevented by exact host matching at attach time, and the landing
filename's extension is gated separately by the caller.
Loopback/blocked-IP screening: the connector's resolver filters resolvable
hostnames but exempts literal loopback hosts (``localhost``/``127.0.0.1``/
``::1``) and never sees IP literals through DNS. That loopback exemption is
legitimate only for the *initial* user-supplied URL (``is_initial_url``);
on a redirect hop we reject loopback hosts and any blocked IP-literal here,
so a 30x can't steer a server-side GET at loopback/internal services.
"""
try:
parsed = urlparse(url)
except ValueError as e:
raise SSRFError(f"unparseable redirect URL {url!r}: {e}") from e
host = parsed.hostname
if not host:
raise SSRFError(f"redirect URL has no host: {url!r}")
if not is_scheme_allowed(parsed.scheme, host):
raise SSRFError(
f"redirect to disallowed scheme {parsed.scheme!r} for host "
f"{host!r} (https required for external hosts)"
)
if parsed.username or parsed.password:
raise SSRFError("credentials-in-URL are not allowed")
host_is_loopback = host.lower() in LOOPBACK_HOSTS
if not is_initial_url and host_is_loopback:
raise SSRFError(f"redirect to loopback host {host!r} is not allowed")
# IP-literal targets never go through DNS, so the connector's resolver can't
# screen them — check them directly. The only blocked IP allowed through is
# a loopback literal on the initial URL (handled by the exemption above).
try:
ipaddress.ip_address(host)
except ValueError:
is_ip_literal = False
else:
is_ip_literal = True
if is_ip_literal and is_blocked_ip(host) and not (
is_initial_url and host_is_loopback
):
raise SSRFError(f"redirect to blocked internal address {host!r}")
return url

View File

@ -0,0 +1,49 @@
"""Hub-checksum verification = SHA256.
Only used to confirm a download matches a *provided* ``expected_sha256``. It
is NOT the dedup key (that is blake3, owned by the assets system). The full
sequential read happens at most once, here, only when a checksum was supplied.
"""
from __future__ import annotations
import hashlib
from typing import Callable, Optional
_CHUNK = 8 * 1024 * 1024
InterruptCheck = Callable[[], bool]
class ChecksumError(Exception):
"""The computed SHA256 did not match the expected value."""
def sha256_file(path: str, interrupt_check: Optional[InterruptCheck] = None) -> Optional[str]:
"""Stream the file and return its lowercase hex SHA256.
Returns ``None`` if interrupted via ``interrupt_check``.
"""
h = hashlib.sha256()
with open(path, "rb") as f:
while True:
if interrupt_check is not None and interrupt_check():
return None
chunk = f.read(_CHUNK)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
def verify_sha256(
path: str, expected: str, interrupt_check: Optional[InterruptCheck] = None
) -> None:
"""Raise :class:`ChecksumError` unless the file's SHA256 matches ``expected``."""
actual = sha256_file(path, interrupt_check)
if actual is None:
return # interrupted; caller will re-verify on resume
if actual.lower() != expected.lower():
raise ChecksumError(
f"sha256 mismatch: expected {expected.lower()}, got {actual.lower()}"
)

View File

@ -0,0 +1,53 @@
"""Dedup + catalog handoff — reuse the assets system.
We do NOT build a parallel indexer. "Do I already have it?" is answered by
``resolve_existing`` (path) at enqueue time and, where a hash is known, by the
assets blake3 catalog. After a completed download we register the file
through the assets ingest path so it is cataloged and (eventually) hashed by
the existing enrichment worker.
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Optional
def _register_sync(abs_path: str) -> Optional[str]:
"""Register a finished file into the assets catalog. Returns asset hash."""
try:
from app.assets.services.ingest import register_file_in_place
except Exception as e: # assets package import failure — non-fatal
logging.debug("[model_downloader] assets ingest unavailable: %s", e)
return None
try:
result = register_file_in_place(abs_path, name=os.path.basename(abs_path), tags=[])
return result.asset.hash if result and result.asset else None
except Exception as e:
# The file is already safely on disk; cataloging is best-effort.
logging.warning(
"[model_downloader] could not register %s into assets catalog: %s",
abs_path, e,
)
return None
async def register_completed(abs_path: str) -> Optional[str]:
"""Catalog a completed download via the assets system (off the event loop)."""
return await asyncio.to_thread(_register_sync, abs_path)
def _find_by_hash_sync(blake3_hex: str) -> Optional[str]:
try:
from app.assets.services.asset_management import get_asset_by_hash
except Exception:
return None
asset = get_asset_by_hash("blake3:" + blake3_hex)
return asset.hash if asset is not None else None
async def find_existing_by_hash(blake3_hex: str) -> Optional[str]:
"""Pure DB lookup — never triggers hashing on the hot path."""
return await asyncio.to_thread(_find_by_hash_sync, blake3_hex)

View File

@ -0,0 +1,86 @@
"""Cheap structural validation, no full read.
For ``.safetensors``/``.sft`` we parse the header (first few KB): it carries
the tensor table and the byte length of the data region. We assert
``file_size == 8 + header_len + data_region_len``. This detects truncation
and most corruption for free, before any crypto hashing. Other extensions
have no cheap structural check and pass through.
"""
from __future__ import annotations
import json
import os
import struct
from typing import Optional
_SAFETENSORS_EXTS = (".safetensors", ".sft")
# A sane upper bound so a corrupt header length can't make us read gigabytes.
_MAX_HEADER_BYTES = 100 * 1024 * 1024
class StructuralError(Exception):
"""The file failed its structural integrity check."""
def validate(path: str, name_hint: Optional[str] = None) -> None:
"""Validate the file at ``path``. Raises :class:`StructuralError` on failure.
The file format is detected from ``name_hint`` when provided, otherwise from
``path``. Callers that download into a temp file with an opaque suffix (e.g.
``*.comfy-download.part``) must pass the final destination name as
``name_hint`` so the format check is not silently skipped.
"""
lower = (name_hint or path).lower()
if lower.endswith(_SAFETENSORS_EXTS):
_validate_safetensors(path)
# No structural check for other formats; the size + (optional) checksum
# gates in the engine cover those.
def _validate_safetensors(path: str) -> None:
file_size = os.path.getsize(path)
if file_size < 8:
raise StructuralError(f"file too small to be safetensors ({file_size} bytes)")
with open(path, "rb") as f:
header_len = struct.unpack("<Q", f.read(8))[0]
if header_len <= 0 or header_len > _MAX_HEADER_BYTES:
raise StructuralError(f"implausible safetensors header length {header_len}")
if 8 + header_len > file_size:
raise StructuralError("safetensors header extends past end of file")
try:
header = json.loads(f.read(header_len).decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError) as e:
raise StructuralError(f"safetensors header is not valid JSON: {e}") from e
if not isinstance(header, dict):
raise StructuralError("safetensors header is not a JSON object")
data_len = 0
for name, entry in header.items():
if name == "__metadata__":
continue
if not isinstance(entry, dict) or "data_offsets" not in entry:
raise StructuralError(f"tensor {name!r} missing data_offsets")
offsets = entry["data_offsets"]
if not (isinstance(offsets, list) and len(offsets) == 2):
raise StructuralError(f"tensor {name!r} has malformed data_offsets")
begin, end = offsets
# bool is an int subclass; reject it explicitly to avoid True/False offsets.
if (
not isinstance(begin, int)
or not isinstance(end, int)
or isinstance(begin, bool)
or isinstance(end, bool)
or begin < 0
or end < begin
):
raise StructuralError(f"tensor {name!r} has malformed data_offsets")
data_len = max(data_len, end)
expected = 8 + header_len + data_len
if file_size != expected:
raise StructuralError(
f"size mismatch: file is {file_size} bytes, header implies {expected} "
f"(8 + {header_len} header + {data_len} data)"
)

View File

@ -50,45 +50,21 @@ class ModelFileManager:
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
async def get_model_preview(request):
folder_name = request.match_info.get("folder", None)
path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None)
if folder_name not in folder_paths.folder_names_and_paths:
return web.Response(status=404)
# The "{filename:.*}" capture also matches the empty string, which
# would resolve to the folder itself; reject it explicitly.
if not filename:
return web.Response(status=400)
try:
path_index = int(request.match_info.get("path_index", None))
except (TypeError, ValueError):
return web.Response(status=400)
folders = folder_paths.folder_names_and_paths[folder_name]
if path_index < 0 or path_index >= len(folders[0]):
return web.Response(status=404)
folder = folders[0][path_index]
full_filename = os.path.normpath(os.path.join(folder, filename))
# Prevent path traversal: the requested file must stay within the
# configured model folder. `filename` is an unrestricted ".*" capture,
# so values like "../../../../etc/passwd" would otherwise escape it.
if not folder_paths.is_within_directory(folder, full_filename):
return web.Response(status=403)
full_filename = os.path.join(folder, filename)
previews = self.get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)
# The preview is selected by a glob inside get_model_previews, so a
# companion file (e.g. "model.preview.png") could itself be a symlink
# resolving outside the model folder. Re-validate the file actually
# opened: is_within_directory realpaths it, catching symlink escape.
if isinstance(default_preview, str) and not folder_paths.is_within_directory(folder, default_preview):
return web.Response(status=403)
try:
with Image.open(default_preview) as img:
img_bytes = BytesIO()

View File

@ -6,7 +6,6 @@ import glob
import shutil
import logging
import tempfile
import mimetypes
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
@ -337,20 +336,7 @@ class UserManager():
if not isinstance(path, str):
return path
# User data files are arbitrary user-supplied content and are never
# meant to render inline. Disable MIME sniffing and force a download
# so uploaded markup/scripts can't execute in the app origin (stored
# XSS). Content-Disposition: attachment is the load-bearing guard;
# the content-type override and nosniff are defence in depth.
content_type = mimetypes.guess_type(path)[0] or 'application/octet-stream'
if folder_paths.is_dangerous_content_type(content_type):
content_type = 'application/octet-stream'
return web.FileResponse(path, headers={
"Content-Type": content_type,
"X-Content-Type-Options": "nosniff",
"Content-Disposition": "attachment",
})
return web.FileResponse(path)
@routes.post("/userdata/{file}")
async def post_userdata(request):

View File

@ -33,6 +33,28 @@ class EnumAction(argparse.Action):
setattr(namespace, self.dest, value)
def _positive_int(value: str) -> int:
"""argparse type that rejects zero and negative integers."""
try:
ivalue = int(value)
except ValueError:
raise argparse.ArgumentTypeError(f"{value!r} is not an integer")
if ivalue <= 0:
raise argparse.ArgumentTypeError(f"{value!r} must be a positive integer (> 0)")
return ivalue
def _non_negative_int(value: str) -> int:
"""argparse type that rejects negatives but allows zero (a disable sentinel)."""
try:
ivalue = int(value)
except ValueError:
raise argparse.ArgumentTypeError(f"{value!r} is not an integer")
if ivalue < 0:
raise argparse.ArgumentTypeError(f"{value!r} must be a non-negative integer (>= 0)")
return ivalue
parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
@ -240,10 +262,18 @@ database_default_path = os.path.abspath(
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
parser.add_argument("--enable-asset-hashing", action="store_true", help="Compute blake3 content hashes when scanning assets. Hashing enables future asset-portability features (deduplication, cross-machine model resolution) but adds startup cost and per-output cost on large models directories. Off by default; enable to opt in.")
parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")
# ----- Model download manager (PRD: docs/prd-download-manager.md) -----
parser.add_argument("--download-segments", type=_positive_int, default=8, metavar="N", help="Number of parallel HTTP range segments per file for the model download manager (default: 8).")
parser.add_argument("--download-max-active", type=_positive_int, default=3, metavar="N", help="Maximum number of model downloads running concurrently (default: 3).")
parser.add_argument("--download-max-connections-per-host", type=_positive_int, default=16, metavar="N", help="Maximum simultaneous connections to a single host for the download manager (default: 16).")
parser.add_argument("--download-chunk-size", type=_positive_int, default=4 * 1024 * 1024, metavar="BYTES", help="Read chunk size in bytes for the download manager (default: 4 MiB).")
parser.add_argument("--download-max-bytes", type=_non_negative_int, default=1024 * 1024 * 1024 * 1024, metavar="BYTES", help="Maximum size in bytes of a single download; aborts transfers that exceed it (guards against malicious/non-conforming hosts filling the disk). Set to 0 to disable (default: 1 TiB).")
parser.add_argument("--download-allowed-hosts", type=str, nargs="*", default=[], metavar="HOST", help="Additional hostnames to add to the download manager allowlist (https only). The built-in defaults always include huggingface.co and civitai.com.")
parser.add_argument("--download-allow-any-extension", action="store_true", help="Allow the download manager to fetch files with any extension (default: only known model extensions like .safetensors).")
if comfy.options.args_parsing:
args = parser.parse_args()
else:

View File

@ -167,7 +167,7 @@ class Qwen3VLTokenizer(sd1_clip.SD1Tokenizer):
embed_count = 0
for r in tokens[key_name]:
for i in range(len(r)):
if isinstance(r[i][0], (int, float)) and r[i][0] == 151655: # <|image_pad|>
if r[i][0] == 151655: # <|image_pad|>
if len(images) > embed_count:
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1

View File

@ -104,6 +104,7 @@ _CORE_FEATURE_FLAGS: dict[str, Any] = {
"extension": {"manager": {"supports_v4": True}},
"node_replacements": True,
"assets": args.enable_assets,
"server_side_model_downloads": True,
}
# CLI-provided flags cannot overwrite core flags

View File

@ -1,4 +1,4 @@
from typing import Any, Literal
from typing import Literal
from pydantic import BaseModel, Field
@ -316,36 +316,3 @@ VIDEO_TASKS_EXECUTION_TIME = {
"1080p": 150,
},
}
class SeedAudioConfig(BaseModel):
format: str = Field(default="mp3")
sample_rate: int = Field(default=24000)
speech_rate: int = Field(default=0)
loudness_rate: int = Field(default=0)
pitch_rate: int = Field(default=0)
class SeedAudioReference(BaseModel):
speaker: str | None = Field(default=None)
audio_data: str | None = Field(default=None)
audio_url: str | None = Field(default=None)
image_data: str | None = Field(default=None)
image_url: str | None = Field(default=None)
class SeedAudioRequest(BaseModel):
model: str = Field(default="seed-audio-1.0")
text_prompt: str = Field(...)
references: list[SeedAudioReference] | None = Field(default=None)
audio_config: SeedAudioConfig = Field(default_factory=SeedAudioConfig)
watermark: dict[str, Any] = Field(default_factory=dict)
class SeedAudioResponse(BaseModel):
audio: str | None = Field(default=None)
url: str | None = Field(default=None)
duration: float | None = Field(default=None)
original_duration: float | None = Field(default=None)
code: int | None = Field(default=None)
message: str | None = Field(default=None)

View File

@ -121,7 +121,6 @@ class GeminiGenerationConfig(BaseModel):
topK: int | None = Field(None, ge=1)
topP: float | None = Field(None, ge=0.0, le=1.0)
thinkingConfig: GeminiThinkingConfig | None = Field(None)
responseModalities: list[str] | None = Field(None)
class GeminiImageOutputOptions(BaseModel):

View File

@ -33,6 +33,53 @@ class IdeogramColorPalette(
)
class ImageRequest(BaseModel):
aspect_ratio: Optional[str] = Field(
None,
description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.",
)
color_palette: Optional[Dict[str, Any]] = Field(
None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.'
)
magic_prompt_option: Optional[str] = Field(
None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')."
)
model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')")
negative_prompt: Optional[str] = Field(
None,
description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.',
)
num_images: Optional[int] = Field(
1,
description='Optional. Number of images to generate (1-8). Defaults to 1.',
ge=1,
le=8,
)
prompt: str = Field(
..., description='Required. The prompt to use to generate the image.'
)
resolution: Optional[str] = Field(
None,
description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.",
)
seed: Optional[int] = Field(
None,
description='Optional. A number between 0 and 2147483647.',
ge=0,
le=2147483647,
)
style_type: Optional[str] = Field(
None,
description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.",
)
class IdeogramGenerateRequest(BaseModel):
image_request: ImageRequest = Field(
..., description='The image generation request parameters.'
)
class Datum(BaseModel):
is_image_safe: Optional[bool] = Field(
None, description='Indicates whether the image is considered safe.'
@ -66,6 +113,20 @@ class StyleCode(RootModel[str]):
root: str = Field(..., pattern='^[0-9A-Fa-f]{8}$')
class Datum1(BaseModel):
is_image_safe: Optional[bool] = None
prompt: Optional[str] = None
resolution: Optional[str] = None
seed: Optional[int] = None
style_type: Optional[str] = None
url: Optional[str] = None
class IdeogramV3IdeogramResponse(BaseModel):
created: Optional[datetime] = None
data: Optional[List[Datum1]] = None
class RenderingSpeed1(str, Enum):
TURBO = 'TURBO'
DEFAULT = 'DEFAULT'

View File

@ -0,0 +1,147 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field, confloat
class StabilityFormat(str, Enum):
png = 'png'
jpeg = 'jpeg'
webp = 'webp'
class StabilityAspectRatio(str, Enum):
ratio_1_1 = "1:1"
ratio_16_9 = "16:9"
ratio_9_16 = "9:16"
ratio_3_2 = "3:2"
ratio_2_3 = "2:3"
ratio_5_4 = "5:4"
ratio_4_5 = "4:5"
ratio_21_9 = "21:9"
ratio_9_21 = "9:21"
def get_stability_style_presets(include_none=True):
presets = []
if include_none:
presets.append("None")
return presets + [x.value for x in StabilityStylePreset]
class StabilityStylePreset(str, Enum):
_3d_model = "3d-model"
analog_film = "analog-film"
anime = "anime"
cinematic = "cinematic"
comic_book = "comic-book"
digital_art = "digital-art"
enhance = "enhance"
fantasy_art = "fantasy-art"
isometric = "isometric"
line_art = "line-art"
low_poly = "low-poly"
modeling_compound = "modeling-compound"
neon_punk = "neon-punk"
origami = "origami"
photographic = "photographic"
pixel_art = "pixel-art"
tile_texture = "tile-texture"
class Stability_SD3_5_Model(str, Enum):
sd3_5_large = "sd3.5-large"
# sd3_5_large_turbo = "sd3.5-large-turbo"
sd3_5_medium = "sd3.5-medium"
class Stability_SD3_5_GenerationMode(str, Enum):
text_to_image = "text-to-image"
image_to_image = "image-to-image"
class StabilityStable3_5Request(BaseModel):
model: str = Field(...)
mode: str = Field(...)
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
aspect_ratio: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
style_preset: Optional[str] = Field(None)
cfg_scale: float = Field(...)
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
class StabilityUpscaleConservativeRequest(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
creativity: Optional[confloat(ge=0.2, le=0.5)] = Field(None)
class StabilityUpscaleCreativeRequest(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
creativity: Optional[confloat(ge=0.1, le=0.5)] = Field(None)
style_preset: Optional[str] = Field(None)
class StabilityStableUltraRequest(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
aspect_ratio: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
style_preset: Optional[str] = Field(None)
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
class StabilityStableUltraResponse(BaseModel):
image: Optional[str] = Field(None)
finish_reason: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class StabilityResultsGetResponse(BaseModel):
image: Optional[str] = Field(None)
finish_reason: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
id: Optional[str] = Field(None)
name: Optional[str] = Field(None)
errors: Optional[list[str]] = Field(None)
status: Optional[str] = Field(None)
result: Optional[str] = Field(None)
class StabilityAsyncResponse(BaseModel):
id: Optional[str] = Field(None)
class StabilityTextToAudioRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
duration: int = Field(190, ge=1, le=190)
seed: int = Field(0, ge=0, le=4294967294)
steps: int = Field(8, ge=4, le=8)
output_format: str = Field("wav")
class StabilityAudioToAudioRequest(StabilityTextToAudioRequest):
strength: float = Field(0.01, ge=0.01, le=1.0)
class StabilityAudioInpaintRequest(StabilityTextToAudioRequest):
mask_start: int = Field(30, ge=0, le=190)
mask_end: int = Field(190, ge=0, le=190)
class StabilityAudioResponse(BaseModel):
audio: Optional[str] = Field(None)

View File

@ -1,4 +1,3 @@
import base64
import hashlib
import logging
import math
@ -21,10 +20,6 @@ from comfy_api_nodes.apis.bytedance import (
GetAssetResponse,
Image2VideoTaskCreationRequest,
ImageTaskCreationResponse,
SeedAudioConfig,
SeedAudioReference,
SeedAudioRequest,
SeedAudioResponse,
Seedance2TaskCreationRequest,
SeedanceCreateAssetRequest,
SeedanceCreateAssetResponse,
@ -48,8 +43,6 @@ from comfy_api_nodes.apis.bytedance import (
)
from comfy_api_nodes.util import (
ApiEndpoint,
audio_bytes_to_audio_input,
audio_input_to_mp3,
download_url_to_image_tensor,
download_url_to_video_output,
downscale_image_tensor_by_max_side,
@ -58,14 +51,11 @@ from comfy_api_nodes.util import (
image_tensor_pair_to_batch,
poll_op,
sync_op,
tensor_to_base64_string,
upload_audio_to_comfyapi,
upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
upscale_image_tensor_to_min_pixels,
upscale_video_to_min_pixels,
validate_audio_duration,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_string,
@ -2484,311 +2474,6 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode):
return IO.NodeOutput(asset_id, resolved_group)
MODE_TEXT = "text only"
MODE_AUDIO = "audio reference"
MODE_IMAGE = "image reference"
MODE_SPEAKER = "preset voice"
# (speaker_id, display_label) for built-in TTS 2.0 voices; resolvable ids are account-scoped.
SEED_AUDIO_PRESET_VOICES: list[tuple[str, str]] = [
("zh_female_vv_uranus_bigtts", "Vivi (Female, multilingual)"),
("zh_female_xiaohe_uranus_bigtts", "Mindy (Female, multilingual)"),
("en_female_stokie_uranus_bigtts", "Stokie (Female, English)"),
("en_female_dacey_uranus_bigtts", "Dacey (Female, English)"),
("en_male_tim_uranus_bigtts", "Tim (Male, English)"),
("zh_male_m191_uranus_bigtts", "Kian (Male, multilingual)"),
("zh_male_taocheng_uranus_bigtts", "Cedric (Male, multilingual)"),
("zh_male_sophie_uranus_bigtts", "Sophie (Female, multilingual)"),
("zh_female_yingyujiaoxue_uranus_bigtts", "Jean (Female, multilingual)"),
("zh_male_dayi_uranus_bigtts", "Magnus (Male, multilingual)"),
("zh_female_mizai_uranus_bigtts", "Mabel (Female, multilingual)"),
("zh_female_jitangnv_uranus_bigtts", "Nadia (Female, multilingual)"),
("zh_female_meilinvyou_uranus_bigtts", "Opal (Female, multilingual)"),
("zh_female_liuchangnv_uranus_bigtts", "Pearl (Female, multilingual)"),
("zh_male_ruyayichen_uranus_bigtts", "Quentin (Male, multilingual)"),
("zh_female_vivo_uranus_bigtts", "Vienna (Female, multilingual)"),
("zh_female_xiaoai_uranus_bigtts", "Alina (Female, multilingual)"),
("zh_female_cancan_uranus_bigtts", "Corinne (Female, multilingual)"),
("zh_female_tianmeixiaoyuan_uranus_bigtts", "Esther (Female, multilingual)"),
("zh_female_tianmeitaozi_uranus_bigtts", "Freya (Female, multilingual)"),
("zh_female_shuangkuaisisi_uranus_bigtts", "Gigi (Female, multilingual)"),
("zh_female_peiqi_uranus_bigtts", "Holly (Female, multilingual)"),
("zh_female_xiaoxue_uranus_bigtts", "Lyla (Female, multilingual)"),
("zh_female_yuanqi_uranus_bigtts", "Daisy (Female, multilingual)"),
("zh_female_kefunvsheng_uranus_bigtts", "Tracy (Female, multilingual)"),
("zh_male_shaonianzixin_uranus_bigtts", "Jess (Male, multilingual)"),
("zh_female_linjianvhai_uranus_bigtts", "Pinky (Female, multilingual)"),
("zh_female_kiwi_uranus_bigtts", "Sweety (Female, multilingual)"),
("zh_female_sajiaoxuemei_uranus_bigtts", "Sandy (Female, multilingual)"),
("de_male_seven_uranus_bigtts", "Sven (Male, German)"),
("jp_female_minimi_uranus_bigtts", "Minimi (Female, Japanese)"),
("fr_male_usseau_uranus_bigtts", "Usseau (Male, French)"),
("es_male_felipe_uranus_bigtts", "Felipe (Male, Spanish)"),
("id_male_han_uranus_bigtts", "Han (Male, Indonesian)"),
("pt_male_martins_uranus_bigtts", "Martins (Male, Portuguese)"),
("it_male_enzo_uranus_bigtts", "Enzo (Male, Italian)"),
("kr_male_shane_uranus_bigtts", "Shane (Male, Korean)"),
("zh_male_liufei_uranus_bigtts", "Felix (Male, Chinese)"),
("zh_female_qingxinnvsheng_uranus_bigtts", "Celeste (Female, Chinese)"),
("zh_male_sunwukong_uranus_bigtts", "Monkey King (Male, Chinese)"),
]
SEED_AUDIO_VOICE_OPTIONS = [label for _, label in SEED_AUDIO_PRESET_VOICES]
SEED_AUDIO_VOICE_MAP = {label: speaker_id for speaker_id, label in SEED_AUDIO_PRESET_VOICES}
_AUDIO_TAG_RE = re.compile(r"@Audio(\d+)", re.IGNORECASE)
def max_audio_tag(prompt: str) -> int:
"""Highest N referenced as @AudioN in the prompt (0 if none)."""
nums = [int(m) for m in _AUDIO_TAG_RE.findall(prompt or "")]
return max(nums) if nums else 0
def connected_audio_indices(reference_mode: dict) -> list[int]:
"""Indices (1-based) of connected reference_audio sockets, in order."""
return [
i
for i in range(1, 3 + 1)
if reference_mode.get(f"reference_audio_{i}") is not None
]
def validate_seed_audio_inputs(
text_prompt: str,
mode: str,
audio_indices: list[int],
has_image: bool,
preset_voice: str | None = None,
) -> None:
validate_string(text_prompt, field_name="text_prompt", min_length=1, max_length=3000)
max_tag = max_audio_tag(text_prompt)
if mode == MODE_TEXT:
if max_tag:
raise ValueError(
f"The prompt references @Audio{max_tag}, but reference mode is '{MODE_TEXT}'. "
f"Switch to '{MODE_AUDIO}' and connect the reference clip(s)."
)
elif mode == MODE_AUDIO:
if not audio_indices:
raise ValueError(
f"Reference mode '{MODE_AUDIO}' requires at least one reference_audio input "
f"(or switch to '{MODE_TEXT}')."
)
if audio_indices != list(range(1, len(audio_indices) + 1)):
raise ValueError(
"Connect reference_audio inputs in order without gaps: reference_audio_1, then _2, then _3."
)
if max_tag > len(audio_indices):
raise ValueError(
f"The prompt references @Audio{max_tag}, but only {len(audio_indices)} "
f"reference audio(s) are connected."
)
elif mode == MODE_IMAGE:
if not has_image:
raise ValueError(f"Reference mode '{MODE_IMAGE}' requires a reference_image input.")
if max_tag:
raise ValueError(
f"@AudioN tags are not used in '{MODE_IMAGE}' mode; the prompt should contain "
f"only the text to synthesize."
)
elif mode == MODE_SPEAKER:
if not preset_voice or preset_voice not in SEED_AUDIO_VOICE_MAP:
raise ValueError(f"Reference mode '{MODE_SPEAKER}' requires selecting a preset voice.")
if max_tag > 1:
raise ValueError(
f"'{MODE_SPEAKER}' mode uses a single voice, so @Audio{max_tag} is out of range. "
f"Remove the @AudioN tags — the whole prompt is read in the selected voice."
)
else:
raise ValueError(f"Unknown reference mode: {mode!r}")
class ByteDanceSeedAudioNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="ByteDanceSeedAudio",
display_name="ByteDance Seed Audio 1.0",
category="partner/audio/ByteDance",
description=(
"Generate speech, music, sound effects and multi-speaker dialogue from a single prompt "
"with ByteDance Seed Audio 1.0. Describe the voice(s), emotion, ambience, background music "
"and sound effects in the prompt, and include the lines to speak. Optionally pick a built-in "
"preset voice, clone voices from up to 3 reference clips (tagged @Audio1-3 in the prompt), "
"or derive a voice from a character image. Up to 2 minutes of audio per run."
),
inputs=[
IO.String.Input(
"text_prompt",
multiline=True,
default="",
tooltip=(
"Describe the voice(s), emotion, pacing, ambience, background music and sound "
"effects, and include the lines to speak (name characters inline for dialogue). "
"In 'audio reference' mode, refer to connected clips by order as @Audio1, @Audio2, "
"@Audio3. Maximum 3000 characters."
),
),
IO.DynamicCombo.Input(
"reference_mode",
options=[
IO.DynamicCombo.Option(MODE_TEXT, []),
IO.DynamicCombo.Option(
MODE_AUDIO,
[
IO.Audio.Input(
"reference_audio_1",
optional=True,
tooltip="Reference clip for voice cloning, tagged @Audio1 in the prompt. "
"Up to 30s.",
),
IO.Audio.Input(
"reference_audio_2",
optional=True,
tooltip="Reference clip tagged @Audio2 in the prompt. Up to 30s.",
),
IO.Audio.Input(
"reference_audio_3",
optional=True,
tooltip="Reference clip tagged @Audio3 in the prompt. Up to 30s.",
),
],
),
IO.DynamicCombo.Option(
MODE_IMAGE,
[
IO.Image.Input(
"reference_image",
optional=True,
tooltip="A single character image; the model derives a voice from it. "
"Cannot be combined with reference audio.",
),
],
),
IO.DynamicCombo.Option(
MODE_SPEAKER,
[
IO.Combo.Input(
"preset_voice",
options=SEED_AUDIO_VOICE_OPTIONS,
default=SEED_AUDIO_VOICE_OPTIONS[0],
tooltip="A built-in TTS 2.0 voice that reads the prompt. No reference "
"clip needed, and @AudioN tags are not used in this mode.",
),
],
),
],
tooltip=(
"How to condition the voice: 'text only' (describe everything in the prompt), "
"'audio reference' (clone up to 3 voices, tagged @Audio1-3), 'image reference' "
"(derive a voice from one character image), or 'preset voice' (pick a built-in "
"named voice that reads the prompt)."
),
),
IO.Combo.Input(
"sample_rate",
options=["8000", "16000", "24000", "32000", "44100", "48000"],
default="24000",
tooltip="Output sample rate in Hz.",
),
IO.Int.Input(
"speech_rate",
default=0,
min=-50,
max=100,
tooltip="Speaking speed. 0 = normal, 100 = 2.0x, -50 = 0.5x.",
),
IO.Int.Input(
"loudness_rate",
default=0,
min=-50,
max=100,
tooltip="Loudness. 0 = normal, 100 = 2.0x, -50 = 0.5x.",
),
IO.Int.Input(
"pitch_rate",
default=0,
min=-12,
max=12,
tooltip="Pitch shift in semitones (-12 to 12).",
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Audio.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd": 0.2145, "format":{"suffix":"/minute","approximate":true}}""",
),
)
@classmethod
async def execute(
cls,
text_prompt: str,
reference_mode: dict,
sample_rate: str,
speech_rate: int,
loudness_rate: int,
pitch_rate: int,
seed: int,
) -> IO.NodeOutput:
mode = reference_mode["reference_mode"]
audio_indices = connected_audio_indices(reference_mode)
image = reference_mode.get("reference_image")
preset_voice = reference_mode.get("preset_voice")
validate_seed_audio_inputs(text_prompt, mode, audio_indices, image is not None, preset_voice)
references: list[SeedAudioReference] | None = None
if mode == MODE_AUDIO:
references = []
for i in audio_indices:
clip = reference_mode[f"reference_audio_{i}"]
validate_audio_duration(clip, max_duration=30.0)
mp3_bytes = audio_input_to_mp3(clip).getvalue()
references.append(SeedAudioReference(audio_data=base64.b64encode(mp3_bytes).decode("utf-8")))
elif mode == MODE_IMAGE:
image = upscale_image_tensor_to_min_pixels(image, 160_000)
references = [SeedAudioReference(image_data=tensor_to_base64_string(image, mime_type="image/png"))]
elif mode == MODE_SPEAKER:
references = [SeedAudioReference(speaker=SEED_AUDIO_VOICE_MAP[preset_voice])]
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/byteplus/api/v3/tts/create", method="POST"),
response_model=SeedAudioResponse,
data=SeedAudioRequest(
text_prompt=text_prompt,
references=references,
audio_config=SeedAudioConfig(
sample_rate=int(sample_rate),
speech_rate=speech_rate,
loudness_rate=loudness_rate,
pitch_rate=pitch_rate,
),
),
)
if not response.audio:
raise Exception(
f"Seed Audio returned no audio (code={response.code}): {response.message}"
)
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response.audio)))
class ByteDanceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -2805,7 +2490,6 @@ class ByteDanceExtension(ComfyExtension):
ByteDance2ReferenceNode,
ByteDanceCreateImageAsset,
ByteDanceCreateVideoAsset,
ByteDanceSeedAudioNode,
]

View File

@ -13,7 +13,7 @@ import torch
from typing_extensions import override
import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl, Types
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.gemini import (
GeminiContent,
GeminiFileData,
@ -37,7 +37,6 @@ from comfy_api_nodes.util import (
audio_to_base64_string,
bytesio_to_image_tensor,
download_url_to_image_tensor,
download_url_to_video_output,
get_number_of_images,
sync_op,
tensor_to_base64_string,
@ -46,7 +45,6 @@ from comfy_api_nodes.util import (
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
validate_video_duration,
video_to_base64_string,
)
@ -231,29 +229,10 @@ async def get_image_from_response(response: GeminiGenerateContentResponse, thoug
return torch.cat(image_tensors, dim=0)
async def get_video_from_response(
response: GeminiGenerateContentResponse, cls: type[IO.ComfyNode] | None = None
) -> InputImpl.VideoFromFile:
parts = get_parts_by_type(response, "video/*")
for part in parts:
if part.inlineData and part.inlineData.data:
return InputImpl.VideoFromFile(BytesIO(base64.b64decode(part.inlineData.data)))
if part.fileData and part.fileData.fileUri:
return await download_url_to_video_output(part.fileData.fileUri, cls=cls)
model_message = get_text_from_response(response).strip()
if model_message:
raise ValueError(f"Gemini did not generate a video. Model response: {model_message}")
raise ValueError(
"Gemini did not generate a video. Try rephrasing your prompt, "
"shortening the requested duration, or reducing the number of input images/videos."
)
def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None:
if not response.modelVersion:
return None
# Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing
output_video_tokens_price = 0.0
if response.modelVersion == "gemini-2.5-pro":
input_tokens_price = 1.25
output_text_tokens_price = 10.0
@ -286,11 +265,6 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
input_tokens_price = 0.25
output_text_tokens_price = 1.50
output_image_tokens_price = 30.0
elif response.modelVersion == "gemini-omni-flash-preview":
input_tokens_price = 2.145
output_text_tokens_price = 12.87
output_image_tokens_price = 0.0
output_video_tokens_price = 25.025
else:
return None
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
@ -298,8 +272,6 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
for i in response.usageMetadata.candidatesTokensDetails:
if i.modality == Modality.IMAGE:
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
elif i.modality == Modality.VIDEO:
final_price += output_video_tokens_price * i.tokenCount # for Omni Flash
else:
final_price += output_text_tokens_price * i.tokenCount
if response.usageMetadata.thoughtsTokenCount:
@ -1559,149 +1531,6 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
)
OMNI_MAX_IMAGES = 14
OMNI_MAX_VIDEOS = 3
OMNI_MODELS: dict[str, str] = {
"Omni Flash": "gemini-omni-flash-preview",
}
def _omni_flash_inputs() -> list[Input]:
"""Per-model inputs for the Omni video DynamicCombo (prompt + reference media + sampling)."""
return [
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Describe the video to generate. Specify the length and aspect ratio directly in the "
'prompt, e.g. "a 6-second clip in 16:9". Length may be 3-10 seconds; the aspect ratio must be '
"16:9 (landscape) or 9:16 (portrait). The output is 720p, 24 FPS, with audio.",
),
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, OMNI_MAX_IMAGES + 1)],
min=0,
),
tooltip=f"Optional reference image(s) to guide or animate the video. Up to {OMNI_MAX_IMAGES} images.",
),
IO.Autogrow.Input(
"videos",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("video"),
names=[f"video_{i}" for i in range(1, OMNI_MAX_VIDEOS + 1)],
min=0,
),
tooltip=f"Optional reference video(s) to guide or edit. Up to {OMNI_MAX_VIDEOS} videos, "
f"each up to 10 seconds long.",
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.01,
tooltip="Controls randomness. Lower is more focused/deterministic, higher is more varied.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=0.95,
min=0.0,
max=1.0,
step=0.01,
tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.",
advanced=True,
),
]
class GeminiVideoOmni(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GeminiVideoOmni",
display_name="Google Gemini Omni (Video)",
category="partner/video/Gemini",
essentials_category="Video Generation",
description="Generate a video with audio from a text prompt using Google's Gemini Omni Flash model. "
"Optionally provide reference images and/or videos to guide or edit the result. Describe the desired "
"length (3-10s) and aspect ratio (16:9 or 9:16) directly in the prompt.",
inputs=[
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Omni Flash", _omni_flash_inputs()),
],
tooltip="The Gemini video model used to generate the video.",
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.Video.Output(),
IO.String.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr='{"type":"usd","usd":0.146,"format":{"suffix":"/second","approximate":true}}'
),
)
@classmethod
async def execute(cls, model: dict, seed: int) -> IO.NodeOutput:
prompt = model.get("prompt") or ""
validate_string(prompt, strip_whitespace=True, min_length=1)
model_id = OMNI_MODELS[model["model"]]
images = [t for t in (model.get("images") or {}).values() if t is not None]
videos = [v for v in (model.get("videos") or {}).values() if v is not None]
if sum(get_number_of_images(t) for t in images) > OMNI_MAX_IMAGES:
raise ValueError(f"The current maximum number of supported images is {OMNI_MAX_IMAGES}.")
if len(videos) > OMNI_MAX_VIDEOS:
raise ValueError(f"The current maximum number of supported videos is {OMNI_MAX_VIDEOS}.")
for video in videos:
validate_video_duration(video, max_duration=10)
parts: list[GeminiPart] = []
if images or videos:
parts.extend(await build_gemini_media_parts(cls, images, [], videos))
parts.append(GeminiPart(text=prompt))
response = await sync_op(
cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"),
data=GeminiGenerateContentRequest(
contents=[GeminiContent(role=GeminiRole.user, parts=parts)],
generationConfig=GeminiGenerationConfig(
responseModalities=["TEXT", "VIDEO"],
temperature=model.get("temperature", 1.0),
topP=model.get("top_p", 0.95),
),
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(
await get_video_from_response(response, cls=cls),
get_text_from_response(response),
)
class GeminiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -1712,7 +1541,6 @@ class GeminiExtension(ComfyExtension):
GeminiImage2,
GeminiNanoBanana2,
GeminiNanoBanana2V2,
GeminiVideoOmni,
GeminiInputFiles,
]

View File

@ -5,7 +5,9 @@ from PIL import Image
import numpy as np
import torch
from comfy_api_nodes.apis.ideogram import (
IdeogramGenerateRequest,
IdeogramGenerateResponse,
ImageRequest,
IdeogramV3Request,
IdeogramV3EditRequest,
IdeogramV4Request,
@ -19,6 +21,101 @@ from comfy_api_nodes.util import (
validate_string,
)
V1_V1_RES_MAP = {
"Auto":"AUTO",
"512 x 1536":"RESOLUTION_512_1536",
"576 x 1408":"RESOLUTION_576_1408",
"576 x 1472":"RESOLUTION_576_1472",
"576 x 1536":"RESOLUTION_576_1536",
"640 x 1024":"RESOLUTION_640_1024",
"640 x 1344":"RESOLUTION_640_1344",
"640 x 1408":"RESOLUTION_640_1408",
"640 x 1472":"RESOLUTION_640_1472",
"640 x 1536":"RESOLUTION_640_1536",
"704 x 1152":"RESOLUTION_704_1152",
"704 x 1216":"RESOLUTION_704_1216",
"704 x 1280":"RESOLUTION_704_1280",
"704 x 1344":"RESOLUTION_704_1344",
"704 x 1408":"RESOLUTION_704_1408",
"704 x 1472":"RESOLUTION_704_1472",
"720 x 1280":"RESOLUTION_720_1280",
"736 x 1312":"RESOLUTION_736_1312",
"768 x 1024":"RESOLUTION_768_1024",
"768 x 1088":"RESOLUTION_768_1088",
"768 x 1152":"RESOLUTION_768_1152",
"768 x 1216":"RESOLUTION_768_1216",
"768 x 1232":"RESOLUTION_768_1232",
"768 x 1280":"RESOLUTION_768_1280",
"768 x 1344":"RESOLUTION_768_1344",
"832 x 960":"RESOLUTION_832_960",
"832 x 1024":"RESOLUTION_832_1024",
"832 x 1088":"RESOLUTION_832_1088",
"832 x 1152":"RESOLUTION_832_1152",
"832 x 1216":"RESOLUTION_832_1216",
"832 x 1248":"RESOLUTION_832_1248",
"864 x 1152":"RESOLUTION_864_1152",
"896 x 960":"RESOLUTION_896_960",
"896 x 1024":"RESOLUTION_896_1024",
"896 x 1088":"RESOLUTION_896_1088",
"896 x 1120":"RESOLUTION_896_1120",
"896 x 1152":"RESOLUTION_896_1152",
"960 x 832":"RESOLUTION_960_832",
"960 x 896":"RESOLUTION_960_896",
"960 x 1024":"RESOLUTION_960_1024",
"960 x 1088":"RESOLUTION_960_1088",
"1024 x 640":"RESOLUTION_1024_640",
"1024 x 768":"RESOLUTION_1024_768",
"1024 x 832":"RESOLUTION_1024_832",
"1024 x 896":"RESOLUTION_1024_896",
"1024 x 960":"RESOLUTION_1024_960",
"1024 x 1024":"RESOLUTION_1024_1024",
"1088 x 768":"RESOLUTION_1088_768",
"1088 x 832":"RESOLUTION_1088_832",
"1088 x 896":"RESOLUTION_1088_896",
"1088 x 960":"RESOLUTION_1088_960",
"1120 x 896":"RESOLUTION_1120_896",
"1152 x 704":"RESOLUTION_1152_704",
"1152 x 768":"RESOLUTION_1152_768",
"1152 x 832":"RESOLUTION_1152_832",
"1152 x 864":"RESOLUTION_1152_864",
"1152 x 896":"RESOLUTION_1152_896",
"1216 x 704":"RESOLUTION_1216_704",
"1216 x 768":"RESOLUTION_1216_768",
"1216 x 832":"RESOLUTION_1216_832",
"1232 x 768":"RESOLUTION_1232_768",
"1248 x 832":"RESOLUTION_1248_832",
"1280 x 704":"RESOLUTION_1280_704",
"1280 x 720":"RESOLUTION_1280_720",
"1280 x 768":"RESOLUTION_1280_768",
"1280 x 800":"RESOLUTION_1280_800",
"1312 x 736":"RESOLUTION_1312_736",
"1344 x 640":"RESOLUTION_1344_640",
"1344 x 704":"RESOLUTION_1344_704",
"1344 x 768":"RESOLUTION_1344_768",
"1408 x 576":"RESOLUTION_1408_576",
"1408 x 640":"RESOLUTION_1408_640",
"1408 x 704":"RESOLUTION_1408_704",
"1472 x 576":"RESOLUTION_1472_576",
"1472 x 640":"RESOLUTION_1472_640",
"1472 x 704":"RESOLUTION_1472_704",
"1536 x 512":"RESOLUTION_1536_512",
"1536 x 576":"RESOLUTION_1536_576",
"1536 x 640":"RESOLUTION_1536_640",
}
V1_V2_RATIO_MAP = {
"1:1":"ASPECT_1_1",
"4:3":"ASPECT_4_3",
"3:4":"ASPECT_3_4",
"16:9":"ASPECT_16_9",
"9:16":"ASPECT_9_16",
"2:1":"ASPECT_2_1",
"1:2":"ASPECT_1_2",
"3:2":"ASPECT_3_2",
"2:3":"ASPECT_2_3",
"4:5":"ASPECT_4_5",
"5:4":"ASPECT_5_4",
}
V3_RATIO_MAP = {
"1:3":"1x3",
@ -132,6 +229,298 @@ async def download_and_process_images(image_urls):
return stacked_tensors
class IdeogramV1(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="IdeogramV1",
display_name="Ideogram V1",
category="partner/image/Ideogram",
description="Generates images using the Ideogram V1 model.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
IO.Boolean.Input(
"turbo",
default=False,
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
),
IO.Combo.Input(
"aspect_ratio",
options=list(V1_V2_RATIO_MAP.keys()),
default="1:1",
tooltip="The aspect ratio for image generation.",
optional=True,
),
IO.Combo.Input(
"magic_prompt_option",
options=["AUTO", "ON", "OFF"],
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
advanced=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
control_after_generate=True,
display_mode=IO.NumberDisplay.number,
optional=True,
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Description of what to exclude from the image",
optional=True,
),
IO.Int.Input(
"num_images",
default=1,
min=1,
max=8,
step=1,
display_mode=IO.NumberDisplay.number,
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]),
expr="""
(
$n := widgets.num_images;
$base := (widgets.turbo = true) ? 0.0286 : 0.0858;
{"type":"usd","usd": $round($base * $n, 2)}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt,
turbo=False,
aspect_ratio="1:1",
magic_prompt_option="AUTO",
seed=0,
negative_prompt="",
num_images=1,
):
# Determine the model based on turbo setting
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
model = "V_1_TURBO" if turbo else "V_1"
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
num_images=num_images,
seed=seed,
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
negative_prompt=negative_prompt if negative_prompt else None,
)
),
max_retries=1,
)
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
return IO.NodeOutput(await download_and_process_images(image_urls))
class IdeogramV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="IdeogramV2",
display_name="Ideogram V2",
category="partner/image/Ideogram",
description="Generates images using the Ideogram V2 model.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
IO.Boolean.Input(
"turbo",
default=False,
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
),
IO.Combo.Input(
"aspect_ratio",
options=list(V1_V2_RATIO_MAP.keys()),
default="1:1",
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
optional=True,
),
IO.Combo.Input(
"resolution",
options=list(V1_V1_RES_MAP.keys()),
default="Auto",
tooltip="The resolution for image generation. "
"If not set to AUTO, this overrides the aspect_ratio setting.",
optional=True,
),
IO.Combo.Input(
"magic_prompt_option",
options=["AUTO", "ON", "OFF"],
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
advanced=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
control_after_generate=True,
display_mode=IO.NumberDisplay.number,
optional=True,
),
IO.Combo.Input(
"style_type",
options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
default="NONE",
tooltip="Style type for generation (V2 only)",
optional=True,
advanced=True,
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Description of what to exclude from the image",
optional=True,
),
IO.Int.Input(
"num_images",
default=1,
min=1,
max=8,
step=1,
display_mode=IO.NumberDisplay.number,
optional=True,
),
#"color_palette": (
# IO.STRING,
# {
# "multiline": False,
# "default": "",
# "tooltip": "Color palette preset name or hex colors with weights",
# },
#),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]),
expr="""
(
$n := widgets.num_images;
$base := (widgets.turbo = true) ? 0.0715 : 0.1144;
{"type":"usd","usd": $round($base * $n, 2)}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt,
turbo=False,
aspect_ratio="1:1",
resolution="Auto",
magic_prompt_option="AUTO",
seed=0,
style_type="NONE",
negative_prompt="",
num_images=1,
color_palette="",
):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
resolution = V1_V1_RES_MAP.get(resolution, None)
# Determine the model based on turbo setting
model = "V_2_TURBO" if turbo else "V_2"
# Handle resolution vs aspect_ratio logic
# If resolution is not AUTO, it overrides aspect_ratio
final_resolution = None
final_aspect_ratio = None
if resolution != "AUTO":
final_resolution = resolution
else:
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
num_images=num_images,
seed=seed,
aspect_ratio=final_aspect_ratio,
resolution=final_resolution,
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
style_type=style_type if style_type != "NONE" else None,
negative_prompt=negative_prompt if negative_prompt else None,
color_palette=color_palette if color_palette else None,
)
),
max_retries=1,
)
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
return IO.NodeOutput(await download_and_process_images(image_urls))
class IdeogramV3(IO.ComfyNode):
@classmethod
@ -528,6 +917,8 @@ class IdeogramExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
IdeogramV1,
IdeogramV2,
IdeogramV3,
IdeogramV4,
]

View File

@ -0,0 +1,932 @@
from inspect import cleandoc
from typing import Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, Input, IO
from comfy_api_nodes.apis.stability import (
StabilityUpscaleConservativeRequest,
StabilityUpscaleCreativeRequest,
StabilityAsyncResponse,
StabilityResultsGetResponse,
StabilityStable3_5Request,
StabilityStableUltraRequest,
StabilityStableUltraResponse,
StabilityAspectRatio,
Stability_SD3_5_Model,
Stability_SD3_5_GenerationMode,
get_stability_style_presets,
StabilityTextToAudioRequest,
StabilityAudioToAudioRequest,
StabilityAudioInpaintRequest,
StabilityAudioResponse,
)
from comfy_api_nodes.util import (
validate_audio_duration,
validate_string,
audio_input_to_mp3,
bytesio_to_image_tensor,
tensor_to_bytesio,
audio_bytes_to_audio_input,
sync_op,
poll_op,
ApiEndpoint,
)
import torch
import base64
from io import BytesIO
from enum import Enum
class StabilityPollStatus(str, Enum):
finished = "finished"
in_progress = "in_progress"
failed = "failed"
def get_async_dummy_status(x: StabilityResultsGetResponse):
if x.name is not None or x.errors is not None:
return StabilityPollStatus.failed
elif x.finish_reason is not None:
return StabilityPollStatus.finished
return StabilityPollStatus.in_progress
class StabilityStableImageUltraNode(IO.ComfyNode):
"""
Generates images synchronously based on prompt and resolution.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityStableImageUltraNode",
display_name="Stability AI Stable Image Ultra",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
"elements, colors, and subjects will lead to better results. " +
"To control the weight of a given word use the format `(word:weight)`," +
"where `word` is the word you'd like to control the weight of and `weight`" +
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
"would convey a sky that was blue and green, but more green than blue.",
),
IO.Combo.Input(
"aspect_ratio",
options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.",
),
IO.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
advanced=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.Image.Input(
"image",
optional=True,
),
IO.String.Input(
"negative_prompt",
default="",
tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
advanced=True,
),
IO.Float.Input(
"image_denoise",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.08}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
aspect_ratio: str,
style_preset: str,
seed: int,
image: Optional[torch.Tensor] = None,
negative_prompt: str = "",
image_denoise: Optional[float] = 0.5,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
if image is not None:
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
else:
image_denoise = None
if not negative_prompt:
negative_prompt = None
if style_preset == "None":
style_preset = None
files = {
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityStableUltraRequest(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
seed=seed,
strength=image_denoise,
style_preset=style_preset,
),
files=files,
content_type="multipart/form-data",
)
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return IO.NodeOutput(returned_image)
class StabilityStableImageSD_3_5Node(IO.ComfyNode):
"""
Generates images synchronously based on prompt and resolution.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityStableImageSD_3_5Node",
display_name="Stability AI Stable Diffusion 3.5 Image",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
IO.Combo.Input(
"model",
options=Stability_SD3_5_Model,
),
IO.Combo.Input(
"aspect_ratio",
options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.",
),
IO.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
advanced=True,
),
IO.Float.Input(
"cfg_scale",
default=4.0,
min=1.0,
max=10.0,
step=0.1,
tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.Image.Input(
"image",
optional=True,
),
IO.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
advanced=True,
),
IO.Float.Input(
"image_denoise",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$contains(widgets.model,"large")
? {"type":"usd","usd":0.065}
: {"type":"usd","usd":0.035}
)
""",
),
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
aspect_ratio: str,
style_preset: str,
seed: int,
cfg_scale: float,
image: Optional[torch.Tensor] = None,
negative_prompt: str = "",
image_denoise: Optional[float] = 0.5,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
mode = Stability_SD3_5_GenerationMode.text_to_image
if image is not None:
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
mode = Stability_SD3_5_GenerationMode.image_to_image
aspect_ratio = None
else:
image_denoise = None
if not negative_prompt:
negative_prompt = None
if style_preset == "None":
style_preset = None
files = {
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityStable3_5Request(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
seed=seed,
strength=image_denoise,
style_preset=style_preset,
cfg_scale=cfg_scale,
model=model,
mode=mode,
),
files=files,
content_type="multipart/form-data",
)
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return IO.NodeOutput(returned_image)
class StabilityUpscaleConservativeNode(IO.ComfyNode):
"""
Upscale image with minimal alterations to 4K resolution.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityUpscaleConservativeNode",
display_name="Stability AI Upscale Conservative",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
IO.Float.Input(
"creativity",
default=0.35,
min=0.2,
max=0.5,
step=0.01,
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
advanced=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
prompt: str,
creativity: float,
seed: int,
negative_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
if not negative_prompt:
negative_prompt = None
files = {
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityUpscaleConservativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
seed=seed,
),
files=files,
content_type="multipart/form-data",
)
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return IO.NodeOutput(returned_image)
class StabilityUpscaleCreativeNode(IO.ComfyNode):
"""
Upscale image with minimal alterations to 4K resolution.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityUpscaleCreativeNode",
display_name="Stability AI Upscale Creative",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
IO.Float.Input(
"creativity",
default=0.3,
min=0.1,
max=0.5,
step=0.01,
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
),
IO.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
advanced=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
advanced=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.6}""",
),
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
prompt: str,
creativity: float,
style_preset: str,
seed: int,
negative_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
if not negative_prompt:
negative_prompt = None
if style_preset == "None":
style_preset = None
files = {
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
response_model=StabilityAsyncResponse,
data=StabilityUpscaleCreativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
style_preset=style_preset,
seed=seed,
),
files=files,
content_type="multipart/form-data",
)
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
response_model=StabilityResultsGetResponse,
poll_interval=3,
status_extractor=lambda x: get_async_dummy_status(x),
)
if response_poll.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
image_data = base64.b64decode(response_poll.result)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return IO.NodeOutput(returned_image)
class StabilityUpscaleFastNode(IO.ComfyNode):
"""
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityUpscaleFastNode",
display_name="Stability AI Upscale Fast",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.02}""",
),
)
@classmethod
async def execute(cls, image: torch.Tensor) -> IO.NodeOutput:
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
files = {
"image": image_binary
}
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
response_model=StabilityStableUltraResponse,
files=files,
content_type="multipart/form-data",
)
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return IO.NodeOutput(returned_image)
class StabilityTextToAudio(IO.ComfyNode):
"""Generates high-quality music and sound effects from text descriptions."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityTextToAudio",
display_name="Stability AI Text To Audio",
category="partner/audio/Stability AI",
essentials_category="Audio",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
IO.String.Input("prompt", multiline=True, default=""),
IO.Int.Input(
"duration",
default=190,
min=1,
max=190,
step=1,
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
IO.Int.Input(
"steps",
default=8,
min=4,
max=8,
step=1,
tooltip="Controls the number of sampling steps.",
optional=True,
advanced=True,
),
],
outputs=[
IO.Audio.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.2}""",
),
)
@classmethod
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
validate_string(prompt, max_length=10000)
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
)
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityAudioToAudio(IO.ComfyNode):
"""Transforms existing audio samples into new high-quality compositions using text instructions."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityAudioToAudio",
display_name="Stability AI Audio To Audio",
category="partner/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
IO.String.Input("prompt", multiline=True, default=""),
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
IO.Int.Input(
"duration",
default=190,
min=1,
max=190,
step=1,
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
IO.Int.Input(
"steps",
default=8,
min=4,
max=8,
step=1,
tooltip="Controls the number of sampling steps.",
optional=True,
advanced=True,
),
IO.Float.Input(
"strength",
default=1,
min=0.01,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Parameter controls how much influence the audio parameter has on the generated audio.",
optional=True,
),
],
outputs=[
IO.Audio.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.2}""",
),
)
@classmethod
async def execute(
cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float
) -> IO.NodeOutput:
validate_string(prompt, max_length=10000)
validate_audio_duration(audio, 6, 190)
payload = StabilityAudioToAudioRequest(
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
)
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
)
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityAudioInpaint(IO.ComfyNode):
"""Transforms part of existing audio sample using text instructions."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="StabilityAudioInpaint",
display_name="Stability AI Audio Inpaint",
category="partner/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
IO.String.Input("prompt", multiline=True, default=""),
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
IO.Int.Input(
"duration",
default=190,
min=1,
max=190,
step=1,
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
IO.Int.Input(
"steps",
default=8,
min=4,
max=8,
step=1,
tooltip="Controls the number of sampling steps.",
optional=True,
advanced=True,
),
IO.Int.Input(
"mask_start",
default=30,
min=0,
max=190,
step=1,
optional=True,
advanced=True,
),
IO.Int.Input(
"mask_end",
default=190,
min=0,
max=190,
step=1,
optional=True,
advanced=True,
),
],
outputs=[
IO.Audio.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.2}""",
),
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
audio: Input.Audio,
duration: int,
seed: int,
steps: int,
mask_start: int,
mask_end: int,
) -> IO.NodeOutput:
validate_string(prompt, max_length=10000)
if mask_end <= mask_start:
raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})")
validate_audio_duration(audio, 6, 190)
payload = StabilityAudioInpaintRequest(
prompt=prompt,
model=model,
duration=duration,
seed=seed,
steps=steps,
mask_start=mask_start,
mask_end=mask_end,
)
response_api = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
)
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
StabilityStableImageUltraNode,
StabilityStableImageSD_3_5Node,
StabilityUpscaleConservativeNode,
StabilityUpscaleCreativeNode,
StabilityUpscaleFastNode,
StabilityTextToAudio,
StabilityAudioToAudio,
StabilityAudioInpaint,
]
async def comfy_entrypoint() -> StabilityExtension:
return StabilityExtension()

View File

@ -26,7 +26,6 @@ from .conversions import (
text_filepath_to_base64_string,
text_filepath_to_data_uri,
trim_video,
upscale_image_tensor_to_min_pixels,
upscale_video_to_min_pixels,
video_to_base64_string,
)
@ -100,7 +99,6 @@ __all__ = [
"text_filepath_to_base64_string",
"text_filepath_to_data_uri",
"trim_video",
"upscale_image_tensor_to_min_pixels",
"upscale_video_to_min_pixels",
"video_to_base64_string",
# Validation utilities

View File

@ -448,15 +448,6 @@ def _compute_upscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[in
return new_w, new_h
def upscale_image_tensor_to_min_pixels(image: torch.Tensor, total_pixels: int) -> torch.Tensor:
samples = image.movedim(-1, 1)
dims = _compute_upscale_dims(samples.shape[3], samples.shape[2], int(total_pixels))
if dims is None:
return image
new_w, new_h = dims
return common_upscale(samples, new_w, new_h, "lanczos", "disabled").movedim(1, -1)
def upscale_video_to_min_pixels(video: Input.Video, min_pixels: int) -> Input.Video:
"""Upscale a video to meet at least ``min_pixels`` (w * h), preserving aspect ratio.

View File

@ -1,6 +1,5 @@
import asyncio
import bisect
import gc
import itertools
import psutil
import time
@ -529,38 +528,6 @@ class RAMPressureCache(LRUCache):
if psutil.virtual_memory().available >= target:
return
def remove_cache_key(key):
del self.cache[key]
self.used_generation.pop(key, None)
self.timestamps.pop(key, None)
self.children.pop(key, None)
def has_old_model_patcher(outputs):
if outputs is None:
return False
for output in outputs:
if isinstance(output, (list, tuple)):
if has_old_model_patcher(output):
return True
elif isinstance(output, ModelPatcher):
return True
return False
old_modelpatcher_keys = []
for key, cache_entry in self.cache.items():
if self.used_generation[key] == self.generation:
continue
if has_old_model_patcher(cache_entry.outputs):
old_modelpatcher_keys.append(key)
for key in old_modelpatcher_keys:
remove_cache_key(key)
if old_modelpatcher_keys:
gc.collect()
if psutil.virtual_memory().available >= target:
return
clean_list = []
for key, cache_entry in self.cache.items():
@ -578,17 +545,19 @@ class RAMPressureCache(LRUCache):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
ram_usage += output.numel() * output.element_size()
elif isinstance(output, ModelPatcher) and self.used_generation[key] != self.generation:
#old ModelPatchers are the first to go
ram_usage = 1e30
scan_list_for_ram_usage(cache_entry.outputs)
oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], ram_usage, key))
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
to_free = target - psutil.virtual_memory().available
while to_free > 0 and clean_list:
_, _, ram_usage, key = clean_list.pop()
remove_cache_key(key)
to_free -= ram_usage
gc.collect()
while psutil.virtual_memory().available < target and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
self.used_generation.pop(key, None)
self.timestamps.pop(key, None)
self.children.pop(key, None)

View File

@ -8,8 +8,7 @@ class CLIPTextEncodeControlnet(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CLIPTextEncodeControlnet",
display_name="CLIP Text Encode (Controlnet)",
category="model/conditioning",
category="experimental/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Conditioning.Input("conditioning"),
@ -36,12 +35,11 @@ class T5TokenizerOptions(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="T5TokenizerOptions",
display_name="T5 Tokenizer Options",
category="model/conditioning",
category="experimental/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1),
io.Int.Input("min_length", default=0, min=0, max=10000, step=1),
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),
io.Int.Input("min_length", default=0, min=0, max=10000, step=1, advanced=True),
],
outputs=[io.Clip.Output()],
is_experimental=True,

View File

@ -1070,7 +1070,7 @@ class AddNoise(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="AddNoise",
category="model/sampling/noise",
category="experimental/custom_sampling/noise",
is_experimental=True,
inputs=[
io.Model.Input("model"),
@ -1120,7 +1120,7 @@ class ManualSigmas(io.ComfyNode):
return io.Schema(
node_id="ManualSigmas",
search_aliases=["custom noise schedule", "define sigmas"],
category="model/sampling/sigmas",
category="experimental/custom_sampling",
is_experimental=True,
inputs=[
io.String.Input("sigmas", default="1, 0.5", multiline=False)

View File

@ -123,8 +123,7 @@ class PhotoMakerLoader(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PhotoMakerLoader",
display_name="Load PhotoMaker Model",
category="model/loaders",
category="experimental/photomaker",
inputs=[
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
],
@ -150,8 +149,7 @@ class PhotoMakerEncode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PhotoMakerEncode",
display_name="PhotoMaker Encode",
category="model/conditioning/photomaker",
category="experimental/photomaker",
inputs=[
io.Photomaker.Input("photomaker"),
io.Image.Input("image"),

View File

@ -119,7 +119,7 @@ class StableCascade_SuperResolutionControlnet(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StableCascade_SuperResolutionControlnet",
category="experimental/stable cascade",
category="experimental/stable_cascade",
is_experimental=True,
inputs=[
io.Image.Input("image"),

View File

@ -143,7 +143,7 @@ class VAEDecodeTripoSplat(IO.ComfyNode):
return IO.Schema(
node_id="VAEDecodeTripoSplat",
display_name="TripoSplat Decode",
category="model/latent/triposplat",
category="3d/latent",
description="Decode the sampled TripoSplat latent into a 3D gaussian splat. "
"Modify the number of gaussians to vary the density.",
inputs=[
@ -188,7 +188,7 @@ class TripoSplatSamplingPreview(IO.ComfyNode):
return IO.Schema(
node_id="TripoSplatSamplingPreview",
display_name="TripoSplat Sampling Preview",
category="model/latent/triposplat",
category="3d/latent",
description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded "
"gaussian splat preview at each step.",
inputs=[

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.27.0"
__version__ = "0.26.0"

View File

@ -264,59 +264,6 @@ def annotated_filepath(name: str) -> tuple[str, str | None]:
return name, base_dir
# Content types a browser may execute or render inline. File endpoints that
# serve user-controlled content must force these to download (and ideally set
# Content-Disposition: attachment) to avoid stored XSS. Centralised here so the
# /view and /userdata handlers can't drift apart. mimetypes.guess_type may
# return either the text/* or application/* spelling depending on platform, so
# both are listed.
DANGEROUS_CONTENT_TYPES = {
'text/html', 'text/html-sandboxed', 'application/xhtml+xml',
'text/javascript', 'application/javascript', 'application/x-javascript',
'application/ecmascript', 'text/css',
'image/svg+xml', 'application/xml', 'text/xml',
# message/rfc822 (.mht/.mhtml) can carry script in some browsers.
'message/rfc822',
}
def is_dangerous_content_type(content_type: str | None) -> bool:
"""Return True if a browser may execute or render `content_type` inline.
Normalises before matching so the check can't be slipped past with a
charset/boundary parameter (``text/html; charset=utf-8``) or casing
(``TEXT/HTML``). Any XML dialect (``*+xml`` or ``*/xml``) is treated as
dangerous because XML can carry inline script via stylesheet/entity tricks,
which also covers the ``application/{xslt,rss,atom,rdf}+xml`` family without
enumerating each one. Endpoints serving user-controlled content should route
a dangerous type to ``application/octet-stream`` + ``Content-Disposition:
attachment`` + ``X-Content-Type-Options: nosniff``.
"""
if not content_type:
return False
normalized = content_type.split(';', 1)[0].strip().lower()
if normalized in DANGEROUS_CONTENT_TYPES:
return True
return normalized.endswith('+xml') or normalized.endswith('/xml')
def is_within_directory(directory: str, target: str) -> bool:
"""Return True if `target` resolves to a path inside `directory`.
Uses realpath on both operands so that a symlink placed inside `directory`
that points elsewhere cannot escape the containment check at open time.
"""
try:
directory = os.path.realpath(directory)
target = os.path.realpath(target)
return os.path.commonpath((directory, target)) == directory
except ValueError:
# ValueError is raised by realpath() on a path with an embedded null
# byte, and by commonpath() on Windows when the paths are on different
# drives. In either case the target is not safely within the directory.
return False
def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
name, base_dir = annotated_filepath(name)
@ -326,12 +273,7 @@ def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
else:
base_dir = get_input_directory() # fallback path
filepath = os.path.abspath(os.path.join(base_dir, name))
# Prevent path traversal: the resolved path must stay within base_dir.
# repr() the name in the message so a crafted value can't inject log lines.
if not is_within_directory(base_dir, filepath):
raise ValueError("Invalid file path: {!r}".format(name))
return filepath
return os.path.join(base_dir, name)
def exists_annotated_filepath(name) -> bool:
@ -340,10 +282,7 @@ def exists_annotated_filepath(name) -> bool:
if base_dir is None:
base_dir = get_input_directory() # fallback path
filepath = os.path.abspath(os.path.join(base_dir, name))
# Treat traversal attempts as non-existent rather than probing the filesystem.
if not is_within_directory(base_dir, filepath):
return False
filepath = os.path.join(base_dir, name)
return os.path.exists(filepath)

View File

@ -314,7 +314,7 @@ def prompt_worker(q, server_instance):
cache_ram = 0
cache_ram_inactive = 0
if not args.cache_classic and not args.cache_none and args.cache_lru <= 0:
cache_ram = min(10.0, max(1.5, comfy.model_management.total_ram * 0.05 / 1024.0))
cache_ram = min(10.0, max(2.0, comfy.model_management.total_ram * 0.10 / 1024.0))
cache_ram_inactive = min(96.0, comfy.model_management.total_ram / 1024.0)
if len(args.cache_ram) > 0:
cache_ram = args.cache_ram[0]
@ -403,7 +403,7 @@ def prompt_worker(q, server_instance):
hook_breaker_ac10a0.restore_functions()
if not asset_seeder.is_disabled():
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=args.enable_asset_hashing)
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
asset_seeder.resume()
@ -458,7 +458,7 @@ def setup_database():
if dependencies_available():
init_db()
if args.enable_assets:
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=args.enable_asset_hashing):
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True):
logging.info("Background asset scan initiated for models, input, output")
except Exception as e:
if "database is locked" in str(e):

View File

@ -349,7 +349,7 @@ class VAEDecodeTiled:
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "model/latent"
CATEGORY = "experimental"
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
if tile_size < overlap * 4:
@ -396,7 +396,7 @@ class VAEEncodeTiled:
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "model/latent"
CATEGORY = "experimental"
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
@ -514,7 +514,7 @@ class SaveLatent:
OUTPUT_NODE = True
CATEGORY = "model/latent"
CATEGORY = "experimental"
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
@ -559,7 +559,7 @@ class LoadLatent:
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
return {"required": {"latent": [sorted(files), ]}, }
CATEGORY = "model/latent"
CATEGORY = "experimental"
RETURN_TYPES = ("LATENT", )
FUNCTION = "load"
@ -2155,8 +2155,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"GLIGENTextBoxApply": "Apply GLIGEN Text Box",
"ConditioningZeroOut": "Conditioning Zero Out",
# Latent
"LoadLatent": "Load Latent",
"SaveLatent": "Save Latent",
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
"SetLatentNoiseMask": "Set Latent Noise Mask",
"VAEDecode": "VAE Decode",
@ -2191,6 +2189,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImageSharpen": "Sharpen Image",
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
"GetImageSize": "Get Image Size",
# experimental
"VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)",
}

View File

@ -230,6 +230,93 @@ components:
- base_version
- workflow_json
type: object
DownloadEnqueueRequest:
description: Request body for enqueuing a server-side model download.
properties:
allow_any_extension:
default: false
description: Permit a non-model file extension (default only allows known model extensions).
type: boolean
credential_id:
description: Explicit per-host credential to use; otherwise auto-resolved by host. Still subject to the per-hop host match.
nullable: true
type: string
expected_sha256:
description: Optional hub-provided SHA256 to verify the completed file against (fail-closed).
nullable: true
type: string
model_id:
description: Destination as "<directory>/<filename>", resolving to a registered model folder (e.g. "loras/my_lora.safetensors").
type: string
priority:
default: 0
description: Scheduling priority; higher is admitted first.
type: integer
url:
description: Source URL; must be on the allowlist (host + scheme + extension).
type: string
required:
- url
- model_id
type: object
DownloadStatus:
description: Current state and live progress of a single download.
properties:
bytes_done:
type: integer
created_at:
type: integer
download_id:
format: uuid
type: string
error:
nullable: true
type: string
eta_seconds:
nullable: true
type: number
model_id:
type: string
priority:
type: integer
progress:
description: Fraction in [0,1]; null until total size is known.
nullable: true
type: number
segments:
description: Per-segment progress (segmented downloads only).
items:
properties:
bytes_done:
type: integer
idx:
type: integer
length:
type: integer
type: object
nullable: true
type: array
speed_bps:
nullable: true
type: number
status:
enum:
- queued
- active
- paused
- verifying
- completed
- failed
- cancelled
type: string
total_bytes:
nullable: true
type: integer
updated_at:
type: integer
url:
type: string
type: object
ErrorResponse:
description: Standard error response with a machine-readable code and human-readable message.
properties:
@ -511,6 +598,78 @@ components:
required:
- history
type: object
HostCredentialUpsert:
description: Request body for upserting a per-host credential. The secret is write-only.
properties:
auth_scheme:
default: bearer
description: How the secret is attached to requests.
enum:
- bearer
- header
- query
type: string
enabled:
default: true
type: boolean
header_name:
description: Header name when auth_scheme=header (defaults to Authorization).
nullable: true
type: string
host:
description: Normalized hostname the key applies to (e.g. "civitai.com").
type: string
label:
description: User-friendly name for display.
nullable: true
type: string
match_subdomains:
default: false
description: Also match label-boundary subdomains of host (off by default; unsafe for hub CDNs).
type: boolean
query_param:
description: Query parameter name when auth_scheme=query.
nullable: true
type: string
secret:
description: The API key. Write-only — never returned by any endpoint.
type: string
required:
- host
- secret
type: object
HostCredentialView:
description: Masked, API-safe view of a stored credential. Never includes the secret.
properties:
auth_scheme:
type: string
created_at:
type: integer
enabled:
type: boolean
header_name:
nullable: true
type: string
host:
type: string
id:
format: uuid
type: string
label:
nullable: true
type: string
match_subdomains:
type: boolean
query_param:
nullable: true
type: string
secret_last4:
description: Last 4 characters of the secret, for masked display only.
nullable: true
type: string
updated_at:
type: integer
type: object
JobCancelResponse:
description: Response for POST /api/jobs/{job_id}/cancel. Returned on both fresh cancels and idempotent no-ops.
properties:
@ -2350,6 +2509,391 @@ paths:
summary: Get tag histogram for filtered assets
tags:
- file
/api/download:
get:
description: List all known downloads (queued, active, paused, and terminal) with live progress.
operationId: listDownloads
responses:
"200":
content:
application/json:
schema:
properties:
downloads:
items:
$ref: '#/components/schemas/DownloadStatus'
type: array
type: object
description: List of downloads
summary: List downloads
tags:
- download
/api/download/availability:
post:
description: |
Bulk per-id availability for a set of model_ids declared in a workflow.
Returns whether each model is available on disk, currently downloading
(with progress), or missing, plus whether its URL is on the allowlist.
operationId: getModelsAvailability
requestBody:
content:
application/json:
schema:
properties:
models:
additionalProperties:
type: string
description: Map of "<directory>/<filename>" model_id to its declared source URL.
type: object
type: object
responses:
"200":
content:
application/json:
schema:
properties:
models:
additionalProperties: true
type: object
type: object
description: Per-id availability map
summary: Bulk model availability + status
tags:
- download
/api/download/clear:
post:
description: |
Delete all terminal downloads (completed, failed, cancelled) from history
in one transaction, so the cleared history persists across reloads. Live
downloads (queued, active, paused, verifying) are skipped. Finished model
files on disk are never removed; only leftover .part temp files are cleaned up.
operationId: clearDownloads
responses:
"200":
content:
application/json:
schema:
properties:
deleted:
description: Number of history rows removed.
type: integer
type: object
description: History cleared
summary: Clear terminal downloads from history
tags:
- download
/api/download/credentials:
get:
description: List stored per-host credentials. Secrets are never returned; only masked metadata (last 4 chars, scheme, label).
operationId: listDownloadCredentials
responses:
"200":
content:
application/json:
schema:
properties:
credentials:
items:
$ref: '#/components/schemas/HostCredentialView'
type: array
type: object
description: Masked credential list
summary: List host credentials (masked)
tags:
- download
post:
description: |
Upsert (by host) a per-host API key used to authenticate downloads.
The secret is write-only: it is stored once here and never returned by any endpoint.
operationId: upsertDownloadCredential
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/HostCredentialUpsert'
responses:
"201":
content:
application/json:
schema:
$ref: '#/components/schemas/HostCredentialView'
description: Credential stored (masked view returned)
"400":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Invalid credential
summary: Upsert a host credential
tags:
- download
/api/download/credentials/{id}:
delete:
description: Delete a stored host credential.
operationId: deleteDownloadCredential
parameters:
- in: path
name: id
required: true
schema:
type: string
responses:
"200":
content:
application/json:
schema:
properties:
deleted:
type: boolean
type: object
description: Deleted
"404":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: No such credential
summary: Delete a host credential
tags:
- download
get:
description: Get a single host credential (masked; never includes the secret).
operationId: getDownloadCredential
parameters:
- in: path
name: id
required: true
schema:
type: string
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/HostCredentialView'
description: Masked credential
"404":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: No such credential
summary: Get a host credential (masked)
tags:
- download
/api/download/enqueue:
post:
description: |
Enqueue a server-side model download. The URL must be on the allowlist
(host + scheme + extension) and the model_id must be "<directory>/<filename>"
resolving to a registered model folder. Returns immediately; track progress
via GET /api/download/{id} or the "download_progress" websocket event.
operationId: enqueueDownload
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/DownloadEnqueueRequest'
responses:
"202":
content:
application/json:
schema:
properties:
accepted:
type: boolean
download_id:
format: uuid
type: string
type: object
description: Download accepted and queued
"400":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Invalid request (bad URL, model_id, or not allowlisted)
"409":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Already on disk or already downloading
summary: Enqueue a model download
tags:
- download
/api/download/{id}:
delete:
description: |
Delete a single terminal download from history so it stays gone across
reloads. Refuses (409) to delete a live download (queued, active, paused,
verifying) — cancel it first. The finished model file on disk is never
removed; only a leftover .part temp file is cleaned up.
operationId: deleteDownload
parameters:
- in: path
name: id
required: true
schema:
format: uuid
type: string
responses:
"200":
content:
application/json:
schema:
properties:
deleted:
type: boolean
type: object
description: Deleted
"404":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: No such download
"409":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Download is still in progress
summary: Delete a download from history
tags:
- download
get:
description: Get the current status + progress of a single download.
operationId: getDownloadStatus
parameters:
- in: path
name: id
required: true
schema:
format: uuid
type: string
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/DownloadStatus'
description: Download status
"404":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: No such download
summary: Get download status
tags:
- download
/api/download/{id}/cancel:
post:
description: Cancel a download. The partial file is removed.
operationId: cancelDownload
parameters:
- in: path
name: id
required: true
schema:
format: uuid
type: string
responses:
"200":
content:
application/json:
schema:
properties:
ok:
type: boolean
type: object
description: Cancelled
summary: Cancel a download
tags:
- download
/api/download/{id}/pause:
post:
description: Pause a download. The partial file and per-segment offsets are retained for resume.
operationId: pauseDownload
parameters:
- in: path
name: id
required: true
schema:
format: uuid
type: string
responses:
"200":
content:
application/json:
schema:
properties:
ok:
type: boolean
type: object
description: Paused
summary: Pause a download
tags:
- download
/api/download/{id}/priority:
post:
description: Set a download's scheduling priority. Higher priority is admitted first when a slot frees.
operationId: setDownloadPriority
parameters:
- in: path
name: id
required: true
schema:
format: uuid
type: string
requestBody:
content:
application/json:
schema:
properties:
priority:
type: integer
required:
- priority
type: object
responses:
"200":
content:
application/json:
schema:
properties:
ok:
type: boolean
type: object
description: Priority updated
summary: Set download priority
tags:
- download
/api/download/{id}/resume:
post:
description: Resume a paused (or failed) download from its persisted offsets.
operationId: resumeDownload
parameters:
- in: path
name: id
required: true
schema:
format: uuid
type: string
responses:
"200":
content:
application/json:
schema:
properties:
ok:
type: boolean
type: object
description: Resumed
summary: Resume a download
tags:
- download
/api/embeddings:
get:
description: Returns the list of text-encoder embeddings available on disk.
@ -5103,3 +5647,5 @@ tags:
name: queue
- description: Job lifecycle queries
name: job
- description: Model download management
name: download

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.27.0"
version = "0.26.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.45.20
comfyui-workflow-templates==0.11.2
comfyui-workflow-templates==0.10.7
comfyui-embedded-docs==0.5.6
torch
torchsde
@ -22,7 +22,7 @@ alembic
SQLAlchemy>=2.0.0
filelock
av>=16.0.0
comfy-kitchen==0.2.16
comfy-kitchen==0.2.15
comfy-aimdo==0.4.10
requests
simpleeval>=1.0.0

View File

@ -45,6 +45,8 @@ from app.frontend_management import FrontendManager, parse_version
from comfy_api.internal import _ComfyNodeInternal
from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_routes
from app.model_downloader.api.routes import register_routes as register_model_downloader_routes
from app.model_downloader.manager import DOWNLOAD_MANAGER
from app.assets.services.ingest import register_file_in_place
from app.assets.services.asset_management import resolve_hash_to_path
@ -127,7 +129,6 @@ def create_cors_middleware(allowed_origin: str):
return cors_middleware
def is_loopback(host):
if host is None:
return False
@ -257,6 +258,7 @@ class PromptServer():
else:
register_assets_routes(self.app)
asset_seeder.disable()
register_model_downloader_routes(self.app)
routes = web.RouteTableDef()
self.routes = routes
self.last_node_id = None
@ -617,30 +619,15 @@ class PromptServer():
or 'application/octet-stream'
)
# For security, force renderable/active types (HTML, JS,
# CSS, SVG, XML — anything that can carry inline <script>
# and execute in the page origin) to download instead of
# displaying inline, preventing stored XSS. The
# attachment disposition is the load-bearing guard: a
# bare filename= hint does not force a download per
# RFC 6266, so we only attach it on the dangerous branch
# to avoid breaking inline display of legitimate images.
# Escape backslash/quote per RFC 6266 quoted-string so a
# filename containing a double quote (which passes the
# ".."/leading-slash filter above) can't break out of the
# header's quoted-string and malform the disposition.
safe_filename = filename.replace("\\", "\\\\").replace('"', '\\"')
disposition = f"filename=\"{safe_filename}\""
if folder_paths.is_dangerous_content_type(content_type):
content_type = 'application/octet-stream'
disposition = f"attachment; filename=\"{safe_filename}\""
# For security, force certain mimetypes to download instead of display
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
content_type = 'application/octet-stream' # Forces download
return web.FileResponse(
file,
headers={
"Content-Disposition": disposition,
"Content-Type": content_type,
"X-Content-Type-Options": "nosniff"
"Content-Disposition": f"filename=\"{filename}\"",
"Content-Type": content_type
}
)
@ -1198,6 +1185,29 @@ class PromptServer():
async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)
await self._setup_model_downloader()
async def _setup_model_downloader(self):
"""Start the download manager: push progress over the websocket and
resume any downloads interrupted by a previous run."""
def _notify(download_id: str) -> None:
try:
view = DOWNLOAD_MANAGER.status_sync(download_id)
if view is not None:
# Drop the url field before broadcasting: the redacted URL
# (scheme + host + path) should not leak to every connected
# websocket client. download_id / model_id are sufficient to
# correlate progress on the frontend.
broadcast = {k: v for k, v in view.items() if k != "url"}
self.send_sync("download_progress", broadcast)
except Exception:
logging.debug("download progress notify failed", exc_info=True)
DOWNLOAD_MANAGER.set_notify(_notify)
try:
await DOWNLOAD_MANAGER.start()
except Exception as e:
logging.warning("Failed to start model download manager: %s", e)
def add_routes(self):
self.user_manager.add_routes(self.routes)

View File

@ -1,5 +1,3 @@
import contextlib
import json
import time
import uuid
from datetime import datetime
@ -11,40 +9,6 @@ import requests
from helpers import get_asset_filename, trigger_sync_seed_assets
def test_download_svg_forced_to_attachment(http: requests.Session, api_base: str):
"""GHSA-779p-m5rp-r4h4 CISA-5 (sibling route): an uploaded SVG must never be
served inline from GET /api/assets/{id}/content, or an inline <script> runs
in the app origin (stored XSS). Even with disposition=inline requested, a
dangerous content type must be forced to application/octet-stream +
Content-Disposition: attachment + nosniff. Regression guard for the stale
inline blocklist that previously omitted image/svg+xml and ignored the
centralized folder_paths.is_dangerous_content_type check.
"""
svg = b'<svg xmlns="http://www.w3.org/2000/svg"><script>alert(1)</script></svg>'
files = {"file": ("evil.svg", svg, "image/svg+xml")}
form_data = {
"tags": json.dumps(["models", "checkpoints", "unit-tests", "svgxss"]),
"name": "evil.svg",
}
up = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
body = up.json()
assert up.status_code in (200, 201), body
aid = body["id"]
try:
r = http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline", timeout=120)
r.content
assert r.status_code == 200
ct = r.headers.get("Content-Type", "").lower()
cd = r.headers.get("Content-Disposition", "").lower()
assert "svg" not in ct, f"SVG served with a renderable content type: {ct!r}"
assert ct.startswith("application/octet-stream"), f"expected octet-stream, got {ct!r}"
assert "attachment" in cd, f"inline disposition not overridden to attachment: {cd!r}"
assert r.headers.get("X-Content-Type-Options", "").lower() == "nosniff"
finally:
with contextlib.suppress(Exception):
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]

View File

@ -53,11 +53,8 @@ def test_annotated_filepath():
def test_get_annotated_filepath():
default_dir = "/default/dir"
# get_annotated_filepath now normalizes with os.path.abspath (part of the
# GHSA-779p traversal hardening), so compare against the normalized form —
# on Windows abspath also prepends the current drive letter.
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.abspath(os.path.join(default_dir, "test.txt"))
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.abspath(os.path.join(folder_paths.get_output_directory(), "test.txt"))
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
def test_add_model_folder_path_append(clear_folder_paths):
folder_paths.add_model_folder_path("test_folder", "/default/path", is_default=True)

View File

@ -0,0 +1,90 @@
"""Shared fixtures for the model download manager tests.
These run in-process (no ComfyUI subprocess): a file-backed SQLite DB is
initialized once, a temp model folder is registered with ``folder_paths``, and
the shared aiohttp session is reset between tests so each async test gets a
session bound to its own event loop.
"""
from __future__ import annotations
import asyncio
import os
import tempfile
import pytest
def _drain_scheduler_tasks(scheduler) -> None:
"""Cancel and await live scheduler tasks so none outlive the test.
Uses the actual task handles rather than only clearing ``_tasks``: each
per-test event loop is created by ``asyncio.run``, so a task left behind by
a crashed/aborted test would otherwise keep its coroutine alive. We cancel
every live task and, when its loop is still usable, run it to completion to
let the cancellation propagate before dropping the reference.
"""
for task in list(scheduler._tasks.values()):
if task is None:
continue
loop = task.get_loop()
if task.done() or loop.is_closed():
continue
task.cancel()
if not loop.is_running():
try:
loop.run_until_complete(asyncio.gather(task, return_exceptions=True))
except Exception:
pass
scheduler._tasks.clear()
@pytest.fixture(scope="session", autouse=True)
def _init_db():
import app.database.db as db
from comfy.cli_args import args
fd, db_path = tempfile.mkstemp(suffix="-dlmgr-test.sqlite3")
os.close(fd)
args.database_url = f"sqlite:///{db_path}"
db.init_db()
yield
try:
os.remove(db_path)
except OSError:
pass
@pytest.fixture(autouse=True)
def _reset_runtime():
"""Reset module singletons that hold event-loop-bound or cross-test state."""
import app.model_downloader.net.session as ns
from app.model_downloader.scheduler import SCHEDULER
ns._session = None
_drain_scheduler_tasks(SCHEDULER)
SCHEDULER._jobs.clear()
SCHEDULER._backoff_until.clear()
SCHEDULER._started = False
yield
_drain_scheduler_tasks(SCHEDULER)
ns._session = None
@pytest.fixture
def model_root(tmp_path):
"""Register a temp 'loras' model folder and return its absolute path."""
import folder_paths
root = tmp_path / "loras"
root.mkdir(parents=True, exist_ok=True)
saved = folder_paths.folder_names_and_paths.get("loras")
folder_paths.folder_names_and_paths["loras"] = (
[str(root)],
{".safetensors", ".sft", ".ckpt", ".pt", ".pth"},
)
yield str(root)
if saved is not None:
folder_paths.folder_names_and_paths["loras"] = saved
else:
folder_paths.folder_names_and_paths.pop("loras", None)

View File

@ -0,0 +1,166 @@
"""Unit tests for the credential store and the per-hop credential resolver.
Covers the critical rule: a secret is only ever attached when the current
hop's host matches a stored credential, and never over a non-https hop.
"""
from __future__ import annotations
import asyncio
import pytest
from app.model_downloader.credentials import resolver
from app.model_downloader.credentials.store import (
CREDENTIAL_STORE,
CredentialValidationError,
normalize_host,
)
from app.model_downloader.database.models import HostCredential
# ----- pure host normalization + matching -----
@pytest.mark.parametrize(
"raw,expected",
[
("Civitai.com", "civitai.com"),
("HuggingFace.co:443", "huggingface.co"),
(" Example.COM ", "example.com"),
],
)
def test_normalize_host(raw, expected):
assert normalize_host(raw) == expected
def _cred(**kw) -> HostCredential:
base = dict(
id="x", host="civitai.com", match_subdomains=False, auth_scheme="bearer",
secret="SECRET", enabled=True,
)
base.update(kw)
return HostCredential(**base)
def test_matches_exact_only_by_default():
c = _cred(host="civitai.com")
assert resolver._matches(c, "civitai.com") is True
assert resolver._matches(c, "api.civitai.com") is False
assert resolver._matches(c, "evil-civitai.com") is False
def test_matches_subdomain_label_boundary():
c = _cred(host="example.com", match_subdomains=True)
assert resolver._matches(c, "api.example.com") is True
assert resolver._matches(c, "example.com") is True
# not a label boundary -> no match
assert resolver._matches(c, "evil-example.com") is False
def test_build_auth_shapes():
assert resolver._build_auth(_cred(auth_scheme="bearer")).headers == {
"Authorization": "Bearer SECRET"
}
assert resolver._build_auth(
_cred(auth_scheme="header", header_name="X-Api-Key")
).headers == {"X-Api-Key": "SECRET"}
q = resolver._build_auth(_cred(auth_scheme="query", query_param="token"))
assert q.query == {"token": "SECRET"}
assert q.apply_to_url("https://civitai.com/x") == "https://civitai.com/x?token=SECRET"
# ----- DB-backed store + resolver -----
def test_store_upsert_is_write_only_and_masked():
async def _run():
view = await CREDENTIAL_STORE.upsert("civitai.com", "abcd1234", label="my key")
# The view never carries the secret, only the last 4.
assert not hasattr(view, "secret")
assert view.secret_last4 == "1234"
assert view.host == "civitai.com"
listed = await CREDENTIAL_STORE.list()
assert any(v.host == "civitai.com" for v in listed)
await CREDENTIAL_STORE.delete(view.id)
asyncio.run(_run())
def test_query_scheme_requires_param():
async def _run():
with pytest.raises(CredentialValidationError):
await CREDENTIAL_STORE.upsert("civitai.com", "k", auth_scheme="query")
asyncio.run(_run())
def test_resolver_never_crosses_host_boundary():
async def _run():
view = await CREDENTIAL_STORE.upsert("huggingface.co", "hf_secret_key")
try:
# matching host over https -> attached
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
assert auth is not None
assert auth.headers["Authorization"] == "Bearer hf_secret_key"
# CDN redirect host -> dropped
assert await resolver.resolve_auth_for_hop("cdn-lfs.huggingface.co", "https") is None
# non-https hop -> never attached
assert await resolver.resolve_auth_for_hop("huggingface.co", "http") is None
finally:
await CREDENTIAL_STORE.delete(view.id)
asyncio.run(_run())
# ----- env-based HF token fallback -----
def test_env_token_fallback_attaches_when_no_db_credential(monkeypatch):
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
async def _run():
# exact host over https -> env token attached
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
assert auth is not None
assert auth.headers["Authorization"] == "Bearer env_hf_token"
# non-https hop -> never attached
assert await resolver.resolve_auth_for_hop("huggingface.co", "http") is None
# CDN redirect host -> dropped (exact-host only)
assert await resolver.resolve_auth_for_hop("cdn-lfs.huggingface.co", "https") is None
asyncio.run(_run())
def test_env_token_secondary_var_is_honored(monkeypatch):
monkeypatch.delenv("HF_TOKEN", raising=False)
monkeypatch.setenv("HUGGING_FACE_HUB_TOKEN", "env_hub_token")
async def _run():
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
assert auth is not None
assert auth.headers["Authorization"] == "Bearer env_hub_token"
asyncio.run(_run())
def test_db_credential_takes_precedence_over_env(monkeypatch):
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
async def _run():
view = await CREDENTIAL_STORE.upsert("huggingface.co", "db_secret_key")
try:
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
assert auth is not None
assert auth.headers["Authorization"] == "Bearer db_secret_key"
finally:
await CREDENTIAL_STORE.delete(view.id)
asyncio.run(_run())
def test_env_token_does_not_leak_into_explicit_path(monkeypatch):
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
async def _run():
# An explicit credential id that doesn't resolve must stay None; the env
# fallback only applies to the auto-resolve branch.
auth = await resolver.resolve_auth_for_hop(
"huggingface.co", "https", explicit_credential_id="does-not-exist"
)
assert auth is None
asyncio.run(_run())

View File

@ -0,0 +1,136 @@
"""Unit tests for ``DownloadManager.delete`` and ``DownloadManager.clear``.
Deleting a terminal row must remove it from history for good (so it does not
reappear on the next ``list``), leave live rows untouched, and clean up any
leftover ``.part`` temp file without touching the finished model file.
``clear()`` is the bulk variant: it removes all terminal rows atomically, skips
live ones, and returns the count of rows deleted.
Async methods are driven via ``asyncio.run`` so no pytest-asyncio plugin is
required.
"""
from __future__ import annotations
import asyncio
import os
import pytest
from app.model_downloader.constants import DownloadStatus
from app.model_downloader.database import queries
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
def _insert(download_id: str, status: str, *, temp_path: str = "/tmp/none.part") -> None:
queries.insert_download(
{
"id": download_id,
"url": "https://huggingface.co/org/model.safetensors",
"model_id": "loras/model.safetensors",
"dest_path": "/tmp/model.safetensors",
"temp_path": temp_path,
"status": status,
"priority": 0,
}
)
def test_delete_removes_terminal_row_from_history():
_insert("done", DownloadStatus.COMPLETED)
asyncio.run(DOWNLOAD_MANAGER.delete("done"))
assert queries.get_download("done") is None
def test_delete_refuses_live_row():
_insert("live", DownloadStatus.QUEUED)
with pytest.raises(DownloadError) as excinfo:
asyncio.run(DOWNLOAD_MANAGER.delete("live"))
assert excinfo.value.code == "DOWNLOAD_ACTIVE"
assert queries.get_download("live") is not None
def test_delete_missing_row_raises_not_found():
with pytest.raises(DownloadError) as excinfo:
asyncio.run(DOWNLOAD_MANAGER.delete("nope"))
assert excinfo.value.code == "NOT_FOUND"
def test_delete_removes_leftover_temp_file(tmp_path):
partial = tmp_path / "model.safetensors.part"
partial.write_bytes(b"partial")
_insert("failed", DownloadStatus.FAILED, temp_path=str(partial))
asyncio.run(DOWNLOAD_MANAGER.delete("failed"))
assert not os.path.exists(partial)
assert queries.get_download("failed") is None
# ----- clear -----
def test_clear_removes_all_terminal_rows():
_insert("c-done", DownloadStatus.COMPLETED)
_insert("c-fail", DownloadStatus.FAILED)
_insert("c-canc", DownloadStatus.CANCELLED)
deleted = asyncio.run(DOWNLOAD_MANAGER.clear())
assert deleted == 3
assert queries.get_download("c-done") is None
assert queries.get_download("c-fail") is None
assert queries.get_download("c-canc") is None
def test_clear_skips_live_rows():
_insert("cl-queued", DownloadStatus.QUEUED)
_insert("cl-paused", DownloadStatus.PAUSED)
_insert("cl-done", DownloadStatus.COMPLETED)
deleted = asyncio.run(DOWNLOAD_MANAGER.clear())
assert deleted == 1
assert queries.get_download("cl-queued") is not None
assert queries.get_download("cl-paused") is not None
assert queries.get_download("cl-done") is None
def test_clear_returns_zero_when_nothing_to_delete():
_insert("cl-only-live", DownloadStatus.QUEUED)
deleted = asyncio.run(DOWNLOAD_MANAGER.clear())
assert deleted == 0
assert queries.get_download("cl-only-live") is not None
def test_clear_removes_leftover_temp_files(tmp_path):
partial = tmp_path / "clear_partial.part"
partial.write_bytes(b"partial data")
finished = tmp_path / "finished.safetensors"
finished.write_bytes(b"real model weights")
_insert("cl-part", DownloadStatus.FAILED, temp_path=str(partial))
# The finished file is not the temp_path; temp_path for a completed download
# no longer exists (already renamed), so use a non-existent path here to
# verify clear() tolerates a missing temp file without raising.
_insert("cl-comp", DownloadStatus.COMPLETED, temp_path=str(tmp_path / "gone.part"))
asyncio.run(DOWNLOAD_MANAGER.clear())
# Leftover .part from the failed download is cleaned up.
assert not partial.exists()
# Finished model file is never touched.
assert finished.exists()
def test_clear_empty_db_returns_zero():
deleted = asyncio.run(DOWNLOAD_MANAGER.clear())
assert deleted == 0

View File

@ -0,0 +1,637 @@
"""Integration tests for the download engine against a local aiohttp server.
Covers single-stream and segmented transfers, deterministic resume from a
partial file, and cancel rollback. Async tests are driven via ``asyncio.run``
so no pytest-asyncio plugin is required.
"""
from __future__ import annotations
import asyncio
import json
import os
import struct
import uuid
import pytest
from aiohttp import web
from comfy.cli_args import args
from app.model_downloader.constants import DownloadStatus
from app.model_downloader.database import queries
from app.model_downloader.engine.job import DownloadJob, JobSpec
from app.model_downloader.net.session import close_session
from app.model_downloader.security import paths
PAYLOAD_ETAG = '"v1"'
def _payload(n: int) -> bytes:
return bytes((i * 37 + 11) % 256 for i in range(n))
def _safetensors_payload(total: int) -> bytes:
"""A structurally valid ``.safetensors`` blob of exactly ``total`` bytes.
Success-path tests download to ``.safetensors`` destinations, which the
engine now structurally validates before the atomic rename, so their
payloads must parse as real safetensors (header length + JSON header +
data region whose size matches the declared ``data_offsets``).
"""
def _header(data_len: int) -> bytes:
return json.dumps(
{"w": {"dtype": "U8", "shape": [data_len], "data_offsets": [0, data_len]}}
).encode("utf-8")
# The header's byte length depends on the digit count of ``data_len``, so
# iterate until ``total == 8 + len(header) + data_len`` is self-consistent.
data_len = total - 8 - len(_header(total))
for _ in range(8):
header = _header(data_len)
new_data_len = total - 8 - len(header)
if new_data_len == data_len:
break
data_len = new_data_len
assert data_len >= 0, "total too small for a safetensors payload"
header = _header(data_len)
body = bytes((i * 37 + 11) % 256 for i in range(data_len))
return struct.pack("<Q", len(header)) + header + body
def _range_handler(payload: bytes):
async def handler(request: web.Request) -> web.Response:
rng = request.headers.get("Range")
if rng:
spec = rng.split("=", 1)[1]
s, _, e = spec.partition("-")
start = int(s)
end = int(e) if e else len(payload) - 1
chunk = payload[start : end + 1]
return web.Response(
status=206,
body=chunk,
headers={
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
"Accept-Ranges": "bytes",
"ETag": PAYLOAD_ETAG,
},
)
return web.Response(
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
)
return handler
def _content_disposition_handler(payload: bytes, filename: str):
"""A range-capable server that only reveals its filename via a header.
Models a Civitai-style ``/api/download/...`` endpoint: the URL path has no
extension, and the real filename (hence extension) lives in the response
``Content-Disposition`` header.
"""
async def handler(request: web.Request) -> web.Response:
headers = {
"Accept-Ranges": "bytes",
"ETag": PAYLOAD_ETAG,
"Content-Disposition": f'attachment; filename="{filename}"',
}
rng = request.headers.get("Range")
if rng:
spec = rng.split("=", 1)[1]
s, _, e = spec.partition("-")
start = int(s)
end = int(e) if e else len(payload) - 1
chunk = payload[start : end + 1]
return web.Response(
status=206,
body=chunk,
headers={**headers, "Content-Range": f"bytes {start}-{end}/{len(payload)}"},
)
return web.Response(status=200, body=payload, headers=headers)
return handler
def _noranges_handler(payload: bytes):
async def handler(request: web.Request) -> web.Response:
# Always full body, never advertises Accept-Ranges -> single-stream.
return web.Response(status=200, body=payload)
return handler
def _slow_handler(payload: bytes, chunk: int = 16384, delay: float = 0.01):
async def handler(request: web.Request) -> web.StreamResponse:
resp = web.StreamResponse(
status=200, headers={"Content-Length": str(len(payload))}
)
await resp.prepare(request)
for i in range(0, len(payload), chunk):
await resp.write(payload[i : i + chunk])
await asyncio.sleep(delay)
await resp.write_eof()
return resp
return handler
def _overflow_range_handler(payload: bytes, extra: int = 256 * 1024):
"""A non-conforming 206 server that returns MORE than the requested range."""
async def handler(request: web.Request) -> web.Response:
rng = request.headers.get("Range")
if rng:
spec = rng.split("=", 1)[1]
s, _, e = spec.partition("-")
start = int(s)
end = int(e) if e else len(payload) - 1
# Maliciously overrun: append extra bytes past the requested end.
body = payload[start : end + 1] + bytes(extra)
return web.Response(
status=206,
body=body,
headers={
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
"Accept-Ranges": "bytes",
"ETag": PAYLOAD_ETAG,
},
)
return web.Response(
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
)
return handler
def _short_range_handler(payload: bytes, drop: int = 64 * 1024):
"""A 206 server that returns fewer bytes than requested for later segments.
Simulates a server cleanly closing a range connection early. The response
is internally consistent (Content-Length matches the short body), so the
client sees no error and the segment just ends short, leaving a zero-filled
hole in the preallocated file.
"""
async def handler(request: web.Request) -> web.Response:
rng = request.headers.get("Range")
if rng:
spec = rng.split("=", 1)[1]
s, _, e = spec.partition("-")
start = int(s)
end = int(e) if e else len(payload) - 1
chunk = payload[start : end + 1]
if start > 0 and len(chunk) > drop:
chunk = chunk[:-drop] # truncate a non-first segment
return web.Response(
status=206,
body=chunk,
headers={
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
"Accept-Ranges": "bytes",
"ETag": PAYLOAD_ETAG,
},
)
return web.Response(
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
)
return handler
def _unbounded_handler(total: int, chunk: int = 16384):
"""A 200 stream with no Content-Length / Accept-Ranges (unknown length)."""
async def handler(request: web.Request) -> web.StreamResponse:
resp = web.StreamResponse(status=200)
await resp.prepare(request)
sent = 0
while sent < total:
await resp.write(bytes(min(chunk, total - sent)))
sent += chunk
await resp.write_eof()
return resp
return handler
async def _serve(handler):
app = web.Application()
app.router.add_route("*", "/{name:.*}", handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", 0)
await site.start()
port = site._server.sockets[0].getsockname()[1]
return runner, port
def _insert(model_id: str, url: str, status: str = DownloadStatus.QUEUED) -> tuple[str, str, str]:
final_path, temp_path = paths.resolve_destination(model_id)
download_id = str(uuid.uuid4())
queries.insert_download(
{
"id": download_id,
"url": url,
"model_id": model_id,
"dest_path": final_path,
"temp_path": temp_path,
"status": status,
}
)
return download_id, final_path, temp_path
# ----- single-stream -----
def test_single_stream_download(model_root):
payload = _safetensors_payload(300_000)
async def _run():
await close_session()
runner, port = await _serve(_noranges_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, _temp = _insert("loras/single.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/single.safetensors",
dest_path=final_path, temp_path=_temp,
))
status = await job.run()
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
assert os.path.exists(final_path)
assert open(final_path, "rb").read() == payload
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
# ----- segmented -----
def test_segmented_download(model_root):
payload = _safetensors_payload(4 * 1024 * 1024) # 4 MiB -> multiple segments
async def _run():
await close_session()
runner, port = await _serve(_range_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/seg.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/seg.safetensors",
dest_path=final_path, temp_path=temp,
))
status = await job.run()
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
assert open(final_path, "rb").read() == payload
# More than one segment row was planned.
assert len(queries.list_segments(did)) > 1
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
# ----- deterministic resume from a partial file -----
def test_resume_from_partial(model_root):
payload = _safetensors_payload(512 * 1024) # < 1 MiB -> single segment
async def _run():
await close_session()
runner, port = await _serve(_range_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/resume.safetensors", url)
# Simulate a prior partial: first 200 KiB already written, offset persisted.
prefix = 200 * 1024
os.makedirs(os.path.dirname(temp), exist_ok=True)
with open(temp, "wb") as f:
f.write(payload[:prefix])
queries.update_download(did, bytes_done=prefix, etag=PAYLOAD_ETAG)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/resume.safetensors",
dest_path=final_path, temp_path=temp, etag=PAYLOAD_ETAG,
))
status = await job.run()
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
assert open(final_path, "rb").read() == payload
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
# ----- cancel rollback -----
def test_cancel_rollback(model_root, monkeypatch):
monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False)
payload = _payload(1024 * 1024)
async def _run():
await close_session()
runner, port = await _serve(_slow_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/cancel.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/cancel.safetensors",
dest_path=final_path, temp_path=temp,
))
task = asyncio.ensure_future(job.run())
# Wait until some bytes have been written, then cancel.
for _ in range(200):
await asyncio.sleep(0.01)
if job.state.bytes_done > 0:
break
job.request_cancel()
status = await task
assert status == DownloadStatus.CANCELLED
assert not os.path.exists(temp)
assert not os.path.exists(final_path)
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
# ----- size-bound enforcement (malicious / non-conforming hosts) -----
def test_segment_overflow_aborts(model_root):
"""A 206 returning more than the requested range must not overrun."""
payload = _payload(4 * 1024 * 1024) # large enough to segment
async def _run():
await close_session()
runner, port = await _serve(_overflow_range_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/overflow.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/overflow.safetensors",
dest_path=final_path, temp_path=temp,
))
status = await job.run()
assert status == DownloadStatus.FAILED
assert not os.path.exists(final_path)
assert not os.path.exists(temp)
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
def test_short_segment_fails_closed(model_root):
"""A segment that ends short must fail, not be accepted as complete.
The file is preallocated to total_bytes, so the on-disk size still equals
total even with a zero-filled hole; completeness must be judged per-segment.
"""
payload = _safetensors_payload(4 * 1024 * 1024) # large enough to segment
async def _run():
await close_session()
runner, port = await _serve(_short_range_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/short.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/short.safetensors",
dest_path=final_path, temp_path=temp,
))
status = await job.run()
assert status == DownloadStatus.FAILED, queries.get_download(did).error
assert "incomplete" in (queries.get_download(did).error or "")
assert not os.path.exists(final_path)
assert not os.path.exists(temp)
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
def test_structural_validation_rejects_corrupt(model_root):
"""A correctly sized but structurally invalid file fails closed (not retried).
Regression for the dead structural gate: validation must key off the
destination extension, not the ``.part`` temp suffix.
"""
payload = _payload(300_000) # right size, but not a valid safetensors blob
async def _run():
await close_session()
runner, port = await _serve(_noranges_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/corrupt.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/corrupt.safetensors",
dest_path=final_path, temp_path=temp,
))
status = await job.run()
assert status == DownloadStatus.FAILED, queries.get_download(did).error
assert not os.path.exists(final_path)
assert not os.path.exists(temp)
# Failed closed at first attempt, not re-queued as retryable.
assert queries.get_download(did).attempts == 0
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
def test_rejects_oversized_known_download(model_root, monkeypatch):
"""A file whose advertised size exceeds the cap is rejected at probe."""
monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False)
payload = _payload(300_000)
async def _run():
await close_session()
runner, port = await _serve(_noranges_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/toobig.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/toobig.safetensors",
dest_path=final_path, temp_path=temp,
))
status = await job.run()
assert status == DownloadStatus.FAILED
assert not os.path.exists(final_path)
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
def test_unknown_length_capped_by_max_bytes(model_root, monkeypatch):
"""An unbounded unknown-length stream is capped by --download-max-bytes."""
monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False)
monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False)
async def _run():
await close_session()
runner, port = await _serve(_unbounded_handler(2 * 1024 * 1024))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/unbounded.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/unbounded.safetensors",
dest_path=final_path, temp_path=temp,
))
status = await job.run()
assert status == DownloadStatus.FAILED
assert not os.path.exists(final_path)
assert not os.path.exists(temp)
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
# ----- manager + scheduler end-to-end -----
def test_manager_enqueue_to_completion(model_root):
payload = _safetensors_payload(2 * 1024 * 1024)
async def _run():
await close_session()
from app.model_downloader.manager import DOWNLOAD_MANAGER
runner, port = await _serve(_range_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/e2e.safetensors")
# Wait for completion.
final_path, _ = paths.resolve_destination("loras/e2e.safetensors")
for _ in range(500):
await asyncio.sleep(0.02)
row = queries.get_download(did)
if row.status in DownloadStatus.TERMINAL:
break
row = queries.get_download(did)
assert row.status == DownloadStatus.COMPLETED, row.error
assert open(final_path, "rb").read() == payload
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
def test_manager_rejects_disallowed_url(model_root):
async def _run():
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
with pytest.raises(DownloadError) as ei:
await DOWNLOAD_MANAGER.enqueue(
"https://evil.example.com/x.safetensors", "loras/bad.safetensors"
)
assert ei.value.code == "URL_NOT_ALLOWED"
asyncio.run(_run())
def test_manager_resolves_extensionless_url(model_root):
"""An allowlisted URL with no extension in its path is resolved from the
response, and the stored file adopts the resolved extension."""
payload = _safetensors_payload(1 * 1024 * 1024)
async def _run():
await close_session()
from app.model_downloader.manager import DOWNLOAD_MANAGER
runner, port = await _serve(
_content_disposition_handler(payload, "RealModel.safetensors")
)
try:
# No extension in the path (Civitai-style) and none in the model_id.
url = f"http://127.0.0.1:{port}/api/download/models/12345"
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/my_civitai_model")
row = queries.get_download(did)
# The resolved extension was appended to the model_id + destination.
assert row.model_id == "loras/my_civitai_model.safetensors"
assert row.dest_path.endswith("my_civitai_model.safetensors")
final_path, _ = paths.resolve_destination(
"loras/my_civitai_model.safetensors"
)
for _ in range(500):
await asyncio.sleep(0.02)
row = queries.get_download(did)
if row.status in DownloadStatus.TERMINAL:
break
row = queries.get_download(did)
assert row.status == DownloadStatus.COMPLETED, row.error
assert open(final_path, "rb").read() == payload
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
def test_manager_overrides_extension_from_resolution(model_root):
"""A model_id carrying a different known extension is corrected to match
the resolved URL's extension."""
payload = _safetensors_payload(256 * 1024)
async def _run():
await close_session()
from app.model_downloader.manager import DOWNLOAD_MANAGER
runner, port = await _serve(
_content_disposition_handler(payload, "weights.safetensors")
)
try:
url = f"http://127.0.0.1:{port}/api/download/models/777"
# Caller guessed .ckpt; resolution says .safetensors -> corrected.
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/guessed.ckpt")
row = queries.get_download(did)
assert row.model_id == "loras/guessed.safetensors"
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
def test_manager_rejects_non_model_resolution(model_root):
"""A URL that resolves to a non-model file is rejected, not downloaded."""
async def _run():
await close_session()
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
runner, port = await _serve(
_content_disposition_handler(b"not a model", "installer.zip")
)
try:
url = f"http://127.0.0.1:{port}/api/download/models/999"
with pytest.raises(DownloadError) as ei:
await DOWNLOAD_MANAGER.enqueue(url, "loras/whatever")
assert ei.value.code == "URL_NOT_ALLOWED"
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())

View File

@ -0,0 +1,81 @@
"""Unit tests for the segment planner and structural safetensors validation."""
from __future__ import annotations
import json
import struct
import pytest
from app.model_downloader.engine.planner import (
effective_segment_count,
plan_segments,
)
from app.model_downloader.verify import structural
# ----- planner -----
def test_plan_segments_covers_full_range_contiguously():
total = 1000
plans = plan_segments(total, 4)
assert len(plans) == 4
assert plans[0].start == 0
assert plans[-1].end == total - 1
# contiguous, no gaps/overlaps
for a, b in zip(plans, plans[1:]):
assert b.start == a.end + 1
assert sum(p.length for p in plans) == total
def test_effective_segment_count_falls_back_to_single():
# No range support -> single
assert effective_segment_count(10_000_000, False, 8) == 1
# Unknown size -> single
assert effective_segment_count(None, True, 8) == 1
# Tiny file -> fewer segments than configured
assert effective_segment_count(1024, True, 8) == 1
# Large file with range support -> configured count
assert effective_segment_count(1_000_000_000, True, 8) == 8
# ----- structural -----
def _make_safetensors(tensor_data_len: int, *, corrupt_size: bool = False) -> bytes:
header = {"t": {"dtype": "F32", "shape": [tensor_data_len], "data_offsets": [0, tensor_data_len]}}
header_bytes = json.dumps(header).encode("utf-8")
body = b"\x00" * tensor_data_len
if corrupt_size:
body = body[:-1] # truncate one byte
return struct.pack("<Q", len(header_bytes)) + header_bytes + body
def test_structural_valid_safetensors(tmp_path):
p = tmp_path / "ok.safetensors"
p.write_bytes(_make_safetensors(256))
structural.validate(str(p)) # no raise
def test_structural_detects_truncation(tmp_path):
p = tmp_path / "bad.safetensors"
p.write_bytes(_make_safetensors(256, corrupt_size=True))
with pytest.raises(structural.StructuralError):
structural.validate(str(p))
def test_structural_skips_unknown_extension(tmp_path):
p = tmp_path / "weights.bin"
p.write_bytes(b"anything")
structural.validate(str(p)) # no structural check, no raise
def test_structural_detects_truncation_via_name_hint(tmp_path):
# The downloader validates the opaque temp file (a ``.part`` path) but keys
# the format check off the final destination name via ``name_hint``, so
# truncation must still be detected instead of silently skipped.
p = tmp_path / "bad.comfy-download.part"
p.write_bytes(_make_safetensors(256, corrupt_size=True))
with pytest.raises(structural.StructuralError):
structural.validate(str(p), name_hint="model.safetensors")

View File

@ -0,0 +1,231 @@
"""Unit tests for the security layer: allowlist, SSRF checks, path safety."""
from __future__ import annotations
import pytest
from app.model_downloader.security import allowlist, paths
from app.model_downloader.security.ssrf import (
SSRFError,
check_redirect_hop,
is_blocked_ip,
)
# ----- allowlist -----
@pytest.mark.parametrize(
"url,allowed",
[
("https://huggingface.co/org/repo/resolve/main/model.safetensors", True),
("https://civitai.com/api/download/x/model.safetensors", True),
("http://localhost/model.safetensors", True),
# off-list host
("https://evil.example.com/model.safetensors", False),
# http to a non-loopback allowlisted host is not permitted (https only)
("http://huggingface.co/org/repo/resolve/main/model.safetensors", False),
# bad extension on an allowed host
("https://huggingface.co/org/repo/resolve/main/config.json", False),
# userinfo trick: real host is the metadata IP, not 127.0.0.1
("http://127.0.0.1@169.254.169.254/x.safetensors", False),
],
)
def test_is_url_allowed(url, allowed):
assert allowlist.is_url_allowed(url) is allowed
def test_allow_any_extension_relaxes_extension_only():
url = "https://huggingface.co/org/repo/resolve/main/weights.bin"
assert allowlist.is_url_allowed(url) is True # .bin is in the known set
odd = "https://huggingface.co/org/repo/resolve/main/weights.zip"
assert allowlist.is_url_allowed(odd) is False
assert allowlist.is_url_allowed(odd, allow_any_extension=True) is True
@pytest.mark.parametrize(
"url,downloadable",
[
# known model extension in the path -> allowed
("https://civitai.com/x/model.safetensors", True),
# no extension in the path (Civitai download API) -> allowed, resolved later
("https://civitai.com/api/download/models/3031464?fileId=2910346", True),
("https://civitai.com/api/download/models/3031464", True),
# explicit non-model extension -> rejected even on an allowed host
("https://civitai.com/api/download/models/thing.zip", False),
("https://huggingface.co/org/repo/resolve/main/config.json", False),
# off-list host is never downloadable
("https://evil.example.com/api/download/models/1", False),
# http to a non-loopback allowlisted host is not permitted
("http://civitai.com/api/download/models/1", False),
],
)
def test_is_url_downloadable(url, downloadable):
assert allowlist.is_url_downloadable(url) is downloadable
@pytest.mark.parametrize(
"name,ext",
[
("model.safetensors", ".safetensors"),
("model.SAFETENSORS", ".safetensors"),
("archive.tar.gz", ".gz"),
("noext", ""),
(".safetensors", ""), # leading-dot dotfile -> no extension
("a/b/c/model.ckpt", ".ckpt"),
],
)
def test_filename_extension(name, ext):
assert allowlist.filename_extension(name) == ext
# ----- SSRF: blocked IPs -----
@pytest.mark.parametrize(
"ip,blocked",
[
("169.254.169.254", True), # cloud metadata / link-local
("127.0.0.1", True),
("10.0.0.5", True),
("192.168.1.1", True),
("172.16.0.1", True),
("::1", True),
("0.0.0.0", True),
# IPv4-mapped IPv6: must see through the mapping even on CPython
# versions predating the gh-113171 is_* property fix.
("::ffff:169.254.169.254", True), # mapped cloud metadata
("::ffff:127.0.0.1", True), # mapped loopback
("::ffff:10.0.0.1", True), # mapped RFC1918
("::ffff:8.8.8.8", False), # mapped public address stays allowed
("8.8.8.8", False),
("1.1.1.1", False),
("not-an-ip", True), # unparseable -> refuse
],
)
def test_is_blocked_ip(ip, blocked):
assert is_blocked_ip(ip) is blocked
# ----- SSRF: redirect hop validation -----
def test_check_redirect_hop_rejects_bad_scheme_and_userinfo():
with pytest.raises(SSRFError):
check_redirect_hop("ftp://huggingface.co/x.safetensors")
with pytest.raises(SSRFError):
check_redirect_hop("https://user:pass@cdn.example.com/x")
# A CDN host that is NOT on the allowlist is allowed as a redirect target
# (private-IP protection is the resolver's job; credential leak is prevented
# by exact host matching).
assert check_redirect_hop("https://cdn-lfs.huggingface.co/abc") is not None
def test_check_redirect_hop_http_only_for_loopback():
# Plain http to an external host is rejected (no plaintext downgrade).
with pytest.raises(SSRFError):
check_redirect_hop("http://cdn-lfs.huggingface.co/abc")
# http is honored for loopback only on the initial user-supplied URL (the
# "download a local model" feature).
assert (
check_redirect_hop("http://localhost/x.safetensors", is_initial_url=True)
is not None
)
assert (
check_redirect_hop("http://127.0.0.1/x.safetensors", is_initial_url=True)
is not None
)
def test_check_redirect_hop_blocks_loopback_and_ip_literals_on_redirect():
# A redirect (is_initial_url=False, the default) must never reach loopback,
# whether by hostname or by IP literal, nor any other internal IP literal.
for target in (
"http://localhost/x.safetensors",
"http://127.0.0.1/x.safetensors",
"https://[::1]/x.safetensors",
"https://169.254.169.254/x.safetensors", # cloud metadata
"https://10.0.0.5/x.safetensors", # RFC1918
):
with pytest.raises(SSRFError):
check_redirect_hop(target)
# Off-allowlist public CDN hosts (hostnames) remain valid redirect targets;
# their resolved IPs are screened by the connector's resolver.
assert check_redirect_hop("https://cdn-lfs.huggingface.co/abc") is not None
# ----- path safety -----
def test_parse_model_id_valid(model_root):
directory, filename = paths.parse_model_id("loras/my_lora.safetensors")
assert directory == "loras"
assert filename == "my_lora.safetensors"
@pytest.mark.parametrize(
"model_id",
[
"loras/../etc/passwd.safetensors", # traversal
"loras/sub/dir.safetensors", # nested
"unknownfolder/x.safetensors", # unknown folder
"loras/model.txt", # bad extension
"noslash.safetensors", # missing directory
"loras/", # empty filename
],
)
def test_parse_model_id_rejects(model_root, model_id):
with pytest.raises(paths.InvalidModelId):
paths.parse_model_id(model_id)
def test_resolve_destination_stays_in_root(model_root):
final_path, temp_path = paths.resolve_destination("loras/x.safetensors")
assert final_path.startswith(model_root)
assert temp_path.startswith(model_root)
assert temp_path != final_path
@pytest.mark.parametrize(
"model_id,ext,expected",
[
# no extension -> append the resolved one
("loras/my_civitai_model", ".safetensors", "loras/my_civitai_model.safetensors"),
# different known extension -> replace it
("loras/mymodel.ckpt", ".safetensors", "loras/mymodel.safetensors"),
# same extension -> unchanged
("loras/mymodel.safetensors", ".safetensors", "loras/mymodel.safetensors"),
# non-model suffix is treated as a stem, extension appended
("loras/my.model.v2", ".safetensors", "loras/my.model.v2.safetensors"),
# malformed (no slash) is returned untouched for parse_model_id to reject
("noslash", ".safetensors", "noslash"),
],
)
def test_apply_extension(model_id, ext, expected):
assert paths.apply_extension(model_id, ext) == expected
# ----- Content-Disposition filename parsing -----
@pytest.mark.parametrize(
"header,expected",
[
('attachment; filename="model.safetensors"', "model.safetensors"),
("attachment; filename=model.ckpt", "model.ckpt"),
# RFC 5987 form is preferred and percent-decoded
(
"attachment; filename=\"fallback.bin\"; filename*=UTF-8''my%20model.safetensors",
"my model.safetensors",
),
# directory components in a hostile header are stripped to the basename
('attachment; filename="../../etc/passwd"', "passwd"),
('attachment; filename="a\\\\b\\\\model.pt"', "model.pt"),
("inline", None),
(None, None),
],
)
def test_filename_from_content_disposition(header, expected):
from app.model_downloader.net.http import filename_from_content_disposition
assert filename_from_content_disposition(header) == expected

View File

@ -1,192 +0,0 @@
"""CI unit tests for FIX #2 of GHSA-779p-m5rp-r4h4.
Path traversal / hardening in app/model_manager.py get_model_preview
(route /experiment/models/preview/{folder}/{path_index}/{filename:.*}).
Reference: https://github.com/Comfy-Org/ComfyUI/security/advisories/GHSA-779p-m5rp-r4h4
"""
import pytest
import yarl
from io import BytesIO
from PIL import Image
from aiohttp import web
from unittest.mock import patch
from app.model_manager import ModelFileManager
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def model_manager():
return ModelFileManager()
@pytest.fixture
def app(model_manager):
app = web.Application()
routes = web.RouteTableDef()
model_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_legit_preview_returns_200(aiohttp_client, app, tmp_path):
"""Sanity: a real preview PNG inside the model folder is served as webp 200."""
img = Image.new('RGB', (16, 16), color=(255, 0, 128))
img.save(tmp_path / "test_model.png", format='PNG')
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/0/test_model.png')
assert response.status == 200
assert response.content_type == 'image/webp'
img_bytes = BytesIO(await response.read())
served = Image.open(img_bytes)
assert served.format
assert served.format.lower() == 'webp'
served.close()
async def test_non_integer_path_index_returns_400(aiohttp_client, app, tmp_path):
"""A non-integer path_index segment must be rejected with 400."""
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/abc/test_model.png')
assert response.status == 400
async def test_out_of_range_path_index_returns_404(aiohttp_client, app, tmp_path):
"""A path_index beyond the configured folder list must return 404."""
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/99/test_model.png')
assert response.status == 404
async def test_empty_filename_returns_400(aiohttp_client, app, tmp_path):
"""The "{filename:.*}" capture also matches the empty string (trailing
slash). It would resolve to the folder itself and must be rejected with 400."""
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/0/')
assert response.status == 400
async def test_path_traversal_in_filename_returns_403(aiohttp_client, app, tmp_path):
"""Path traversal in {filename} must be rejected with 403 and must NOT read
a file outside the configured model directory.
GOTCHA: aiohttp/yarl collapses literal ``../`` dot-segments out of the URL
path before it reaches the handler, which would make this test vacuously
pass (the request would hit a different/non-existent route). We percent-encode
the dots and slashes (``%2e%2e%2f``) and send the URL with
``yarl.URL(..., encoded=True)`` so the bytes survive client-side normalization
untouched; aiohttp's router then percent-decodes them into ``match_info``,
delivering the literal ``../`` traversal to the handler's ``{filename:.*}``
capture.
Without the fix the handler computes
``os.path.normpath(os.path.join(folder, "../../../../etc/hosts"))``, which
escapes ``tmp_path`` and would be passed straight to get_model_previews ->
Image.open, serving bytes from outside the model dir (200/served bytes). The
is_within_directory() containment check is the load-bearing fix that turns
that escape into a 403.
"""
# Sanity-anchor: a legit preview exists inside tmp_path, so a 200 path is
# genuinely reachable — proving the 403 below is the containment check
# firing, not an unrelated 404.
img = Image.new('RGB', (16, 16), color=(255, 0, 128))
img.save(tmp_path / "test_model.png", format='PNG')
# Percent-encoded "../../../../etc/hosts" so yarl does not collapse the
# dot-segments before the request leaves the client.
encoded_traversal = '%2e%2e%2f' * 4 + 'etc%2fhosts'
raw_path = '/experiment/models/preview/test_folder/0/' + encoded_traversal
url = yarl.URL(raw_path, encoded=True)
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get(url)
# Confirm the traversal actually reached the handler intact: a 200 here
# would mean either normalization stripped the ``../`` (vacuous pass) or
# the containment check failed open and served outside-dir bytes.
assert response.status == 403, (
f"expected 403 from is_within_directory() containment check, "
f"got {response.status}; traversal may have been normalized away "
f"or the fix failed open"
)
body = await response.read()
assert body == b"", "403 response must not carry any file bytes"
async def test_symlink_companion_preview_returns_403(aiohttp_client, app, tmp_path):
"""A companion preview file is selected by a glob inside get_model_previews
and then opened. If that companion is a symlink whose path is in-dir but
whose target escapes the model folder, it must be rejected with 403 — not
served. The requested path itself stays in-dir (so the first containment
check passes); the load-bearing fix is the SECOND is_within_directory check
on the file actually opened.
"""
model_dir = tmp_path / "models"
model_dir.mkdir()
secret_dir = tmp_path / "secret"
secret_dir.mkdir()
# A real image OUTSIDE the model dir — valid, so without the fix Image.open
# would succeed and its bytes would be served (200).
secret = secret_dir / "secret.png"
Image.new('RGB', (8, 8), color=(0, 0, 0)).save(secret, format='PNG')
# Companion preview, in-dir by name but a symlink escaping the model dir.
# (No real model file is needed — get_model_previews globs companions by
# basename, and omitting a .safetensors avoids the metadata-header read.)
companion = model_dir / "model.preview.png"
try:
companion.symlink_to(secret)
except (OSError, NotImplementedError):
pytest.skip("symlinks not supported on this platform/filesystem")
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(model_dir)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/0/model.safetensors')
assert response.status == 403, (
f"expected 403 — the globbed companion preview is a symlink resolving "
f"outside the model dir and must not be served; got {response.status}"
)
assert await response.read() == b""
async def test_null_byte_in_filename_no_500(aiohttp_client, app, tmp_path):
"""A NUL byte in the filename must yield a clean client rejection, not a 500
from an uncaught ValueError in is_within_directory's realpath() call."""
raw_path = '/experiment/models/preview/test_folder/0/' + 'a%00b'
url = yarl.URL(raw_path, encoded=True)
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get(url)
assert response.status != 500, (
f"NUL byte produced a 500 (uncaught ValueError); expected a clean "
f"4xx rejection, got {response.status}"
)
assert 400 <= response.status < 500

View File

@ -1,165 +0,0 @@
"""Security tests for GHSA-779p-m5rp-r4h4 — FIX #3.
Path traversal in folder_paths.get_annotated_filepath / exists_annotated_filepath,
plus the shared is_within_directory() containment helper.
These are pure-function tests (no running server). The input/output/temp
directories are pointed at tmp_path via the folder_paths setters, so a crafted
name containing `../`, an absolute path, or a symlink that escapes the base
directory must be rejected.
Reference: https://github.com/Comfy-Org/ComfyUI/security/advisories/GHSA-779p-m5rp-r4h4
"""
import os
import pytest
import folder_paths
from comfy.options import enable_args_parsing
enable_args_parsing()
@pytest.fixture
def sandbox(tmp_path):
"""Point folder_paths' input/output/temp dirs at a real temp sandbox.
Yields the realpath'd base, input, output and temp directories. The original
directory values are restored afterward so tests stay isolated.
"""
base = os.path.realpath(str(tmp_path))
input_dir = os.path.join(base, "input")
output_dir = os.path.join(base, "output")
temp_dir = os.path.join(base, "temp")
for d in (input_dir, output_dir, temp_dir):
os.makedirs(d, exist_ok=True)
orig_input = folder_paths.get_input_directory()
orig_output = folder_paths.get_output_directory()
orig_temp = folder_paths.get_temp_directory()
folder_paths.set_input_directory(input_dir)
folder_paths.set_output_directory(output_dir)
folder_paths.set_temp_directory(temp_dir)
yield {
"base": base,
"input": input_dir,
"output": output_dir,
"temp": temp_dir,
}
folder_paths.set_input_directory(orig_input)
folder_paths.set_output_directory(orig_output)
folder_paths.set_temp_directory(orig_temp)
# ---------------------------------------------------------------------------
# is_within_directory() — the shared containment helper
# ---------------------------------------------------------------------------
def test_is_within_directory_legit_child(sandbox):
base = sandbox["input"]
child = os.path.join(base, "sub", "image.png")
assert folder_paths.is_within_directory(base, child) is True
def test_is_within_directory_dotdot_escape(sandbox):
base = sandbox["input"]
escape = os.path.join(base, "..", "..", "etc", "passwd")
assert folder_paths.is_within_directory(base, escape) is False
def test_is_within_directory_symlink_escape(sandbox):
"""A symlink created INSIDE base that points OUTSIDE base must not pass.
This is the key new hardening: is_within_directory realpath()s both operands,
so a symlink planted in the base directory can't be used to read files
elsewhere. We create a real on-disk symlink and a real secret target to
verify the check actually resolves the link.
"""
base = sandbox["input"]
# A directory living outside the base, holding a secret file.
outside = os.path.join(sandbox["base"], "outside_secret_dir")
os.makedirs(outside, exist_ok=True)
secret = os.path.join(outside, "secret.txt")
with open(secret, "w") as f:
f.write("top secret")
# Plant a symlink inside base that points at the outside directory.
# symlink creation can require elevated privileges / Developer Mode on
# Windows, so skip cleanly where it isn't available (same guard as the
# sibling test in test_ghsa_779p_02_preview_traversal.py).
link = os.path.join(base, "escape_link")
try:
os.symlink(outside, link)
except (OSError, NotImplementedError):
pytest.skip("symlinks not supported on this platform/filesystem")
# Accessing the secret "through" the in-base symlink must be rejected.
target_via_link = os.path.join(link, "secret.txt")
assert folder_paths.is_within_directory(base, target_via_link) is False
# ---------------------------------------------------------------------------
# get_annotated_filepath()
# ---------------------------------------------------------------------------
def test_get_annotated_filepath_legit_name(sandbox):
result = folder_paths.get_annotated_filepath("image.png")
assert result == os.path.join(sandbox["input"], "image.png")
assert folder_paths.is_within_directory(sandbox["input"], result)
def test_get_annotated_filepath_input_annotation(sandbox):
result = folder_paths.get_annotated_filepath("image.png [input]")
assert result == os.path.join(sandbox["input"], "image.png")
def test_get_annotated_filepath_output_annotation(sandbox):
result = folder_paths.get_annotated_filepath("image.png [output]")
assert result == os.path.join(sandbox["output"], "image.png")
def test_get_annotated_filepath_temp_annotation(sandbox):
result = folder_paths.get_annotated_filepath("image.png [temp]")
assert result == os.path.join(sandbox["temp"], "image.png")
def test_get_annotated_filepath_dotdot_raises(sandbox):
with pytest.raises(ValueError):
folder_paths.get_annotated_filepath("../etc/passwd")
def test_get_annotated_filepath_dotdot_with_annotation_raises(sandbox):
with pytest.raises(ValueError):
folder_paths.get_annotated_filepath("../../etc/passwd [output]")
def test_get_annotated_filepath_absolute_escape_raises(sandbox):
with pytest.raises(ValueError):
folder_paths.get_annotated_filepath("/etc/passwd")
# ---------------------------------------------------------------------------
# exists_annotated_filepath()
# ---------------------------------------------------------------------------
def test_exists_annotated_filepath_existing_legit_file(sandbox):
real = os.path.join(sandbox["input"], "real.png")
with open(real, "w") as f:
f.write("data")
assert folder_paths.exists_annotated_filepath("real.png") is True
def test_exists_annotated_filepath_traversal_returns_false(sandbox):
"""A traversal name must return False without raising and without probing
outside the base directory (must never reach os.path.exists for the escape).
"""
# /etc/passwd exists on POSIX; the function must still report False because
# the resolved path escapes the input directory.
assert folder_paths.exists_annotated_filepath("../../../../../../etc/passwd") is False
def test_exists_annotated_filepath_absolute_returns_false(sandbox):
assert folder_paths.exists_annotated_filepath("/etc/passwd") is False

View File

@ -1,147 +0,0 @@
"""
CI unit tests for FIX #4 of GHSA-779p-m5rp-r4h4.
Stored-XSS hardening on GET /userdata/{file} in app/user_manager.py.
User data files are arbitrary user-supplied content and must never render
inline in the app origin. The getuserdata handler:
- forces Content-Type to application/octet-stream for any type in
folder_paths.DANGEROUS_CONTENT_TYPES (text/html, image/svg+xml,
text/javascript, ...),
- sets X-Content-Type-Options: nosniff,
- sets Content-Disposition: attachment.
These tests pre-create files in tmp_path and GET them back, asserting the
secure response headers. They mirror the aiohttp_client pattern in
tests-unit/prompt_server_test/user_manager_test.py.
"""
import pytest
import os
from aiohttp import web
from app.user_manager import UserManager
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def user_manager(tmp_path):
um = UserManager()
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
tmp_path, file
) if file else tmp_path
return um
@pytest.fixture
def app(user_manager):
app = web.Application()
routes = web.RouteTableDef()
user_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_html_served_as_octet_stream(aiohttp_client, app, tmp_path):
(tmp_path / "evil.html").write_text(
"<script>console.log('xss-marker-ghsa-779p')</script>"
)
client = await aiohttp_client(app)
resp = await client.get("/userdata/evil.html")
assert resp.status == 200
ct = resp.headers.get("Content-Type", "")
# The load-bearing assertion: a .html file must NOT be served as text/html.
assert "text/html" not in ct.lower(), (
f"Content-Type {ct!r} would let a browser render/execute the file (stored XSS)."
)
assert ct == "application/octet-stream"
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
assert "attachment" in resp.headers.get("Content-Disposition", "")
async def test_svg_served_as_octet_stream(aiohttp_client, app, tmp_path):
(tmp_path / "evil.svg").write_text(
'<?xml version="1.0"?>'
'<svg xmlns="http://www.w3.org/2000/svg">'
'<script>console.log("xss-marker-ghsa-779p")</script>'
"</svg>"
)
client = await aiohttp_client(app)
resp = await client.get("/userdata/evil.svg")
assert resp.status == 200
ct = resp.headers.get("Content-Type", "")
# SVG can carry inline <script>; it must not be served as image/svg+xml.
assert "svg" not in ct.lower(), (
f"Content-Type {ct!r} would let a browser render the SVG and execute embedded scripts."
)
assert ct == "application/octet-stream"
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
assert "attachment" in resp.headers.get("Content-Disposition", "")
async def test_js_served_as_octet_stream(aiohttp_client, app, tmp_path):
(tmp_path / "evil.js").write_text("alert('xss-marker-ghsa-779p')")
client = await aiohttp_client(app)
resp = await client.get("/userdata/evil.js")
assert resp.status == 200
ct = resp.headers.get("Content-Type", "").lower()
# Must not be served as any executable JavaScript content type.
assert "javascript" not in ct, (
f"Content-Type {ct!r} is an executable JS type."
)
assert "ecmascript" not in ct, (
f"Content-Type {ct!r} is an executable JS type."
)
assert ct == "application/octet-stream"
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
assert "attachment" in resp.headers.get("Content-Disposition", "")
async def test_xml_dialect_served_as_octet_stream(aiohttp_client, app, tmp_path):
"""An XML dialect outside the original blocklist (.xslt -> application/xslt+xml)
must still be forced to download. This pins the normalised *+xml family rule
in folder_paths.is_dangerous_content_type(); a plain set-membership test would
have served this inline."""
(tmp_path / "evil.xslt").write_text(
'<?xml version="1.0"?>'
'<xsl:stylesheet version="1.0" '
'xmlns:xsl="http://www.w3.org/1999/XSL/Transform">'
"<!-- xss-marker-ghsa-779p -->"
"</xsl:stylesheet>"
)
client = await aiohttp_client(app)
resp = await client.get("/userdata/evil.xslt")
assert resp.status == 200
ct = resp.headers.get("Content-Type", "")
assert ct == "application/octet-stream", (
f"Content-Type {ct!r}: an *+xml dialect must be forced to octet-stream "
f"(it can carry inline script via stylesheet/entity tricks)."
)
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
assert "attachment" in resp.headers.get("Content-Disposition", "")
async def test_benign_txt_still_served(aiohttp_client, app, tmp_path):
(tmp_path / "note.txt").write_text("just a harmless note")
client = await aiohttp_client(app)
resp = await client.get("/userdata/note.txt")
assert resp.status == 200
assert await resp.text() == "just a harmless note"
ct = resp.headers.get("Content-Type", "")
# text/plain is not in the dangerous set, so it is acceptable here. The
# defence-in-depth headers must still be present regardless.
assert "text/plain" in ct.lower()
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
assert "attachment" in resp.headers.get("Content-Disposition", "")

View File

@ -1,138 +0,0 @@
"""CI unit guard for FIX #5 of GHSA-779p-m5rp-r4h4 — the /view forced-download set.
Vuln #5 was stored XSS via SVG upload: the /view endpoint's Content-Type
blocklist covered text/html, text/javascript, etc. but was missing
image/svg+xml, so an uploaded SVG carrying an inline <script> was served as
image/svg+xml and executed in the page origin when rendered.
The /view forced-download decision lives in the view_image closure registered by
server.PromptServer.add_routes (server.py ~line 596), which calls
`folder_paths.is_dangerous_content_type(content_type)` — a normalising check that
strips charset/boundary parameters and casing and folds in the whole */xml and
*+xml dialect family — rather than a bypassable raw
`content_type in folder_paths.DANGEROUS_CONTENT_TYPES` membership test. On a match
it rewrites the response to application/octet-stream with a
Content-Disposition: attachment header. server.py cannot be imported in a unit
test (importing it spins up the full PromptServer/aiohttp app and its global side
effects), so these tests pin the underlying dangerous-content data
(folder_paths.DANGEROUS_CONTENT_TYPES) and the normalising is_dangerous_content_type()
helper that the closure actually calls.
The end-to-end /view assertion (upload an SVG, GET /view, confirm the response
is not served as image/svg+xml) lives in the live POC at
.security/pocs/test_security_ghsa_779p.py::TestViewSvgContentType, which
requires a running server. This file is the fast, server-free CI guard on the
set contents so the blocklist can't silently regress.
"""
import folder_paths
# Active/renderable content types that must be forced to download. Each of these
# can carry an inline <script> (or otherwise execute) in the page origin if a
# browser renders it. image/svg+xml is the original missing item that caused
# vuln #5.
DANGEROUS = [
'image/svg+xml',
'application/xml',
'text/xml',
'text/html',
'text/html-sandboxed',
'application/xhtml+xml',
'text/javascript',
'application/javascript',
'application/x-javascript',
'application/ecmascript',
'text/css',
]
# Benign image types that browsers display inline and that must keep rendering;
# forcing these to download would break legitimate previews.
BENIGN_INLINE_IMAGES = [
'image/png',
'image/jpeg',
'image/webp',
'image/gif',
]
def test_dangerous_content_types_is_a_set():
assert isinstance(folder_paths.DANGEROUS_CONTENT_TYPES, set)
def test_svg_is_in_the_blocklist():
"""The specific item whose absence caused vuln #5."""
assert 'image/svg+xml' in folder_paths.DANGEROUS_CONTENT_TYPES, (
"image/svg+xml missing from DANGEROUS_CONTENT_TYPES — this is exactly "
"the regression that reopens GHSA-779p-m5rp-r4h4 vuln #5 (stored XSS "
"via SVG upload on /view)."
)
def test_all_dangerous_types_present():
missing = [ct for ct in DANGEROUS if ct not in folder_paths.DANGEROUS_CONTENT_TYPES]
assert not missing, (
f"DANGEROUS_CONTENT_TYPES is missing required active/renderable types: "
f"{missing}. The /view closure only forces a download for content types "
f"in this set; anything missing here is served inline and can execute."
)
def test_benign_inline_image_types_absent():
leaked = [ct for ct in BENIGN_INLINE_IMAGES if ct in folder_paths.DANGEROUS_CONTENT_TYPES]
assert not leaked, (
f"Benign inline-displayable image types found in DANGEROUS_CONTENT_TYPES: "
f"{leaked}. Forcing these to download would break legitimate image "
f"previews in /view — they must keep rendering inline."
)
# ---------------------------------------------------------------------------
# is_dangerous_content_type() — the normalising check the /view and /userdata
# handlers now call instead of a raw `in DANGEROUS_CONTENT_TYPES` membership
# test. An exact-string membership test was bypassable with a charset parameter
# or odd casing, and missed the wider XML dialect family; these tests pin the
# normalisation so that bypass can't reopen.
# ---------------------------------------------------------------------------
def test_function_matches_plain_dangerous_types():
for ct in DANGEROUS:
assert folder_paths.is_dangerous_content_type(ct) is True, ct
def test_function_strips_parameters_and_casing():
"""A charset/boundary parameter or casing must not slip a type past the check.
This is the bypass surfaced by review: the /view blake3 branch can serve an
attacker-controlled, unvalidated asset mime_type like 'text/html; charset=utf-8',
which an exact-string set test missed.
"""
for ct in (
'text/html; charset=utf-8',
'TEXT/HTML',
'Text/HTML; charset=UTF-8',
'image/svg+xml; charset=utf-8',
' text/html ',
):
assert folder_paths.is_dangerous_content_type(ct) is True, ct
def test_function_covers_xml_dialect_family():
"""Any *+xml / */xml dialect is dangerous without enumerating each one."""
for ct in (
'application/xslt+xml',
'application/rss+xml',
'application/atom+xml',
'application/rdf+xml',
'application/mathml+xml',
'message/rfc822',
):
assert folder_paths.is_dangerous_content_type(ct) is True, ct
def test_function_allows_benign_and_empty():
for ct in BENIGN_INLINE_IMAGES + ['application/octet-stream', 'text/plain']:
assert folder_paths.is_dangerous_content_type(ct) is False, ct
# None / empty (mimetypes.guess_type miss) must not be treated as dangerous.
assert folder_paths.is_dangerous_content_type(None) is False
assert folder_paths.is_dangerous_content_type('') is False