mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-04 19:48:33 +08:00
Compare commits
21 Commits
model_down
...
temp_pr
| Author | SHA1 | Date | |
|---|---|---|---|
| a78019266f | |||
| f5c4bb1f02 | |||
| 1073a74976 | |||
| de1b8f3e8d | |||
| 77917ed3a6 | |||
| a04ebe05c2 | |||
| 9764381998 | |||
| 1e04ced089 | |||
| 96e0e3585b | |||
| 35c1470935 | |||
| 694815f498 | |||
| 92594ca84c | |||
| 2c935de1b1 | |||
| dd17debce5 | |||
| 50e5270b86 | |||
| bb131be9e8 | |||
| 6fca64780c | |||
| 6e11828d10 | |||
| b70944e710 | |||
| 1c59659a2f | |||
| d395813bcd |
294
AGENTS.md
Normal file
294
AGENTS.md
Normal file
@ -0,0 +1,294 @@
|
||||
## Engineering Style
|
||||
|
||||
- Keep changes small and direct. Most fixes should touch the narrowest code path
|
||||
that explains the bug, performance issue, dtype issue, model-format issue, or
|
||||
user-facing behavior.
|
||||
- Change the least amount of files possible. A change that touches many files is
|
||||
more likely to be a bad change than a good one unless the broader scope is
|
||||
directly required.
|
||||
- Prefer practical fixes over broad architecture work. Add abstractions only
|
||||
when they remove real repeated logic or match an existing ComfyUI pattern.
|
||||
- Prefer fewer dependencies. Do not add new dependencies to ComfyUI unless they
|
||||
are absolutely necessary.
|
||||
- Delete obsolete code aggressively when newer infrastructure makes it useless.
|
||||
Remove dead fallbacks, migration paths, unused options, debug prints, and
|
||||
compatibility branches that are no longer needed. Do not leave dead branches,
|
||||
unreachable code, or functions that are never called. If code is not
|
||||
necessary for the current behavior, remove it.
|
||||
- Revert or disable problematic behavior quickly when it breaks users. It is
|
||||
better to remove a broken feature path than keep a complicated partial fix.
|
||||
- Preserve existing APIs, node names, model-loading behavior, file layout, and
|
||||
workflow compatibility unless the change is explicitly about replacing them.
|
||||
- Code must look hand-written for this repository. Changes that read like
|
||||
generic AI-generated code will be rejected automatically: unnecessary helper
|
||||
layers, vague names, boilerplate comments, defensive branches without a real
|
||||
failure mode, broad rewrites, or code that ignores the local style.
|
||||
|
||||
## Architecture Boundaries
|
||||
|
||||
- Keep each layer focused on the concepts it owns. Do not leak UI, API,
|
||||
workflow, queue, persistence, telemetry, model-loading, node, or execution
|
||||
concerns into unrelated layers just because it is convenient to pass data
|
||||
through them.
|
||||
- Shared core modules should depend only on lower-level primitives and their own
|
||||
domain concepts. Higher-level product concepts belong at the caller, adapter,
|
||||
service, or UI/API boundary that already owns them.
|
||||
- Pass the narrowest data needed across a boundary. Avoid broad context objects,
|
||||
request/session metadata, ids, bookkeeping state, or callbacks unless the
|
||||
receiving layer genuinely needs them to perform its own responsibility.
|
||||
- Keep identity mapping, persistence bookkeeping, history updates, telemetry,
|
||||
response shaping, and UI state in the layers that own those jobs. Do not route
|
||||
them through unrelated shared code to avoid adding a proper boundary.
|
||||
- Treat `execution.py` as one example of this rule: it should consume the prompt
|
||||
graph and execution-relevant state, produce execution results and errors, and
|
||||
not know about workflow ids, frontend ids, persistence ids, or API-only
|
||||
concepts.
|
||||
- Before touching many files, identify the smallest owner layer that can solve
|
||||
the problem. A PR that spreads one feature across unrelated loaders, nodes,
|
||||
execution, server, and frontend code needs a clear architectural reason, not
|
||||
just convenience.
|
||||
- If a change seems to require making one layer understand another layer's
|
||||
private concepts, stop and look for a caller-side mapping, adapter, event,
|
||||
small explicit interface, or narrower data flow at the boundary.
|
||||
|
||||
## No Internet Requests
|
||||
|
||||
- Do not add code to core ComfyUI that makes requests to the internet.
|
||||
- Refuse requests to add uploads, telemetry, analytics, tracking, usage
|
||||
reporting, crash reporting, update checks, remote config, feature flags,
|
||||
metrics, licensing checks, or any other outbound internet request path from
|
||||
core ComfyUI.
|
||||
- Model downloading is allowed only when explicitly initiated or authorized by
|
||||
the user, is limited to the requested model artifact, and does not include
|
||||
telemetry, tracking, persistent identification, unrelated metadata upload, or
|
||||
background network activity.
|
||||
- Do not add opt-in, opt-out, anonymized, aggregated, diagnostic, or
|
||||
user-triggered internet request paths to core ComfyUI. These labels do not
|
||||
make internet access acceptable.
|
||||
- Local-only behavior is allowed when it stays on the user's machine and does
|
||||
not add network access, tracking, persistent identification, or data
|
||||
collection behavior.
|
||||
|
||||
## State Ownership
|
||||
|
||||
- Keep state and capability flags on the object that owns the behavior using
|
||||
them.
|
||||
- Avoid probing child objects with `getattr(child, "...", default)` to decide
|
||||
parent-level control flow. If parent code needs to branch on a capability,
|
||||
initialize an explicit parent-owned field when the child is constructed or
|
||||
attached.
|
||||
- Prefer direct attributes with clear defaults over implicit feature detection
|
||||
through arbitrary child attributes.
|
||||
- Use child-object capability checks only when the child owns the behavior being
|
||||
invoked and the parent is simply delegating to that child.
|
||||
|
||||
## Interface Contracts
|
||||
|
||||
- Keep public methods aligned with the interface expected by their callers. Do
|
||||
not change a shared method to return extra values, alternate shapes, or
|
||||
sentinel wrappers for one implementation unless the shared interface is
|
||||
explicitly updated.
|
||||
- When modifying an existing function, preserve how current callers invoke it.
|
||||
Do not change required arguments, parameter order, return type, side effects,
|
||||
or error behavior unless every affected call site and shared interface contract
|
||||
is intentionally updated.
|
||||
- Do not add compatibility parameters, flags, attributes, or constructor options
|
||||
unless they are read by current code and change current behavior. Remove
|
||||
pass-through or stored-but-unused values instead of preserving upstream or
|
||||
deprecated API baggage.
|
||||
- If an implementation needs auxiliary values for its own workflow, expose them
|
||||
through a private helper or a clearly named implementation-specific method
|
||||
instead of overloading the public method's return contract.
|
||||
- Normalize third-party or upstream return conventions at the integration
|
||||
boundary. Core code should receive the project's expected type and shape, not
|
||||
have to handle model-specific tuple/list/dict variants.
|
||||
- Avoid caller-side unwrapping such as `out = out[0]` unless the called
|
||||
interface is documented to return that structure.
|
||||
|
||||
## Autograd and Model Freezing
|
||||
|
||||
- Do not add `torch.no_grad`, `torch.inference_mode`, or inference-mode helper
|
||||
wrappers in ComfyUI code. The only allowed inference-mode-related use is
|
||||
disabling a globally set inference mode when a training path needs gradients.
|
||||
- Do not add freeze, unfreeze, or trainability toggles to model classes. ComfyUI
|
||||
models are always treated as frozen for inference, so explicit freeze
|
||||
functionality is redundant and should not be added.
|
||||
- Remove training-only behavior such as dropout from inference model code, but
|
||||
preserve checkpoint and state-dict compatibility when doing so. If deleting a
|
||||
module would change state-dict keys, module ordering, or checkpoint loading
|
||||
behavior, replace it with a no-op such as `nn.Identity` instead of removing the
|
||||
slot outright.
|
||||
|
||||
## Python Style
|
||||
|
||||
- Keep imports at module scope. Avoid inline imports unless they are already part
|
||||
of an established optional-backend probe or are needed to avoid an import
|
||||
cycle.
|
||||
- Do not add unnecessary `try`/`except` blocks. Use them for optional dependency,
|
||||
platform, or backend capability detection only when the program has a useful
|
||||
fallback. Prefer specific exception types when changing new code.
|
||||
- Remove any workarounds for PyTorch versions that ComfyUI no longer officially
|
||||
supports. Deprecated workarounds include catching an exception and rerunning
|
||||
the same op with the input cast to float. If a workaround does not have a
|
||||
comment naming the exact PyTorch version or versions that still need it,
|
||||
remove it.
|
||||
- Let unsupported model formats, invalid quantization metadata, and bad states
|
||||
fail with clear errors instead of silently producing lower quality output.
|
||||
- Match the existing local style in the file you edit. This codebase tolerates
|
||||
long lines, simple helper functions, module-level state, and direct tensor
|
||||
operations when they make the code easier to follow.
|
||||
- Keep comments sparse and useful. Strip useless comments that restate the code
|
||||
or describe obvious behavior. Short TODOs are fine when they name the concrete
|
||||
missing follow-up.
|
||||
|
||||
## Model, Device, and Memory Behavior
|
||||
|
||||
- Treat dtype, device placement, VRAM usage, and offloading behavior as core
|
||||
correctness concerns. Check CPU, CUDA, ROCm, MPS, DirectML, XPU, NPU, and low
|
||||
VRAM implications when touching shared execution or loading code.
|
||||
- Prefer native ComfyUI formats and existing quantization/offload helpers over
|
||||
adding parallel code paths. Use `comfy.quant_ops`, `comfy.model_management`,
|
||||
`comfy.memory_management`, `comfy.pinned_memory`, `comfy_aimdo`, and
|
||||
`comfy-kitchen` helpers where they already solve the problem.
|
||||
- Use optimized comfy-kitchen ops in places where they improve performance
|
||||
without changing the expected dtype, device, memory, or interface behavior.
|
||||
- All models should use the optimized attention function selected by ComfyUI.
|
||||
Treat optimized backend functions, dispatch helpers, and capability-selected
|
||||
callables as opaque. Higher-level code must not inspect function identity,
|
||||
names, modules, or implementation details to decide behavior.
|
||||
- Apply the same opacity rule to similar patterns beyond attention: callers
|
||||
should depend on the documented interface and result contract, not on which
|
||||
backend implementation was selected underneath.
|
||||
- Do not use custom inference ops that only duplicate an existing op while
|
||||
upcasting to float32, such as custom RMSNorm variants. Use the generic ComfyUI
|
||||
ops and/or native torch ops instead.
|
||||
- If a model class `__init__` has an `operations` parameter, assume
|
||||
`operations` is never `None`. Do not add fallback branches or default torch
|
||||
ops for a missing `operations` object.
|
||||
- Do not add unnecessary parameters to model, model block, or model ops related
|
||||
classes. Constructor and forward signatures should carry only values that are
|
||||
actually needed by that object for inference.
|
||||
- Reuse existing model classes, blocks, ops, and helper modules when appropriate.
|
||||
Before implementing a new version of a model component, search the existing
|
||||
model code for a class or helper that already provides the behavior.
|
||||
- Model detection code that inspects linear weight shapes should only use the
|
||||
first dimension. The second dimension may be half the original size for
|
||||
NVFP4 or other 4-bit quantized models.
|
||||
- Avoid adding `einops` usage in core inference code. Use native torch tensor
|
||||
ops such as `reshape`, `view`, `permute`, `transpose`, `flatten`, `unflatten`,
|
||||
`unsqueeze`, and `squeeze` instead.
|
||||
- Do not use tensors as general-purpose Python data structures. Keep metadata,
|
||||
bookkeeping, counters, flags, shape math, padding math, index planning, memory
|
||||
estimates, and control-flow decisions in plain Python values unless the data
|
||||
must participate directly in tensor computation. Do not create tensors for
|
||||
structural metadata that is only used for Python-side control flow. Sequence
|
||||
lengths, cumulative offsets, split indices, window counts, slice boundaries,
|
||||
and repeat counts should be kept as Python ints/lists from the point they are
|
||||
computed. Do not build them as CPU/GPU tensors and then cast, move, validate,
|
||||
or convert them back to Python for `split`, `tensor_split`, indexing plans,
|
||||
loops, or cache keys. Avoid creating temporary tensors just to use tensor
|
||||
methods for scalar or structural calculations.
|
||||
- Avoid unnecessary casts and transfers. Preserve the intended compute dtype,
|
||||
storage dtype, bias dtype, and original tensor shape metadata.
|
||||
- Keep model-native latent layout handling inside the model or latent-format
|
||||
owner, not in helper nodes. Do not collapse, expand, pack, or unpack latent
|
||||
dimensions in nodes or other caller-side adapters just to satisfy a model
|
||||
forward; the model path should consume and return the native latent shape for
|
||||
that model family.
|
||||
- Assume inputs to the main model forward are already in the compute dtype by
|
||||
default, except integer inputs such as some model timestep tensors. Do not add
|
||||
defensive or convenience casts in model code; it is better for invalid dtype
|
||||
plumbing to error clearly than to hide it with unnecessary casts.
|
||||
- Raw model parameters that are not owned by an op and may be initialized in a
|
||||
dtype different from the compute dtype should be cast at use in forward or
|
||||
inference code with `comfy.ops.cast_to_input` or
|
||||
`comfy.model_management.cast_to` to avoid dtype mismatches.
|
||||
- Model code should not care what dtype it is initialized in, and model
|
||||
`__init__` methods should not contain workarounds for specific dtypes. Dtype
|
||||
workaround code, such as making a model work with fp16 compute, belongs in the
|
||||
execution or model-management layer that owns compute policy.
|
||||
- Model code should not perform unnecessary device-to-CPU or CPU-to-device
|
||||
transfers. New allocations must be created on the correct device and dtype;
|
||||
never allocate on CPU and then move to GPU, or allocate in one dtype and then
|
||||
convert to another.
|
||||
- Model code itself should not perform memory management. Loading, unloading,
|
||||
offloading, device movement, VRAM policy, cache lifetime, and cleanup belong
|
||||
in the relevant model-management and execution layers, not inside model
|
||||
implementations.
|
||||
- Do not add global, module-level, class-level, singleton, or model-owned stores
|
||||
for tensors or other large memory that persist across executions. Temporary
|
||||
caches must be scoped to a single execution or forward/encode/decode call:
|
||||
allocate them in the owning top-level call, pass them explicitly through the
|
||||
call stack, and let them be discarded when that call returns.
|
||||
- Follow the Wan VAE temporal cache pattern for temporary caches: create a local
|
||||
cache such as `feat_map` for the encode/decode operation, pass it into the
|
||||
blocks that need it, and do not retain it on the model or in global state.
|
||||
- In model init code, prefer `torch.empty` for parameter/buffer placeholders
|
||||
that are populated from the model state dict instead of zero-initializing with
|
||||
`torch.zeros` or similar. If an allocation is not loaded from the state dict
|
||||
and is useless for inference, do not include it.
|
||||
- `nn.Parameter` tensors that are stored in and populated from the model state
|
||||
dict should be initialized with `torch.empty`, not with zero, random, or
|
||||
otherwise meaningful initialization.
|
||||
- Model initialization should describe module structure, not fabricate
|
||||
checkpoint-owned tensor contents. Parameters and buffers that are loaded from
|
||||
the state dict must not be manually initialized, reassigned, or filled with
|
||||
fallback values unless that value is actually used when no checkpoint key
|
||||
exists.
|
||||
- When slicing large tensors, copy the slice if the sliced tensor's lifetime
|
||||
exceeds the current function scope. Do not keep a long-lived view into a large
|
||||
backing tensor when a smaller copy would release memory sooner.
|
||||
- Use fused or compound torch operations such as `addcmul` when they naturally
|
||||
match the math. Reducing Python and torch dispatch overhead is a valid
|
||||
optimization when it does not obscure the code or change dtype/device
|
||||
behavior.
|
||||
- Avoid caches that persist across different executions as much as possible.
|
||||
Persistent caches are acceptable only when they use a very minimal amount of
|
||||
memory and have a clear ownership and invalidation story.
|
||||
- When optimizing, favor small measurable changes: fewer allocations, fewer
|
||||
device transfers, less peak memory, better batching, or use of a faster
|
||||
existing backend op.
|
||||
|
||||
## Nodes and User-Facing Behavior
|
||||
|
||||
- Follow existing node conventions: `INPUT_TYPES`, `RETURN_TYPES`, `FUNCTION`,
|
||||
`CATEGORY`, and registration through the local mapping used by that file.
|
||||
- Keep node changes backward compatible by default. Add inputs with sensible
|
||||
defaults and avoid changing output types unless the request requires it.
|
||||
- Model implementations should add the minimal number of ComfyUI nodes required
|
||||
to run the model. Reuse existing nodes as much as possible; adapting the model
|
||||
to work with existing nodes is strongly preferred over creating new nodes.
|
||||
- Nodes should output only values they own. Do not add pass-through outputs for
|
||||
workflow convenience unless the node is explicitly an output node. Existing
|
||||
models, latents, conditioning, or other inputs should flow directly to the
|
||||
next consumer instead of being re-emitted unchanged.
|
||||
- Nodes should expose only inputs they actually read to produce current
|
||||
behavior. Do not add placeholder, pass-through, compatibility, or
|
||||
workflow-shaping inputs that are ignored or could flow directly to another
|
||||
node.
|
||||
- Node-level code must not patch model code directly. Any node behavior that
|
||||
modifies, wraps, hooks, or changes model behavior must go through the model
|
||||
patcher class instead of reaching into model internals.
|
||||
- The official mascot of ComfyUI is a very cute anime girl with massive fennec
|
||||
ears, a big fluffy tail, long blonde wavy hair, and blue eyes. Feel free to
|
||||
use her in ComfyUI materials, UI text, examples, tests, generated assets, or
|
||||
comments, but do not disrespect her.
|
||||
- Warning and info messages should be short and actionable. Remove noisy or
|
||||
misleading messages rather than adding more logging.
|
||||
- Documentation and README edits should be concise, factual, and tied to the
|
||||
changed behavior.
|
||||
|
||||
## Commit and Review Habits
|
||||
|
||||
- If asked to write commit messages, use short direct subjects like the existing
|
||||
history: `Fix ...`, `Add ...`, `Support ...`, `Remove ...`, `Update ...`,
|
||||
`Make ...`, `Use ...`, `Disable ...`, `Bump ...`, or `Revert ...`.
|
||||
- Keep PR descriptions short and reviewable. State the problem, the behavioral
|
||||
change, and the tests run; avoid long narrative explanations, implementation
|
||||
diaries, or exhaustive file-by-file summaries unless the reviewer explicitly
|
||||
needs that context.
|
||||
- Prefer one coherent behavioral change per commit. Dependency pins, tests, and
|
||||
the code that needs them may be in the same commit when they are inseparable.
|
||||
- In reviews, prioritize real user impact: crashes, wrong dtype/device behavior,
|
||||
memory regressions, broken model loading, workflow incompatibility, and noisy
|
||||
or misleading user-facing output.
|
||||
@ -1,115 +0,0 @@
|
||||
"""
|
||||
Download manager schema.
|
||||
|
||||
Adds the three tables that back the server-side model download manager
|
||||
: transient job/queue state (``downloads`` + per-segment
|
||||
``download_segments``) and one-API-key-per-host auth (``host_credentials``).
|
||||
|
||||
Revision ID: 0005_download_manager
|
||||
Revises: 0004_drop_tag_type
|
||||
Create Date: 2026-06-27
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0005_download_manager"
|
||||
down_revision = "0004_drop_tag_type"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"downloads",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("url", sa.Text(), nullable=False),
|
||||
sa.Column("final_url", sa.Text(), nullable=True),
|
||||
sa.Column("model_id", sa.String(length=1024), nullable=False),
|
||||
sa.Column("dest_path", sa.Text(), nullable=False),
|
||||
sa.Column("temp_path", sa.Text(), nullable=False),
|
||||
sa.Column("status", sa.String(length=16), nullable=False),
|
||||
sa.Column("priority", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("total_bytes", sa.BigInteger(), nullable=True),
|
||||
sa.Column("bytes_done", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("etag", sa.String(length=512), nullable=True),
|
||||
sa.Column("last_modified", sa.String(length=128), nullable=True),
|
||||
sa.Column(
|
||||
"accept_ranges", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column("expected_sha256", sa.String(length=64), nullable=True),
|
||||
sa.Column("credential_id", sa.String(length=36), nullable=True),
|
||||
sa.Column(
|
||||
"allow_any_extension",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
sa.Column("attempts", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("error", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
sa.CheckConstraint("bytes_done >= 0", name="ck_downloads_bytes_done_nonneg"),
|
||||
sa.CheckConstraint(
|
||||
"total_bytes IS NULL OR total_bytes >= 0",
|
||||
name="ck_downloads_total_bytes_nonneg",
|
||||
),
|
||||
)
|
||||
op.create_index("ix_downloads_status", "downloads", ["status"])
|
||||
op.create_index("ix_downloads_priority", "downloads", ["priority"])
|
||||
op.create_index("ix_downloads_model_id", "downloads", ["model_id"])
|
||||
|
||||
op.create_table(
|
||||
"download_segments",
|
||||
sa.Column(
|
||||
"download_id",
|
||||
sa.String(length=36),
|
||||
sa.ForeignKey("downloads.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("idx", sa.Integer(), nullable=False),
|
||||
sa.Column("start_offset", sa.BigInteger(), nullable=False),
|
||||
sa.Column("end_offset", sa.BigInteger(), nullable=False),
|
||||
sa.Column("bytes_done", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.PrimaryKeyConstraint("download_id", "idx", name="pk_download_segments"),
|
||||
sa.CheckConstraint("bytes_done >= 0", name="ck_segments_bytes_done_nonneg"),
|
||||
sa.CheckConstraint("end_offset >= start_offset", name="ck_segments_range"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"host_credentials",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("host", sa.String(length=255), nullable=False),
|
||||
sa.Column(
|
||||
"match_subdomains",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
sa.Column("label", sa.String(length=255), nullable=True),
|
||||
sa.Column(
|
||||
"auth_scheme", sa.String(length=16), nullable=False, server_default="bearer"
|
||||
),
|
||||
sa.Column("header_name", sa.String(length=255), nullable=True),
|
||||
sa.Column("query_param", sa.String(length=255), nullable=True),
|
||||
sa.Column("secret", sa.Text(), nullable=False),
|
||||
sa.Column("secret_last4", sa.String(length=4), nullable=True),
|
||||
sa.Column("enabled", sa.Boolean(), nullable=False, server_default=sa.text("true")),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
)
|
||||
op.create_index(
|
||||
"uq_host_credentials_host", "host_credentials", ["host"], unique=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("uq_host_credentials_host", table_name="host_credentials")
|
||||
op.drop_table("host_credentials")
|
||||
|
||||
op.drop_table("download_segments")
|
||||
|
||||
op.drop_index("ix_downloads_model_id", table_name="downloads")
|
||||
op.drop_index("ix_downloads_priority", table_name="downloads")
|
||||
op.drop_index("ix_downloads_status", table_name="downloads")
|
||||
op.drop_table("downloads")
|
||||
@ -306,12 +306,15 @@ async def download_asset_content(request: web.Request) -> web.Response:
|
||||
404, "FILE_NOT_FOUND", "Underlying file not found on disk."
|
||||
)
|
||||
|
||||
_DANGEROUS_MIME_TYPES = {
|
||||
"text/html", "text/html-sandboxed", "application/xhtml+xml",
|
||||
"text/javascript", "text/css",
|
||||
}
|
||||
if content_type in _DANGEROUS_MIME_TYPES:
|
||||
# User-controlled asset content must never render inline in the app origin
|
||||
# (stored XSS via SVG/HTML/XML). Force dangerous types to download and
|
||||
# override any requested inline disposition. Centralised through
|
||||
# folder_paths.is_dangerous_content_type so this can't drift from /view and
|
||||
# /userdata (the previous inline set here omitted image/svg+xml and missed
|
||||
# the charset/casing/+xml-dialect bypasses).
|
||||
if folder_paths.is_dangerous_content_type(content_type):
|
||||
content_type = "application/octet-stream"
|
||||
disposition = "attachment"
|
||||
|
||||
safe_name = (filename or "").replace("\r", "").replace("\n", "")
|
||||
encoded = urllib.parse.quote(safe_name)
|
||||
|
||||
@ -4,11 +4,7 @@ import shutil
|
||||
from app.logger import log_startup_warning
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
from filelock import FileLock, Timeout
|
||||
# NOTE: import the module (not `from ... import args`) so we always read the
|
||||
# live `args` object. Tests reload `comfy.cli_args`, which replaces the module
|
||||
# global; a bound `args` reference would go stale and point at the default
|
||||
# database URL instead of the one configured for the test.
|
||||
import comfy.cli_args
|
||||
from comfy.cli_args import args
|
||||
|
||||
_DB_AVAILABLE = False
|
||||
Session = None
|
||||
@ -25,7 +21,6 @@ try:
|
||||
|
||||
from app.database.models import Base
|
||||
import app.assets.database.models # noqa: F401 — register models with Base.metadata
|
||||
import app.model_downloader.database.models # noqa: F401 — register models with Base.metadata
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
@ -62,13 +57,13 @@ def get_alembic_config():
|
||||
|
||||
config = Config(config_path)
|
||||
config.set_main_option("script_location", scripts_path)
|
||||
config.set_main_option("sqlalchemy.url", comfy.cli_args.args.database_url)
|
||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_db_path():
|
||||
url = comfy.cli_args.args.database_url
|
||||
url = args.database_url
|
||||
if url.startswith("sqlite:///"):
|
||||
return url.split("///")[1]
|
||||
else:
|
||||
@ -102,7 +97,7 @@ def _is_memory_db(db_url):
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = comfy.cli_args.args.database_url
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
|
||||
if _is_memory_db(db_url):
|
||||
|
||||
@ -1,220 +0,0 @@
|
||||
"""aiohttp routes for the download manager.
|
||||
|
||||
Endpoint surface (all under ``/api/download``), mirroring the response
|
||||
envelope used by ``app/assets/api/routes.py``:
|
||||
|
||||
POST /api/download/enqueue
|
||||
GET /api/download
|
||||
POST /api/download/availability
|
||||
POST /api/download/clear
|
||||
POST /api/download/credentials
|
||||
GET /api/download/credentials
|
||||
GET /api/download/credentials/{id}
|
||||
DELETE /api/download/credentials/{id}
|
||||
GET /api/download/{id}
|
||||
DELETE /api/download/{id}
|
||||
POST /api/download/{id}/pause
|
||||
POST /api/download/{id}/resume
|
||||
POST /api/download/{id}/cancel
|
||||
POST /api/download/{id}/priority
|
||||
|
||||
Note on ordering: the static ``credentials`` routes are registered before the
|
||||
dynamic ``/api/download/{id}`` route so a request to ``.../credentials`` is not
|
||||
captured as ``id == "credentials"``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from aiohttp import web
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from app.model_downloader.api import schemas_in, schemas_out
|
||||
from app.model_downloader.credentials.store import (
|
||||
CREDENTIAL_STORE,
|
||||
CredentialValidationError,
|
||||
)
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
|
||||
|
||||
def register_routes(app: web.Application) -> None:
|
||||
"""Wire the download-manager routes into the running aiohttp app."""
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
|
||||
# ----- envelope helpers (same shape as app/assets/api/routes.py) -----
|
||||
|
||||
|
||||
def _error(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
|
||||
return web.json_response(
|
||||
{"error": {"code": code, "message": message, "details": details or {}}},
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def _ok(payload, status: int = 200) -> web.Response:
|
||||
return web.json_response(payload, status=status)
|
||||
|
||||
|
||||
async def _parse(request: web.Request, model: type[BaseModel]):
|
||||
try:
|
||||
raw = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return _error(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
try:
|
||||
return model.model_validate(raw)
|
||||
except ValidationError as ve:
|
||||
return _error(400, "INVALID_BODY", "Validation failed.", {"errors": json.loads(ve.json())})
|
||||
|
||||
|
||||
def _from_download_error(e: DownloadError) -> web.Response:
|
||||
return _error(e.http_status, e.code, e.message)
|
||||
|
||||
|
||||
# ----- downloads: collection + enqueue + availability -----
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/enqueue")
|
||||
async def enqueue(request: web.Request) -> web.Response:
|
||||
parsed = await _parse(request, schemas_in.EnqueueRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
try:
|
||||
download_id = await DOWNLOAD_MANAGER.enqueue(
|
||||
parsed.url,
|
||||
parsed.model_id,
|
||||
priority=parsed.priority,
|
||||
expected_sha256=parsed.expected_sha256,
|
||||
allow_any_extension=parsed.allow_any_extension,
|
||||
credential_id=parsed.credential_id,
|
||||
)
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"download_id": download_id, "accepted": True}, status=202)
|
||||
|
||||
|
||||
@ROUTES.get("/api/download")
|
||||
async def list_downloads(request: web.Request) -> web.Response:
|
||||
return _ok({"downloads": await DOWNLOAD_MANAGER.list()})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/availability")
|
||||
async def availability(request: web.Request) -> web.Response:
|
||||
parsed = await _parse(request, schemas_in.AvailabilityRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
return _ok({"models": await DOWNLOAD_MANAGER.availability(parsed.models)})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/clear")
|
||||
async def clear(request: web.Request) -> web.Response:
|
||||
deleted = await DOWNLOAD_MANAGER.clear()
|
||||
return _ok({"deleted": deleted})
|
||||
|
||||
|
||||
# ----- credentials (secrets are write-only) — must precede /{id} -----
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/credentials")
|
||||
async def upsert_credential(request: web.Request) -> web.Response:
|
||||
parsed = await _parse(request, schemas_in.CredentialUpsertRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
try:
|
||||
view = await CREDENTIAL_STORE.upsert(
|
||||
parsed.host,
|
||||
parsed.secret,
|
||||
auth_scheme=parsed.auth_scheme,
|
||||
header_name=parsed.header_name,
|
||||
query_param=parsed.query_param,
|
||||
label=parsed.label,
|
||||
match_subdomains=parsed.match_subdomains,
|
||||
enabled=parsed.enabled,
|
||||
)
|
||||
except CredentialValidationError as e:
|
||||
return _error(400, "INVALID_CREDENTIAL", str(e))
|
||||
return _ok(schemas_out.credential_to_dict(view), status=201)
|
||||
|
||||
|
||||
@ROUTES.get("/api/download/credentials")
|
||||
async def list_credentials(request: web.Request) -> web.Response:
|
||||
views = await CREDENTIAL_STORE.list()
|
||||
return _ok({"credentials": [schemas_out.credential_to_dict(v) for v in views]})
|
||||
|
||||
|
||||
@ROUTES.get("/api/download/credentials/{id}")
|
||||
async def get_credential(request: web.Request) -> web.Response:
|
||||
view = await CREDENTIAL_STORE.get(request.match_info["id"])
|
||||
if view is None:
|
||||
return _error(404, "NOT_FOUND", "No such credential.")
|
||||
return _ok(schemas_out.credential_to_dict(view))
|
||||
|
||||
|
||||
@ROUTES.delete("/api/download/credentials/{id}")
|
||||
async def delete_credential(request: web.Request) -> web.Response:
|
||||
deleted = await CREDENTIAL_STORE.delete(request.match_info["id"])
|
||||
if not deleted:
|
||||
return _error(404, "NOT_FOUND", "No such credential.")
|
||||
return _ok({"deleted": True})
|
||||
|
||||
|
||||
# ----- single download by id (dynamic; registered last) -----
|
||||
|
||||
|
||||
@ROUTES.get("/api/download/{id}")
|
||||
async def get_download(request: web.Request) -> web.Response:
|
||||
view = await DOWNLOAD_MANAGER.status(request.match_info["id"])
|
||||
if view is None:
|
||||
return _error(404, "NOT_FOUND", "No such download.")
|
||||
return _ok(view)
|
||||
|
||||
|
||||
@ROUTES.delete("/api/download/{id}")
|
||||
async def delete_download(request: web.Request) -> web.Response:
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.delete(request.match_info["id"])
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"deleted": True})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/{id}/pause")
|
||||
async def pause(request: web.Request) -> web.Response:
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.pause(request.match_info["id"])
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"ok": True})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/{id}/resume")
|
||||
async def resume(request: web.Request) -> web.Response:
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.resume(request.match_info["id"])
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"ok": True})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/{id}/cancel")
|
||||
async def cancel(request: web.Request) -> web.Response:
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.cancel(request.match_info["id"])
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"ok": True})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/{id}/priority")
|
||||
async def set_priority(request: web.Request) -> web.Response:
|
||||
parsed = await _parse(request, schemas_in.PriorityRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.set_priority(request.match_info["id"], parsed.priority)
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"ok": True})
|
||||
@ -1,51 +0,0 @@
|
||||
"""Request schemas for the download manager API.
|
||||
|
||||
Pydantic enforces shape at the boundary; handlers operate only on validated
|
||||
values past that point.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.model_downloader.constants import AUTH_SCHEME_BEARER
|
||||
|
||||
|
||||
class EnqueueRequest(BaseModel):
|
||||
url: str
|
||||
model_id: str
|
||||
priority: int = 0
|
||||
expected_sha256: Optional[str] = None
|
||||
allow_any_extension: bool = False
|
||||
credential_id: Optional[str] = None
|
||||
|
||||
|
||||
class PriorityRequest(BaseModel):
|
||||
priority: int
|
||||
|
||||
|
||||
class AvailabilityRequest(BaseModel):
|
||||
"""``{model_id: url}`` — the URLs declared in the workflow JSON."""
|
||||
|
||||
models: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class CredentialUpsertRequest(BaseModel):
|
||||
host: str
|
||||
secret: str
|
||||
auth_scheme: str = AUTH_SCHEME_BEARER
|
||||
header_name: Optional[str] = None
|
||||
query_param: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
match_subdomains: bool = False
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EnqueueRequest",
|
||||
"PriorityRequest",
|
||||
"AvailabilityRequest",
|
||||
"CredentialUpsertRequest",
|
||||
]
|
||||
@ -1,26 +0,0 @@
|
||||
"""Response helpers for the download manager API.
|
||||
|
||||
The download/status read models are plain dicts produced by the manager. This
|
||||
module only needs to mask credentials for output (the secret is never returned).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.model_downloader.credentials.store import CredentialView
|
||||
|
||||
|
||||
def credential_to_dict(view: CredentialView) -> dict:
|
||||
"""API-safe credential representation — never includes the secret."""
|
||||
return {
|
||||
"id": view.id,
|
||||
"host": view.host,
|
||||
"auth_scheme": view.auth_scheme,
|
||||
"header_name": view.header_name,
|
||||
"query_param": view.query_param,
|
||||
"label": view.label,
|
||||
"match_subdomains": view.match_subdomains,
|
||||
"enabled": view.enabled,
|
||||
"secret_last4": view.secret_last4,
|
||||
"created_at": view.created_at,
|
||||
"updated_at": view.updated_at,
|
||||
}
|
||||
@ -1,47 +0,0 @@
|
||||
"""Shared constants for the download manager.
|
||||
|
||||
Status values are persisted as TEXT in the ``downloads`` table; keep them
|
||||
stable. The lifecycle is:
|
||||
|
||||
queued -> active -> verifying -> completed
|
||||
| |-> paused -> (resume) -> active
|
||||
| |-> failed (network, retryable) -> queued (backoff)
|
||||
|-> cancelled
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Auth schemes for HostCredential
|
||||
AUTH_SCHEME_BEARER = "bearer"
|
||||
AUTH_SCHEME_HEADER = "header"
|
||||
AUTH_SCHEME_QUERY = "query"
|
||||
AUTH_SCHEMES = (AUTH_SCHEME_BEARER, AUTH_SCHEME_HEADER, AUTH_SCHEME_QUERY)
|
||||
|
||||
# Hosts for which a bearer token can be sourced from the environment when no
|
||||
# stored credential matches. Values are the env var names to try, in order.
|
||||
# Only consulted during auto-resolve for an exact host match over https, so the
|
||||
# same per-hop boundary rules apply (e.g. the token is dropped on a redirect to
|
||||
# a CDN host). Kept here so the host->env-var mapping lives in one place.
|
||||
ENV_TOKEN_HOSTS = {
|
||||
"huggingface.co": ("HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"),
|
||||
}
|
||||
|
||||
|
||||
class DownloadStatus:
|
||||
QUEUED = "queued"
|
||||
ACTIVE = "active"
|
||||
PAUSED = "paused"
|
||||
VERIFYING = "verifying"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
#: States from which a worker is doing (or about to do) network I/O.
|
||||
LIVE = (QUEUED, ACTIVE, VERIFYING)
|
||||
#: Terminal states — the job will not transition again on its own.
|
||||
TERMINAL = (COMPLETED, FAILED, CANCELLED)
|
||||
|
||||
|
||||
# Default temp-file suffix. Distinctive so the startup orphan sweep only
|
||||
# removes files THIS subsystem created, never unrelated *.tmp files.
|
||||
TMP_SUFFIX = ".comfy-download.part"
|
||||
@ -1,111 +0,0 @@
|
||||
"""Turn a stored credential into a per-hop request modifier (PRD section 9.4.2).
|
||||
|
||||
The critical rule: a credential is only ever attached when *the current hop's
|
||||
host* matches a stored credential, and only over https. This is recomputed
|
||||
from scratch on every redirect hop, so a token bound to ``huggingface.co`` is
|
||||
silently dropped when the request is redirected to a presigned CDN host —
|
||||
which is exactly what these hubs expect.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode, urlsplit, urlunsplit
|
||||
|
||||
from app.model_downloader.constants import (
|
||||
AUTH_SCHEME_BEARER,
|
||||
AUTH_SCHEME_HEADER,
|
||||
AUTH_SCHEME_QUERY,
|
||||
ENV_TOKEN_HOSTS,
|
||||
)
|
||||
from app.model_downloader.credentials.store import normalize_host
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.database.models import HostCredential
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestAuth:
|
||||
"""How to modify a single request to carry a credential."""
|
||||
|
||||
headers: dict[str, str] = field(default_factory=dict)
|
||||
query: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def apply_to_url(self, url: str) -> str:
|
||||
if not self.query:
|
||||
return url
|
||||
parts = urlsplit(url)
|
||||
# Append only the credential params, leaving the original query string
|
||||
# (including any repeated keys and existing encoding) untouched.
|
||||
creds = urlencode(self.query)
|
||||
query = f"{parts.query}&{creds}" if parts.query else creds
|
||||
return urlunsplit(parts._replace(query=query))
|
||||
|
||||
|
||||
def _matches(cred: HostCredential, hop_host: str) -> bool:
|
||||
cred_host = cred.host
|
||||
if hop_host == cred_host:
|
||||
return True
|
||||
if cred.match_subdomains:
|
||||
# Label-boundary suffix: api.example.com matches example.com, but
|
||||
# evil-example.com does NOT.
|
||||
return hop_host.endswith("." + cred_host)
|
||||
return False
|
||||
|
||||
|
||||
def _build_auth(cred: HostCredential) -> RequestAuth:
|
||||
if cred.auth_scheme == AUTH_SCHEME_BEARER:
|
||||
return RequestAuth(headers={"Authorization": f"Bearer {cred.secret}"})
|
||||
if cred.auth_scheme == AUTH_SCHEME_HEADER:
|
||||
name = cred.header_name or "Authorization"
|
||||
return RequestAuth(headers={name: cred.secret})
|
||||
if cred.auth_scheme == AUTH_SCHEME_QUERY and cred.query_param:
|
||||
return RequestAuth(query={cred.query_param: cred.secret})
|
||||
return RequestAuth()
|
||||
|
||||
|
||||
def _resolve_sync(
|
||||
host: str, scheme: str, explicit_credential_id: Optional[str]
|
||||
) -> Optional[RequestAuth]:
|
||||
# Never attach a secret over a non-https hop (PRD section 9.4.2).
|
||||
if scheme.lower() != "https":
|
||||
return None
|
||||
hop_host = normalize_host(host)
|
||||
if not hop_host:
|
||||
return None
|
||||
|
||||
if explicit_credential_id is not None:
|
||||
cred = queries.get_credential(explicit_credential_id)
|
||||
# An explicit credential is still subject to the per-hop host check —
|
||||
# it is not forced onto a non-matching host.
|
||||
if cred is None or not cred.enabled or not _matches(cred, hop_host):
|
||||
return None
|
||||
return _build_auth(cred)
|
||||
|
||||
# Auto-resolve: exact host first, then any subdomain-matching credential.
|
||||
cred = queries.get_credential_by_host(hop_host)
|
||||
if cred is not None and cred.enabled:
|
||||
return _build_auth(cred)
|
||||
for sub in queries.list_subdomain_credentials():
|
||||
if sub.enabled and _matches(sub, hop_host):
|
||||
return _build_auth(sub)
|
||||
|
||||
# Env fallback: only for an exact host match, and only after the DB lookups
|
||||
# miss, so a user-set credential always takes precedence. The token is never
|
||||
# persisted; it is read fresh from the environment on each hop.
|
||||
for var in ENV_TOKEN_HOSTS.get(hop_host, ()):
|
||||
token = os.environ.get(var)
|
||||
if token:
|
||||
return RequestAuth(headers={"Authorization": f"Bearer {token}"})
|
||||
return None
|
||||
|
||||
|
||||
async def resolve_auth_for_hop(
|
||||
host: str, scheme: str, *, explicit_credential_id: Optional[str] = None
|
||||
) -> Optional[RequestAuth]:
|
||||
"""Resolve the credential (if any) to attach for one request hop."""
|
||||
return await asyncio.to_thread(
|
||||
_resolve_sync, host, scheme, explicit_credential_id
|
||||
)
|
||||
@ -1,141 +0,0 @@
|
||||
"""The credential store: one API key per host.
|
||||
|
||||
Secrets are write-only over the API — :class:`CredentialView` carries only
|
||||
masked metadata (``secret_last4`` + scheme + label), never the secret itself.
|
||||
At-rest protection for v1 is filesystem permissions on the shared DB (the DB
|
||||
is the trust boundary); encryption-at-rest is a noted future seam.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from app.model_downloader.constants import (
|
||||
AUTH_SCHEME_BEARER,
|
||||
AUTH_SCHEME_HEADER,
|
||||
AUTH_SCHEME_QUERY,
|
||||
AUTH_SCHEMES,
|
||||
)
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.database.models import HostCredential
|
||||
|
||||
|
||||
def normalize_host(host: str) -> str:
|
||||
"""Lowercase, strip port, IDNA-encode."""
|
||||
if not host:
|
||||
return ""
|
||||
host = host.strip()
|
||||
if "://" in host: # a full URL was pasted — extract just the host
|
||||
host = urlsplit(host).hostname or ""
|
||||
host = host.lower()
|
||||
if host.startswith("[") and "]" in host: # bracketed IPv6 literal
|
||||
host = host[1 : host.index("]")]
|
||||
elif host.count(":") == 1: # host:port (not IPv6)
|
||||
host = host.split(":", 1)[0]
|
||||
try:
|
||||
host = host.encode("idna").decode("ascii")
|
||||
except (UnicodeError, ValueError):
|
||||
pass
|
||||
return host
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CredentialView:
|
||||
"""Masked, API-safe view of a credential — never includes the secret."""
|
||||
|
||||
id: str
|
||||
host: str
|
||||
auth_scheme: str
|
||||
header_name: Optional[str]
|
||||
query_param: Optional[str]
|
||||
label: Optional[str]
|
||||
match_subdomains: bool
|
||||
enabled: bool
|
||||
secret_last4: Optional[str]
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
def _to_view(row: HostCredential) -> CredentialView:
|
||||
return CredentialView(
|
||||
id=row.id,
|
||||
host=row.host,
|
||||
auth_scheme=row.auth_scheme,
|
||||
header_name=row.header_name,
|
||||
query_param=row.query_param,
|
||||
label=row.label,
|
||||
match_subdomains=row.match_subdomains,
|
||||
enabled=row.enabled,
|
||||
secret_last4=row.secret_last4,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class CredentialValidationError(ValueError):
|
||||
"""A credential upsert had inconsistent fields."""
|
||||
|
||||
|
||||
class CredentialStore:
|
||||
"""Async facade over the ``host_credentials`` table.
|
||||
|
||||
DB access is synchronous (SQLite) and offloaded via ``asyncio.to_thread``.
|
||||
"""
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
host: str,
|
||||
secret: str,
|
||||
*,
|
||||
auth_scheme: str = AUTH_SCHEME_BEARER,
|
||||
header_name: Optional[str] = None,
|
||||
query_param: Optional[str] = None,
|
||||
label: Optional[str] = None,
|
||||
match_subdomains: bool = False,
|
||||
enabled: bool = True,
|
||||
) -> CredentialView:
|
||||
host = normalize_host(host)
|
||||
if not host:
|
||||
raise CredentialValidationError("host is required")
|
||||
if not secret:
|
||||
raise CredentialValidationError("secret is required")
|
||||
if auth_scheme not in AUTH_SCHEMES:
|
||||
raise CredentialValidationError(
|
||||
f"auth_scheme must be one of {AUTH_SCHEMES}, got {auth_scheme!r}"
|
||||
)
|
||||
if auth_scheme == AUTH_SCHEME_HEADER and not header_name:
|
||||
header_name = "Authorization"
|
||||
if auth_scheme == AUTH_SCHEME_QUERY and not query_param:
|
||||
raise CredentialValidationError(
|
||||
"query_param is required when auth_scheme='query'"
|
||||
)
|
||||
values = {
|
||||
"host": host,
|
||||
"secret": secret,
|
||||
"secret_last4": secret[-4:] if len(secret) > 4 else None,
|
||||
"auth_scheme": auth_scheme,
|
||||
"header_name": header_name,
|
||||
"query_param": query_param,
|
||||
"label": label,
|
||||
"match_subdomains": match_subdomains,
|
||||
"enabled": enabled,
|
||||
}
|
||||
row = await asyncio.to_thread(queries.upsert_credential, values)
|
||||
return _to_view(row)
|
||||
|
||||
async def list(self) -> list[CredentialView]:
|
||||
rows = await asyncio.to_thread(queries.list_credentials)
|
||||
return [_to_view(r) for r in rows]
|
||||
|
||||
async def get(self, credential_id: str) -> Optional[CredentialView]:
|
||||
row = await asyncio.to_thread(queries.get_credential, credential_id)
|
||||
return _to_view(row) if row is not None else None
|
||||
|
||||
async def delete(self, credential_id: str) -> bool:
|
||||
return await asyncio.to_thread(queries.delete_credential, credential_id)
|
||||
|
||||
|
||||
CREDENTIAL_STORE = CredentialStore()
|
||||
@ -1,173 +0,0 @@
|
||||
"""SQLAlchemy models for the download manager.
|
||||
|
||||
Three tables:
|
||||
|
||||
- ``downloads`` one row per requested file (job + queue state).
|
||||
- ``download_segments`` per-segment byte progress, for segmented resume.
|
||||
- ``host_credentials`` one API key per host, reused across downloads.
|
||||
|
||||
On completion a finished file is registered into the assets catalog;
|
||||
``downloads`` is kept only as job history.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database.models import Base
|
||||
|
||||
|
||||
def _uuid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _now() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
class Download(Base):
|
||||
__tablename__ = "downloads"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||
# Original requested URL and the final URL after validated redirects.
|
||||
url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
final_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# Canonical "<directory>/<filename>" identifier (resolved via folder_paths).
|
||||
model_id: Mapped[str] = mapped_column(String(1024), nullable=False)
|
||||
# Final on-disk location and the .part write target.
|
||||
dest_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
temp_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
|
||||
status: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
priority: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
total_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
bytes_done: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
|
||||
etag: Mapped[str | None] = mapped_column(String(512), nullable=True)
|
||||
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
accept_ranges: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Optional hub-provided checksum to verify against (NOT the dedup key).
|
||||
expected_sha256: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
|
||||
# Explicit credential override; otherwise auto-resolved by host.
|
||||
# RESTRICT keeps a credential from being deleted while a download references it.
|
||||
credential_id: Mapped[str | None] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("host_credentials.id", ondelete="RESTRICT"),
|
||||
nullable=True,
|
||||
)
|
||||
allow_any_extension: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
# How many retryable failures we have seen (for backoff capping).
|
||||
attempts: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
error: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[int] = mapped_column(BigInteger, nullable=False, default=_now)
|
||||
updated_at: Mapped[int] = mapped_column(
|
||||
BigInteger, nullable=False, default=_now, onupdate=_now
|
||||
)
|
||||
|
||||
segments: Mapped[list[DownloadSegment]] = relationship(
|
||||
"DownloadSegment",
|
||||
back_populates="download",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
order_by="DownloadSegment.idx",
|
||||
)
|
||||
|
||||
credential: Mapped[HostCredential | None] = relationship(
|
||||
"HostCredential", back_populates="downloads"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_downloads_status", "status"),
|
||||
Index("ix_downloads_priority", "priority"),
|
||||
Index("ix_downloads_model_id", "model_id"),
|
||||
CheckConstraint("bytes_done >= 0", name="ck_downloads_bytes_done_nonneg"),
|
||||
CheckConstraint(
|
||||
"total_bytes IS NULL OR total_bytes >= 0",
|
||||
name="ck_downloads_total_bytes_nonneg",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Download id={self.id} model_id={self.model_id!r} status={self.status}>"
|
||||
|
||||
|
||||
class DownloadSegment(Base):
|
||||
__tablename__ = "download_segments"
|
||||
|
||||
download_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("downloads.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
idx: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
start_offset: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
end_offset: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
bytes_done: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
|
||||
download: Mapped[Download] = relationship("Download", back_populates="segments")
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("bytes_done >= 0", name="ck_segments_bytes_done_nonneg"),
|
||||
CheckConstraint("end_offset >= start_offset", name="ck_segments_range"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<DownloadSegment {self.download_id}#{self.idx} "
|
||||
f"{self.start_offset}-{self.end_offset} done={self.bytes_done}>"
|
||||
)
|
||||
|
||||
|
||||
class HostCredential(Base):
|
||||
__tablename__ = "host_credentials"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||
# Normalized lowercase hostname, e.g. "civitai.com".
|
||||
host: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
match_subdomains: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
label: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
auth_scheme: Mapped[str] = mapped_column(
|
||||
String(16), nullable=False, default="bearer"
|
||||
)
|
||||
header_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
query_param: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
# The API key itself. Write-only over the API; never returned. See PRD 9.4.4.
|
||||
secret: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
secret_last4: Mapped[str | None] = mapped_column(String(4), nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
created_at: Mapped[int] = mapped_column(BigInteger, nullable=False, default=_now)
|
||||
updated_at: Mapped[int] = mapped_column(
|
||||
BigInteger, nullable=False, default=_now, onupdate=_now
|
||||
)
|
||||
|
||||
downloads: Mapped[list[Download]] = relationship(
|
||||
"Download", back_populates="credential"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_host_credentials_host", "host", unique=True),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<HostCredential id={self.id} host={self.host!r} scheme={self.auth_scheme}>"
|
||||
@ -1,272 +0,0 @@
|
||||
"""Synchronous DB access for the download manager.
|
||||
|
||||
All functions open their own short-lived session via ``create_session`` and
|
||||
commit before returning, mirroring ``app/assets`` usage. They are blocking
|
||||
(SQLite) and should be called from async code through ``asyncio.to_thread``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database.models import (
|
||||
Download,
|
||||
DownloadSegment,
|
||||
HostCredential,
|
||||
)
|
||||
|
||||
|
||||
# ----- downloads -----
|
||||
|
||||
|
||||
def insert_download(values: dict) -> None:
|
||||
with create_session() as session:
|
||||
session.add(Download(**values))
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_download(download_id: str) -> Optional[Download]:
|
||||
with create_session() as session:
|
||||
row = session.get(Download, download_id)
|
||||
if row is not None:
|
||||
session.expunge_all()
|
||||
return row
|
||||
|
||||
|
||||
def list_downloads() -> list[Download]:
|
||||
with create_session() as session:
|
||||
rows = list(
|
||||
session.execute(
|
||||
select(Download).order_by(Download.created_at.desc())
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return rows
|
||||
|
||||
|
||||
def list_segments(download_id: str) -> list[DownloadSegment]:
|
||||
with create_session() as session:
|
||||
rows = list(
|
||||
session.execute(
|
||||
select(DownloadSegment)
|
||||
.where(DownloadSegment.download_id == download_id)
|
||||
.order_by(DownloadSegment.idx)
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return rows
|
||||
|
||||
|
||||
def update_download(download_id: str, **fields) -> None:
|
||||
if not fields:
|
||||
return
|
||||
fields.setdefault("updated_at", int(time.time()))
|
||||
with create_session() as session:
|
||||
row = session.get(Download, download_id)
|
||||
if row is None:
|
||||
return
|
||||
for key, value in fields.items():
|
||||
setattr(row, key, value)
|
||||
session.commit()
|
||||
|
||||
|
||||
def delete_download(download_id: str) -> None:
|
||||
with create_session() as session:
|
||||
row = session.get(Download, download_id)
|
||||
if row is not None:
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
|
||||
|
||||
def delete_downloads(download_ids: list[str]) -> int:
|
||||
"""Delete many downloads in one transaction; returns the number removed.
|
||||
|
||||
Uses a bulk ``DELETE ... WHERE id IN (...)``. Segment rows are removed by
|
||||
the ``ON DELETE CASCADE`` foreign key (SQLite ``PRAGMA foreign_keys=ON`` is
|
||||
set in ``app/database/db.py``), so this stays consistent without loading the
|
||||
ORM relationship.
|
||||
"""
|
||||
if not download_ids:
|
||||
return 0
|
||||
with create_session() as session:
|
||||
result = session.execute(
|
||||
delete(Download).where(Download.id.in_(download_ids))
|
||||
)
|
||||
session.commit()
|
||||
return result.rowcount or 0
|
||||
|
||||
|
||||
def replace_segments(download_id: str, segments: list[dict]) -> None:
|
||||
"""Atomically replace the segment plan for a download."""
|
||||
with create_session() as session:
|
||||
session.query(DownloadSegment).filter(
|
||||
DownloadSegment.download_id == download_id
|
||||
).delete()
|
||||
for seg in segments:
|
||||
session.add(DownloadSegment(download_id=download_id, **seg))
|
||||
session.commit()
|
||||
|
||||
|
||||
def update_segment_progress(download_id: str, idx: int, bytes_done: int) -> None:
|
||||
with create_session() as session:
|
||||
row = session.get(DownloadSegment, {"download_id": download_id, "idx": idx})
|
||||
if row is None:
|
||||
return
|
||||
row.bytes_done = bytes_done
|
||||
session.commit()
|
||||
|
||||
|
||||
def list_queued_downloads() -> list[Download]:
|
||||
"""Queued rows ordered for admission (priority desc, then FIFO)."""
|
||||
with create_session() as session:
|
||||
rows = list(
|
||||
session.execute(
|
||||
select(Download)
|
||||
.where(Download.status == DownloadStatus.QUEUED)
|
||||
.order_by(Download.priority.desc(), Download.created_at.asc())
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return rows
|
||||
|
||||
|
||||
def reconcile_live_downloads() -> list[Download]:
|
||||
"""Reset any ``active``/``verifying`` rows left by a previous run.
|
||||
|
||||
On a clean restart there can be no live worker, so anything still marked
|
||||
live is stale. Move it back to ``queued`` (offsets are preserved on the
|
||||
segment rows) so the scheduler re-admits it. Returns the rows that should
|
||||
be re-queued by the scheduler (queued + paused).
|
||||
"""
|
||||
with create_session() as session:
|
||||
stale = list(
|
||||
session.execute(
|
||||
select(Download).where(
|
||||
Download.status.in_([DownloadStatus.ACTIVE, DownloadStatus.VERIFYING])
|
||||
)
|
||||
).scalars()
|
||||
)
|
||||
now = int(time.time())
|
||||
for row in stale:
|
||||
row.status = DownloadStatus.QUEUED
|
||||
row.updated_at = now
|
||||
session.commit()
|
||||
|
||||
resumable = list(
|
||||
session.execute(
|
||||
select(Download)
|
||||
.where(Download.status == DownloadStatus.QUEUED)
|
||||
.order_by(Download.priority.desc(), Download.created_at.asc())
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return resumable
|
||||
|
||||
|
||||
# ----- host credentials -----
|
||||
|
||||
|
||||
def get_credential(credential_id: str) -> Optional[HostCredential]:
|
||||
with create_session() as session:
|
||||
row = session.get(HostCredential, credential_id)
|
||||
if row is not None:
|
||||
session.expunge_all()
|
||||
return row
|
||||
|
||||
|
||||
def get_credential_by_host(host: str) -> Optional[HostCredential]:
|
||||
with create_session() as session:
|
||||
row = (
|
||||
session.execute(
|
||||
select(HostCredential).where(HostCredential.host == host).limit(1)
|
||||
)
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
if row is not None:
|
||||
session.expunge_all()
|
||||
return row
|
||||
|
||||
|
||||
def list_credentials() -> list[HostCredential]:
|
||||
with create_session() as session:
|
||||
rows = list(
|
||||
session.execute(
|
||||
select(HostCredential).order_by(HostCredential.host)
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return rows
|
||||
|
||||
|
||||
def list_subdomain_credentials() -> list[HostCredential]:
|
||||
"""Credentials that opted into subdomain matching, for suffix checks."""
|
||||
with create_session() as session:
|
||||
rows = list(
|
||||
session.execute(
|
||||
select(HostCredential).where(HostCredential.match_subdomains.is_(True))
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return rows
|
||||
|
||||
|
||||
def upsert_credential(values: dict) -> HostCredential:
|
||||
"""Insert or update a credential keyed by ``host``.
|
||||
|
||||
Callers can target the same host concurrently (each runs in its own
|
||||
short-lived session on a separate connection), so the read-then-write here
|
||||
can race: two callers both see no existing row and both attempt an insert.
|
||||
The ``host`` column is uniquely indexed, so the loser's insert raises
|
||||
``IntegrityError``. We recover by rolling back and retrying, at which point
|
||||
the now-committed row is found and updated in place, letting concurrent
|
||||
calls converge instead of failing or creating duplicates.
|
||||
"""
|
||||
host = values["host"]
|
||||
now = int(time.time())
|
||||
last_error: IntegrityError | None = None
|
||||
for _ in range(2):
|
||||
with create_session() as session:
|
||||
row = (
|
||||
session.execute(
|
||||
select(HostCredential).where(HostCredential.host == host).limit(1)
|
||||
)
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
if row is None:
|
||||
row = HostCredential(**values)
|
||||
row.created_at = now
|
||||
row.updated_at = now
|
||||
session.add(row)
|
||||
else:
|
||||
for key, value in values.items():
|
||||
setattr(row, key, value)
|
||||
row.updated_at = now
|
||||
try:
|
||||
session.commit()
|
||||
except IntegrityError as exc:
|
||||
session.rollback()
|
||||
last_error = exc
|
||||
continue
|
||||
session.refresh(row)
|
||||
session.expunge(row)
|
||||
return row
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
|
||||
|
||||
def delete_credential(credential_id: str) -> bool:
|
||||
with create_session() as session:
|
||||
row = session.get(HostCredential, credential_id)
|
||||
if row is None:
|
||||
return False
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
return True
|
||||
@ -1,612 +0,0 @@
|
||||
"""The per-download worker.
|
||||
|
||||
One :class:`DownloadJob` drives a single file from probe to verified, cataloged
|
||||
completion. It supports cooperative pause / resume / cancel, segmented
|
||||
multi-connection transfer with positioned writes, and a verification gate
|
||||
(size + structural + optional sha256) before the atomic rename into place.
|
||||
|
||||
Control is cooperative: external callers flip ``_control`` via
|
||||
:meth:`request_pause` / :meth:`request_cancel`; segment loops observe it between
|
||||
chunks and raise, which unwinds cleanly and persists resume offsets.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Optional
|
||||
|
||||
from comfy.cli_args import args
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.engine.planner import (
|
||||
effective_segment_count,
|
||||
plan_segments,
|
||||
)
|
||||
from app.model_downloader.engine.writer import FileWriter
|
||||
from app.model_downloader.net.http import open_validated, redact_url
|
||||
from app.model_downloader.net.probe import gated_error_message, probe
|
||||
from app.model_downloader.verify import checksum, dedup, structural
|
||||
|
||||
_RETRYABLE_STATUSES = {408, 429, 500, 502, 503, 504}
|
||||
_PERSIST_INTERVAL = 2.0 # seconds between throttled progress persists
|
||||
|
||||
|
||||
class Paused(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Cancelled(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RemoteChanged(Exception):
|
||||
"""The remote file changed under a resume (got 200 where 206 expected)."""
|
||||
|
||||
|
||||
class RetryableError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FatalError(Exception):
|
||||
"""Non-retryable: 4xx, checksum mismatch, structural failure, gated, etc."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SegmentRuntime:
|
||||
idx: int
|
||||
start: int
|
||||
end: int # inclusive; may be -1 for unknown-size single stream
|
||||
bytes_done: int = 0
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return self.end - self.start + 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeState:
|
||||
download_id: str
|
||||
model_id: str
|
||||
url: str
|
||||
priority: int
|
||||
status: str
|
||||
total_bytes: Optional[int] = None
|
||||
bytes_done: int = 0
|
||||
error: Optional[str] = None
|
||||
segments: list[SegmentRuntime] = field(default_factory=list)
|
||||
started_at: float = field(default_factory=time.monotonic)
|
||||
_last_bytes: int = 0
|
||||
_last_time: float = field(default_factory=time.monotonic)
|
||||
speed_bps: float = 0.0
|
||||
|
||||
@property
|
||||
def progress(self) -> Optional[float]:
|
||||
if not self.total_bytes:
|
||||
return None
|
||||
return min(1.0, self.bytes_done / self.total_bytes)
|
||||
|
||||
@property
|
||||
def eta_seconds(self) -> Optional[float]:
|
||||
if not self.total_bytes or self.speed_bps <= 0:
|
||||
return None
|
||||
remaining = max(0, self.total_bytes - self.bytes_done)
|
||||
return remaining / self.speed_bps
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobSpec:
|
||||
download_id: str
|
||||
url: str
|
||||
model_id: str
|
||||
dest_path: str
|
||||
temp_path: str
|
||||
priority: int = 0
|
||||
credential_id: Optional[str] = None
|
||||
expected_sha256: Optional[str] = None
|
||||
allow_any_extension: bool = False
|
||||
etag: Optional[str] = None
|
||||
attempts: int = 0
|
||||
|
||||
|
||||
class DownloadJob:
|
||||
def __init__(
|
||||
self, spec: JobSpec, notify_cb: Optional[Callable[[str], None]] = None
|
||||
) -> None:
|
||||
self.spec = spec
|
||||
self._notify = notify_cb
|
||||
self._control = "run" # run | pause | cancel
|
||||
self.state = RuntimeState(
|
||||
download_id=spec.download_id,
|
||||
model_id=spec.model_id,
|
||||
url=spec.url,
|
||||
priority=spec.priority,
|
||||
status=DownloadStatus.QUEUED,
|
||||
)
|
||||
self._writer: Optional[FileWriter] = None
|
||||
self._etag: Optional[str] = spec.etag
|
||||
self._last_persist = 0.0
|
||||
|
||||
# ----- external control -----
|
||||
|
||||
def request_pause(self) -> None:
|
||||
if self._control == "run":
|
||||
self._control = "pause"
|
||||
|
||||
def request_cancel(self) -> None:
|
||||
self._control = "cancel"
|
||||
|
||||
def _check_control(self) -> None:
|
||||
if self._control == "cancel":
|
||||
raise Cancelled()
|
||||
if self._control == "pause":
|
||||
raise Paused()
|
||||
|
||||
# ----- lifecycle -----
|
||||
|
||||
async def run(self) -> str:
|
||||
"""Run to a terminal/paused state; returns the final status string."""
|
||||
await self._set_status(DownloadStatus.ACTIVE, error=None)
|
||||
try:
|
||||
pr = await self._probe_and_plan()
|
||||
await self._transfer(pr)
|
||||
await self._finalize()
|
||||
await self._set_status(DownloadStatus.COMPLETED)
|
||||
except Paused:
|
||||
await self._persist_progress(force=True)
|
||||
await self._set_status(DownloadStatus.PAUSED)
|
||||
except Cancelled:
|
||||
await self._close_writer()
|
||||
self._remove_temp()
|
||||
await self._set_status(DownloadStatus.CANCELLED)
|
||||
except RemoteChanged:
|
||||
await self._reset_for_restart()
|
||||
await self._set_status(
|
||||
DownloadStatus.QUEUED, error="remote file changed; restarting"
|
||||
)
|
||||
except RetryableError as e:
|
||||
await self._persist_progress(force=True)
|
||||
await self._set_status(DownloadStatus.QUEUED, error=str(e))
|
||||
except FatalError as e:
|
||||
await self._close_writer()
|
||||
self._remove_temp()
|
||||
await self._set_status(DownloadStatus.FAILED, error=str(e))
|
||||
except Exception as e: # unexpected -> treat as retryable
|
||||
logging.warning(
|
||||
"[model_downloader] %s unexpected error: %s",
|
||||
self.spec.model_id, e, exc_info=True,
|
||||
)
|
||||
await self._persist_progress(force=True)
|
||||
await self._set_status(DownloadStatus.QUEUED, error=f"{type(e).__name__}: {e}")
|
||||
finally:
|
||||
await self._close_writer()
|
||||
return self.state.status
|
||||
|
||||
# ----- probe + plan -----
|
||||
|
||||
async def _probe_and_plan(self):
|
||||
pr = await probe(self.spec.url, credential_id=self.spec.credential_id)
|
||||
if not pr.ok:
|
||||
if pr.gated:
|
||||
raise FatalError(gated_error_message(self.spec.url, pr))
|
||||
if pr.status == 0 or pr.status in _RETRYABLE_STATUSES:
|
||||
raise RetryableError(pr.error or "probe failed")
|
||||
raise FatalError(pr.error or f"probe returned HTTP {pr.status}")
|
||||
|
||||
max_bytes = self._max_download_bytes()
|
||||
if max_bytes is not None and pr.total_bytes is not None and pr.total_bytes > max_bytes:
|
||||
raise FatalError(
|
||||
f"file size {pr.total_bytes} exceeds the maximum allowed "
|
||||
f"download size {max_bytes} (--download-max-bytes)"
|
||||
)
|
||||
|
||||
self._etag = pr.etag or self._etag
|
||||
self.state.total_bytes = pr.total_bytes
|
||||
await asyncio.to_thread(
|
||||
queries.update_download,
|
||||
self.spec.download_id,
|
||||
final_url=pr.final_url,
|
||||
total_bytes=pr.total_bytes,
|
||||
accept_ranges=pr.accept_ranges,
|
||||
etag=pr.etag,
|
||||
last_modified=pr.last_modified,
|
||||
)
|
||||
|
||||
seg_count = effective_segment_count(
|
||||
pr.total_bytes, pr.accept_ranges, max(1, args.download_segments)
|
||||
)
|
||||
existing = await asyncio.to_thread(queries.list_segments, self.spec.download_id)
|
||||
can_resume_segmented = (
|
||||
seg_count > 1
|
||||
and existing
|
||||
and pr.total_bytes is not None
|
||||
and existing[-1].end_offset == pr.total_bytes - 1
|
||||
)
|
||||
if can_resume_segmented and not self._segmented_part_valid(pr.total_bytes):
|
||||
# The persisted per-segment offsets describe bytes in a preallocated
|
||||
# .part that is now gone or the wrong size (e.g. the partial of a
|
||||
# failed download was swept on restart, or removed by a fatal
|
||||
# error). Trusting them would skip already-"complete" segments and
|
||||
# leave zero-filled holes. Discard the offsets and re-plan fresh.
|
||||
logging.info(
|
||||
"[model_downloader] %s discarding segmented resume offsets "
|
||||
"(preallocated .part missing or wrong size); restarting",
|
||||
self.spec.model_id,
|
||||
)
|
||||
self._remove_temp()
|
||||
await asyncio.to_thread(
|
||||
queries.replace_segments, self.spec.download_id, []
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, self.spec.download_id, bytes_done=0
|
||||
)
|
||||
existing = []
|
||||
can_resume_segmented = False
|
||||
|
||||
if can_resume_segmented:
|
||||
# Resume an existing segmented plan.
|
||||
self.state.segments = [
|
||||
SegmentRuntime(s.idx, s.start_offset, s.end_offset, s.bytes_done)
|
||||
for s in existing
|
||||
]
|
||||
elif seg_count > 1 and pr.total_bytes is not None:
|
||||
plans = plan_segments(pr.total_bytes, seg_count)
|
||||
await asyncio.to_thread(
|
||||
queries.replace_segments,
|
||||
self.spec.download_id,
|
||||
[
|
||||
{"idx": p.idx, "start_offset": p.start, "end_offset": p.end, "bytes_done": 0}
|
||||
for p in plans
|
||||
],
|
||||
)
|
||||
self.state.segments = [SegmentRuntime(p.idx, p.start, p.end, 0) for p in plans]
|
||||
else:
|
||||
# Single-stream: one logical segment; bytes_done tracked on the row.
|
||||
row = await asyncio.to_thread(queries.get_download, self.spec.download_id)
|
||||
resume_from = row.bytes_done if row else 0
|
||||
end = (pr.total_bytes - 1) if pr.total_bytes else -1
|
||||
# ``row.bytes_done`` may be the SUM of per-segment offsets from a
|
||||
# prior segmented run (a preallocated, non-contiguous .part). A
|
||||
# single-stream resume writes a contiguous prefix, so the offset is
|
||||
# only trustworthy when the on-disk file is exactly that many
|
||||
# contiguous bytes. This guards the case where a download that ran
|
||||
# segmented now resolves to one segment (server dropped
|
||||
# Accept-Ranges, or --download-segments was lowered between runs):
|
||||
# resuming over non-contiguous data would corrupt the output.
|
||||
if resume_from > 0 and not self._contiguous_prefix_valid(resume_from):
|
||||
logging.info(
|
||||
"[model_downloader] %s discarding untrusted resume offset "
|
||||
"%d (on-disk .part not a contiguous prefix); restarting",
|
||||
self.spec.model_id, resume_from,
|
||||
)
|
||||
resume_from = 0
|
||||
self._remove_temp()
|
||||
if await asyncio.to_thread(queries.list_segments, self.spec.download_id):
|
||||
await asyncio.to_thread(
|
||||
queries.replace_segments, self.spec.download_id, []
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, self.spec.download_id, bytes_done=0
|
||||
)
|
||||
self.state.segments = [SegmentRuntime(0, 0, end, resume_from)]
|
||||
self._recompute_bytes_done()
|
||||
return pr
|
||||
|
||||
# ----- transfer -----
|
||||
|
||||
async def _transfer(self, pr) -> None:
|
||||
self._writer = FileWriter(self.spec.temp_path)
|
||||
await self._writer.open()
|
||||
|
||||
segmented = len(self.state.segments) > 1
|
||||
if segmented and self.state.total_bytes:
|
||||
await self._writer.preallocate(self.state.total_bytes)
|
||||
await self._run_segmented()
|
||||
else:
|
||||
await self._run_single()
|
||||
|
||||
await self._writer.flush()
|
||||
|
||||
async def _run_segmented(self) -> None:
|
||||
pending = [
|
||||
asyncio.ensure_future(self._run_segment(seg))
|
||||
for seg in self.state.segments
|
||||
if seg.bytes_done < seg.length
|
||||
]
|
||||
if not pending:
|
||||
return
|
||||
done, not_done = await asyncio.wait(
|
||||
pending, return_when=asyncio.FIRST_EXCEPTION
|
||||
)
|
||||
first_exc: Optional[BaseException] = None
|
||||
for task in done:
|
||||
exc = task.exception()
|
||||
if exc is not None and first_exc is None:
|
||||
first_exc = exc
|
||||
if first_exc is not None:
|
||||
for task in not_done:
|
||||
task.cancel()
|
||||
await asyncio.gather(*not_done, return_exceptions=True)
|
||||
raise first_exc
|
||||
|
||||
async def _run_segment(self, seg: SegmentRuntime) -> None:
|
||||
offset = seg.start + seg.bytes_done
|
||||
headers = {
|
||||
"Range": f"bytes={offset}-{seg.end}",
|
||||
"Accept-Encoding": "identity",
|
||||
}
|
||||
if self._etag:
|
||||
headers["If-Range"] = self._etag
|
||||
async with open_validated(
|
||||
"GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers
|
||||
) as (resp, _final):
|
||||
if resp.status == 200:
|
||||
# Server ignored the range -> remote changed / no resume support.
|
||||
raise RemoteChanged()
|
||||
if resp.status not in (206,):
|
||||
self._raise_for_status(resp.status)
|
||||
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
|
||||
self._check_control()
|
||||
# Never write past this segment's planned range: a
|
||||
# non-conforming 206 that returns more than the requested
|
||||
# bytes would otherwise overrun adjacent segments and the
|
||||
# preallocated file. Cap the write and abort on overflow.
|
||||
remaining = seg.length - seg.bytes_done
|
||||
if remaining <= 0:
|
||||
raise FatalError(
|
||||
f"segment {seg.idx}: server returned more than the "
|
||||
f"requested {seg.length} bytes"
|
||||
)
|
||||
overflow = len(chunk) > remaining
|
||||
if overflow:
|
||||
chunk = chunk[:remaining]
|
||||
await self._writer.write_at(offset, chunk)
|
||||
offset += len(chunk)
|
||||
seg.bytes_done += len(chunk)
|
||||
self._recompute_bytes_done()
|
||||
await self._persist_progress()
|
||||
if overflow:
|
||||
raise FatalError(
|
||||
f"segment {seg.idx}: server returned more than the "
|
||||
f"requested {seg.length} bytes"
|
||||
)
|
||||
|
||||
async def _run_single(self) -> None:
|
||||
seg = self.state.segments[0]
|
||||
offset = seg.bytes_done # resume from here for single-stream
|
||||
headers = {"Accept-Encoding": "identity"}
|
||||
if offset > 0:
|
||||
headers["Range"] = f"bytes={offset}-"
|
||||
if self._etag:
|
||||
headers["If-Range"] = self._etag
|
||||
async with open_validated(
|
||||
"GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers
|
||||
) as (resp, _final):
|
||||
if offset > 0 and resp.status == 200:
|
||||
# Resume not honoured -> start over from the beginning. Truncate
|
||||
# the existing partial so stale trailing bytes from the prior
|
||||
# attempt cannot survive past the new (possibly shorter) end.
|
||||
offset = 0
|
||||
seg.bytes_done = 0
|
||||
self.state.bytes_done = 0
|
||||
await self._writer.truncate(0)
|
||||
elif offset > 0 and resp.status != 206:
|
||||
self._raise_for_status(resp.status)
|
||||
elif offset == 0 and resp.status != 200:
|
||||
self._raise_for_status(resp.status)
|
||||
# Byte ceiling for this stream: the known total when the server
|
||||
# reported a size, otherwise the configured maximum download size.
|
||||
# Without a bound, a non-conforming response or an unknown-length
|
||||
# stream (end == -1) that never closes could fill the disk (DoS).
|
||||
limit = (seg.end + 1) if seg.end >= 0 else self._max_download_bytes()
|
||||
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
|
||||
self._check_control()
|
||||
overflow = False
|
||||
if limit is not None:
|
||||
remaining = limit - offset
|
||||
if remaining <= 0:
|
||||
raise FatalError(
|
||||
f"download exceeded the maximum size {limit} bytes"
|
||||
)
|
||||
if len(chunk) > remaining:
|
||||
chunk = chunk[:remaining]
|
||||
overflow = True
|
||||
await self._writer.write_at(offset, chunk)
|
||||
offset += len(chunk)
|
||||
seg.bytes_done = offset
|
||||
self.state.bytes_done = offset
|
||||
await self._persist_progress()
|
||||
if overflow:
|
||||
raise FatalError(
|
||||
f"download exceeded the maximum size {limit} bytes"
|
||||
)
|
||||
|
||||
def _max_download_bytes(self) -> Optional[int]:
|
||||
"""Configured maximum download size in bytes, or ``None`` if disabled."""
|
||||
cap = getattr(args, "download_max_bytes", 0)
|
||||
return cap if cap and cap > 0 else None
|
||||
|
||||
def _raise_for_status(self, status: int) -> None:
|
||||
if status in (401, 403):
|
||||
raise FatalError(
|
||||
f"{redact_url(self.spec.url)} returned {status}; add/update an API key for "
|
||||
f"this host at /api/download/credentials."
|
||||
)
|
||||
if status in _RETRYABLE_STATUSES:
|
||||
raise RetryableError(f"HTTP {status}")
|
||||
raise FatalError(f"unexpected HTTP {status}")
|
||||
|
||||
# ----- finalize / verify (PRD section 8.4) -----
|
||||
|
||||
async def _finalize(self) -> None:
|
||||
self._check_control()
|
||||
await self._close_writer()
|
||||
await self._set_status(DownloadStatus.VERIFYING)
|
||||
|
||||
total = self.state.total_bytes
|
||||
segmented = len(self.state.segments) > 1
|
||||
if segmented:
|
||||
# The .part was preallocated to total_bytes, so its on-disk size is
|
||||
# not evidence of completeness: a segment that ends short (truncated
|
||||
# 206 / server closes mid-range) leaves a zero-filled hole while the
|
||||
# file size still equals total. Verify each segment wrote its full
|
||||
# planned range, and trust the byte counter (== sum of segments)
|
||||
# rather than os.path.getsize for the total check.
|
||||
for seg in self.state.segments:
|
||||
if seg.bytes_done != seg.length:
|
||||
raise FatalError(
|
||||
f"segment {seg.idx} incomplete: wrote {seg.bytes_done} "
|
||||
f"of {seg.length} bytes"
|
||||
)
|
||||
observed = self.state.bytes_done
|
||||
else:
|
||||
# Single-stream writes a contiguous prefix, so the on-disk size is
|
||||
# an independent witness of how much actually landed.
|
||||
observed = os.path.getsize(self.spec.temp_path)
|
||||
if total is not None and observed != total:
|
||||
raise FatalError(
|
||||
f"size mismatch: wrote {observed} of {total} bytes"
|
||||
)
|
||||
|
||||
# Structural gate (cheap, no full read) then optional sha256 (full read).
|
||||
# Both failures are non-retryable (a truncated/corrupt or mismatched file
|
||||
# will not heal on retry), so surface them as FatalError rather than
|
||||
# letting the plain Exceptions fall through to the retryable handler.
|
||||
# ``temp_path`` carries the ``.part`` suffix; pass ``dest_path`` so the
|
||||
# structural check detects the real file format instead of skipping it.
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
structural.validate, self.spec.temp_path, self.spec.dest_path
|
||||
)
|
||||
if self.spec.expected_sha256:
|
||||
await asyncio.to_thread(
|
||||
checksum.verify_sha256,
|
||||
self.spec.temp_path,
|
||||
self.spec.expected_sha256,
|
||||
)
|
||||
except (structural.StructuralError, checksum.ChecksumError) as e:
|
||||
raise FatalError(str(e)) from e
|
||||
|
||||
os.makedirs(os.path.dirname(self.spec.dest_path), exist_ok=True)
|
||||
os.replace(self.spec.temp_path, self.spec.dest_path)
|
||||
logging.info(
|
||||
"[model_downloader] completed %s (%d bytes)",
|
||||
self.spec.model_id, observed,
|
||||
)
|
||||
# Catalog into the assets system (blake3 dedup identity). Best-effort.
|
||||
await dedup.register_completed(self.spec.dest_path)
|
||||
|
||||
# ----- helpers -----
|
||||
|
||||
def _recompute_bytes_done(self) -> None:
|
||||
self.state.bytes_done = sum(s.bytes_done for s in self.state.segments)
|
||||
now = time.monotonic()
|
||||
dt = now - self.state._last_time
|
||||
if dt >= 0.5:
|
||||
self.state.speed_bps = (self.state.bytes_done - self.state._last_bytes) / dt
|
||||
self.state._last_bytes = self.state.bytes_done
|
||||
self.state._last_time = now
|
||||
|
||||
async def _persist_progress(self, force: bool = False) -> None:
|
||||
# Both the DB write and the websocket notify are gated by the same
|
||||
# throttle: persisting hits SQLite, and notifying broadcasts to every
|
||||
# client, so doing either per-chunk (small --download-chunk-size or
|
||||
# many concurrent segments) would overwhelm both. Skip entirely inside
|
||||
# the window; the next persist (or a forced one) ships the latest bytes.
|
||||
now = time.monotonic()
|
||||
if not force and now - self._last_persist < _PERSIST_INTERVAL:
|
||||
return
|
||||
self._last_persist = now
|
||||
# SQLite is blocking; run it off the event loop per the queries module
|
||||
# contract so progress persists don't stall the web server.
|
||||
await asyncio.to_thread(self._write_progress)
|
||||
if self._notify:
|
||||
self._notify(self.spec.download_id)
|
||||
|
||||
def _write_progress(self) -> None:
|
||||
queries.update_download(self.spec.download_id, bytes_done=self.state.bytes_done)
|
||||
for seg in self.state.segments:
|
||||
if seg.end >= seg.start: # skip unknown-size sentinel
|
||||
queries.update_segment_progress(
|
||||
self.spec.download_id, seg.idx, seg.bytes_done
|
||||
)
|
||||
|
||||
async def _reset_for_restart(self) -> None:
|
||||
await self._close_writer()
|
||||
self._remove_temp()
|
||||
for seg in self.state.segments:
|
||||
seg.bytes_done = 0
|
||||
self.state.bytes_done = 0
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, self.spec.download_id, bytes_done=0
|
||||
)
|
||||
if await asyncio.to_thread(queries.list_segments, self.spec.download_id):
|
||||
await asyncio.to_thread(
|
||||
queries.replace_segments, self.spec.download_id, []
|
||||
)
|
||||
|
||||
async def _close_writer(self) -> None:
|
||||
if self._writer is not None:
|
||||
try:
|
||||
await self._writer.close()
|
||||
except Exception:
|
||||
logging.debug("[model_downloader] writer close error", exc_info=True)
|
||||
self._writer = None
|
||||
|
||||
def _segmented_part_valid(self, total_bytes: int) -> bool:
|
||||
"""True when the temp file is the preallocated segmented ``.part``.
|
||||
|
||||
A segmented transfer preallocates the .part to ``total_bytes`` up front
|
||||
and tracks how much of each range landed via per-segment offsets. Those
|
||||
offsets are only trustworthy when the file they describe is still on
|
||||
disk at its full preallocated size. A missing file (swept after a
|
||||
failure, removed on a fatal error, deleted by hand) or a wrong-sized one
|
||||
means the persisted offsets no longer correspond to real bytes and must
|
||||
not be resumed over. Doing so would skip "complete" segments and leave
|
||||
zero-filled holes that pass the size-only verification gate.
|
||||
"""
|
||||
try:
|
||||
return os.path.getsize(self.spec.temp_path) == total_bytes
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
def _contiguous_prefix_valid(self, prefix_len: int) -> bool:
|
||||
"""True when the temp file is exactly ``prefix_len`` contiguous bytes.
|
||||
|
||||
Single-stream resume appends sequentially, so a valid resume point
|
||||
implies the .part size equals the persisted offset. A larger file (e.g.
|
||||
one preallocated to ``total_bytes`` by a previous segmented run) or a
|
||||
missing/short file means the persisted offset is not a trustworthy
|
||||
contiguous prefix and must not be resumed over.
|
||||
"""
|
||||
try:
|
||||
return os.path.getsize(self.spec.temp_path) == prefix_len
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
def _remove_temp(self) -> None:
|
||||
try:
|
||||
os.remove(self.spec.temp_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError as e:
|
||||
logging.warning(
|
||||
"[model_downloader] could not remove %s: %s", self.spec.temp_path, e
|
||||
)
|
||||
|
||||
async def _set_status(self, status: str, error: Optional[str] = None) -> None:
|
||||
# ``error`` is authoritative: passing None clears any prior failure
|
||||
# text so transitions out of a failure state (retry/success) don't
|
||||
# leave stale messages on RuntimeState or in the persisted row.
|
||||
self.state.status = status
|
||||
self.state.error = error
|
||||
fields = {"status": status, "bytes_done": self.state.bytes_done, "error": error}
|
||||
if status == DownloadStatus.QUEUED:
|
||||
fields["attempts"] = self.spec.attempts + 1
|
||||
self.spec.attempts += 1
|
||||
await asyncio.to_thread(queries.update_download, self.spec.download_id, **fields)
|
||||
if self._notify:
|
||||
self._notify(self.spec.download_id)
|
||||
@ -1,51 +0,0 @@
|
||||
"""Segment planning.
|
||||
|
||||
Split a known byte range into S roughly-equal segments, each fetched by its
|
||||
own coroutine with ``Range: bytes=start-end``. Falls back to a single segment
|
||||
when the server doesn't support ranges or the size is unknown/too small for
|
||||
segmentation to be worthwhile.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Below this size, the per-connection setup cost outweighs any parallelism.
|
||||
_MIN_SEGMENT_BYTES = 1 * 1024 * 1024
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SegmentPlan:
|
||||
idx: int
|
||||
start: int
|
||||
end: int # inclusive
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return self.end - self.start + 1
|
||||
|
||||
|
||||
def effective_segment_count(
|
||||
total_bytes: int | None, accept_ranges: bool, configured: int
|
||||
) -> int:
|
||||
"""How many segments to actually use for this file."""
|
||||
if not accept_ranges or total_bytes is None or total_bytes <= 0:
|
||||
return 1
|
||||
by_size = max(1, total_bytes // _MIN_SEGMENT_BYTES)
|
||||
return max(1, min(configured, by_size))
|
||||
|
||||
|
||||
def plan_segments(total_bytes: int, num_segments: int) -> list[SegmentPlan]:
|
||||
"""Return ``num_segments`` contiguous, inclusive byte ranges covering [0, total)."""
|
||||
if total_bytes <= 0 or num_segments <= 1:
|
||||
return [SegmentPlan(idx=0, start=0, end=max(0, total_bytes - 1))]
|
||||
base = total_bytes // num_segments
|
||||
plans: list[SegmentPlan] = []
|
||||
start = 0
|
||||
for i in range(num_segments):
|
||||
# Last segment soaks up the remainder.
|
||||
length = base if i < num_segments - 1 else total_bytes - start
|
||||
end = start + length - 1
|
||||
plans.append(SegmentPlan(idx=i, start=start, end=end))
|
||||
start = end + 1
|
||||
return plans
|
||||
@ -1,110 +0,0 @@
|
||||
"""Positioned, off-loop file writes.
|
||||
|
||||
Network I/O stays on the event loop; every blocking disk op (preallocate,
|
||||
positioned write, fsync) is run in a bounded thread pool via
|
||||
``run_in_executor`` so downloads never stall inference or the web server.
|
||||
|
||||
A single file descriptor is opened for the whole download. Segments write to
|
||||
their own offsets with ``os.pwrite`` — which is offset-addressed and atomic
|
||||
per call, so concurrent segment writers need no extra locking. Per-chunk
|
||||
fsync is avoided; we fsync once at completion.
|
||||
|
||||
``os.pwrite`` is unavailable on Windows, so there we fall back to
|
||||
``os.lseek`` + ``os.write`` guarded by a per-writer lock (the seek/write pair
|
||||
is not atomic, so concurrent segment writers must be serialized).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
# One shared, bounded pool for all download disk I/O.
|
||||
_EXECUTOR = ThreadPoolExecutor(max_workers=8, thread_name_prefix="dl-writer")
|
||||
|
||||
_HAS_PWRITE = hasattr(os, "pwrite")
|
||||
|
||||
# On Windows ``os.open`` defaults to text mode, which translates every ``\n``
|
||||
# byte into ``\r\n`` on write and corrupts binary payloads (the file grows by
|
||||
# one byte per 0x0A). ``O_BINARY`` disables that translation; it does not exist
|
||||
# on POSIX, where the default is already binary.
|
||||
_O_BINARY = getattr(os, "O_BINARY", 0)
|
||||
|
||||
|
||||
class FileWriter:
|
||||
"""Owns the ``.part`` file descriptor for one download."""
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self._fd: Optional[int] = None
|
||||
# Serializes lseek+write on platforms without os.pwrite (Windows).
|
||||
self._seek_lock = threading.Lock()
|
||||
|
||||
def _open(self) -> None:
|
||||
os.makedirs(os.path.dirname(self.path), exist_ok=True)
|
||||
self._fd = os.open(self.path, os.O_RDWR | os.O_CREAT | _O_BINARY, 0o644)
|
||||
|
||||
async def open(self) -> None:
|
||||
await asyncio.get_running_loop().run_in_executor(_EXECUTOR, self._open)
|
||||
|
||||
async def preallocate(self, size: int) -> None:
|
||||
"""Grow the file to ``size`` so segments write to their offsets."""
|
||||
if self._fd is None or size <= 0:
|
||||
return
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
_EXECUTOR, os.ftruncate, self._fd, size
|
||||
)
|
||||
|
||||
async def truncate(self, size: int = 0) -> None:
|
||||
"""Truncate the file to ``size`` bytes (default: empty it)."""
|
||||
if self._fd is None:
|
||||
return
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
_EXECUTOR, os.ftruncate, self._fd, size
|
||||
)
|
||||
|
||||
def _pwrite_all(self, data: bytes, offset: int) -> None:
|
||||
"""A positioned write may write fewer bytes than requested (signal
|
||||
interruption, near-ENOSPC); loop until every byte lands so we never
|
||||
leave a gap while the caller advances by the full chunk length.
|
||||
|
||||
Uses ``os.pwrite`` where available (offset-addressed, atomic per call).
|
||||
On Windows it falls back to ``os.lseek`` + ``os.write`` under a lock,
|
||||
since that pair is not atomic across concurrent segment writers."""
|
||||
assert self._fd is not None, "writer not opened"
|
||||
view = memoryview(data)
|
||||
written = 0
|
||||
total = len(view)
|
||||
while written < total:
|
||||
if _HAS_PWRITE:
|
||||
n = os.pwrite(self._fd, view[written:], offset + written)
|
||||
else:
|
||||
with self._seek_lock:
|
||||
os.lseek(self._fd, offset + written, os.SEEK_SET)
|
||||
n = os.write(self._fd, view[written:])
|
||||
if n == 0:
|
||||
raise OSError(
|
||||
f"positioned write wrote 0 bytes at offset {offset + written} "
|
||||
f"({written}/{total} bytes written)"
|
||||
)
|
||||
written += n
|
||||
|
||||
async def write_at(self, offset: int, data: bytes) -> None:
|
||||
assert self._fd is not None, "writer not opened"
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
_EXECUTOR, self._pwrite_all, data, offset
|
||||
)
|
||||
|
||||
async def flush(self) -> None:
|
||||
if self._fd is None:
|
||||
return
|
||||
await asyncio.get_running_loop().run_in_executor(_EXECUTOR, os.fsync, self._fd)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._fd is None:
|
||||
return
|
||||
fd, self._fd = self._fd, None
|
||||
await asyncio.get_running_loop().run_in_executor(_EXECUTOR, os.close, fd)
|
||||
@ -1,454 +0,0 @@
|
||||
"""Public facade for the download manager.
|
||||
|
||||
This is the only object the server imports. It validates requests, owns the
|
||||
:class:`Scheduler`, and exposes a small async API plus read models for status.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Callable, Optional
|
||||
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.net.probe import gated_error_message, probe
|
||||
from app.model_downloader.scheduler import SCHEDULER
|
||||
from app.model_downloader.security import paths
|
||||
from app.model_downloader.net.http import redact_url
|
||||
from app.model_downloader.security.allowlist import (
|
||||
ALLOWED_MODEL_EXTENSIONS,
|
||||
filename_extension,
|
||||
is_host_allowed_url,
|
||||
is_url_downloadable,
|
||||
url_path_extension,
|
||||
)
|
||||
from app.model_downloader.security.paths import InvalidModelId
|
||||
|
||||
# Non-terminal statuses: an existing row in one of these blocks a re-enqueue.
|
||||
_LIVE_STATUSES = (
|
||||
DownloadStatus.QUEUED,
|
||||
DownloadStatus.ACTIVE,
|
||||
DownloadStatus.PAUSED,
|
||||
DownloadStatus.VERIFYING,
|
||||
)
|
||||
|
||||
|
||||
class DownloadError(Exception):
|
||||
"""A user-facing error with a stable machine-readable code."""
|
||||
|
||||
def __init__(self, code: str, message: str, status: int = 400) -> None:
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.http_status = status
|
||||
|
||||
|
||||
class DownloadManager:
|
||||
def __init__(self) -> None:
|
||||
self._scheduler = SCHEDULER
|
||||
self._notify_cb: Optional[Callable[[str], None]] = None
|
||||
# Serializes the "check for a live download, then write" critical section
|
||||
# per model_id. ``downloads`` has no uniqueness constraint on model_id
|
||||
# (history rows are kept), so without this two concurrent enqueue/resume
|
||||
# calls could both pass the live check and admit two jobs sharing one
|
||||
# temp/dest path. The manager is a process singleton over a local SQLite
|
||||
# DB, so an in-process lock is sufficient (and avoids a migration).
|
||||
self._model_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
def set_notify(self, cb: Optional[Callable[[str], None]]) -> None:
|
||||
self._notify_cb = cb
|
||||
self._scheduler.set_notify(cb)
|
||||
|
||||
async def start(self) -> None:
|
||||
await self._scheduler.start()
|
||||
|
||||
# ----- enqueue -----
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
url: str,
|
||||
model_id: str,
|
||||
*,
|
||||
priority: int = 0,
|
||||
expected_sha256: Optional[str] = None,
|
||||
allow_any_extension: bool = False,
|
||||
credential_id: Optional[str] = None,
|
||||
) -> str:
|
||||
# Coarse gate first: host/scheme must be allowlisted, and any extension
|
||||
# present in the URL path must be a known model type. A URL whose path
|
||||
# carries NO extension (e.g. Civitai's ``/api/download/models/<id>``) is
|
||||
# admitted here and its real extension is resolved from the network
|
||||
# below before the download is finally accepted.
|
||||
if allow_any_extension:
|
||||
if not is_host_allowed_url(url):
|
||||
raise DownloadError(
|
||||
"URL_NOT_ALLOWED",
|
||||
"URL is not on the download allowlist (host/scheme).",
|
||||
)
|
||||
elif not is_url_downloadable(url):
|
||||
raise DownloadError(
|
||||
"URL_NOT_ALLOWED",
|
||||
"URL is not on the download allowlist (host/scheme/extension).",
|
||||
)
|
||||
|
||||
# When the URL path has no extension, follow it to where it resolves and
|
||||
# adopt the real extension from the response, forcing the stored
|
||||
# filename to match. Skipped when the caller opted into any extension.
|
||||
if not allow_any_extension and url_path_extension(url) == "":
|
||||
resolved_ext = await self._resolve_extension(url, credential_id)
|
||||
model_id = paths.apply_extension(model_id, resolved_ext)
|
||||
|
||||
try:
|
||||
paths.parse_model_id(model_id, allow_any_extension)
|
||||
dest_path, temp_path = paths.resolve_destination(model_id, allow_any_extension)
|
||||
except InvalidModelId as e:
|
||||
raise DownloadError("INVALID_MODEL_ID", str(e))
|
||||
|
||||
if await asyncio.to_thread(
|
||||
paths.resolve_existing, model_id, allow_any_extension
|
||||
):
|
||||
raise DownloadError(
|
||||
"ALREADY_AVAILABLE",
|
||||
f"Model already exists on disk: {model_id}",
|
||||
status=409,
|
||||
)
|
||||
download_id = str(uuid.uuid4())
|
||||
# Hold the per-model lock across the live check and the insert so a
|
||||
# concurrent enqueue/resume for the same model_id cannot interleave
|
||||
# between them and create a second job against the same temp/dest path.
|
||||
async with self._model_lock(model_id):
|
||||
if await self._has_live_download(model_id):
|
||||
raise DownloadError(
|
||||
"ALREADY_DOWNLOADING",
|
||||
f"A download for {model_id} is already in progress.",
|
||||
status=409,
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
queries.insert_download,
|
||||
{
|
||||
"id": download_id,
|
||||
"url": url,
|
||||
"model_id": model_id,
|
||||
"dest_path": dest_path,
|
||||
"temp_path": temp_path,
|
||||
"status": DownloadStatus.QUEUED,
|
||||
"priority": priority,
|
||||
"expected_sha256": expected_sha256,
|
||||
"credential_id": credential_id,
|
||||
"allow_any_extension": allow_any_extension,
|
||||
},
|
||||
)
|
||||
logging.info("[model_downloader] enqueued %s -> %s", redact_url(url), model_id)
|
||||
await self._scheduler.pump()
|
||||
return download_id
|
||||
|
||||
async def _resolve_extension(
|
||||
self, url: str, credential_id: Optional[str]
|
||||
) -> str:
|
||||
"""Follow ``url`` to its final response and return the real extension.
|
||||
|
||||
Used for allowlisted URLs whose path has no extension (e.g. Civitai
|
||||
download endpoints): the filename lives in the ``Content-Disposition``
|
||||
header or the post-redirect URL. Raises :class:`DownloadError` when the
|
||||
URL can't be resolved, needs credentials, or resolves to something that
|
||||
is not a known model file — so we never persist a bogus destination.
|
||||
"""
|
||||
pr = await probe(url, credential_id=credential_id)
|
||||
if not pr.ok:
|
||||
if pr.gated:
|
||||
raise DownloadError(
|
||||
"GATED_REPO" if pr.is_gated_repo else "CREDENTIALS_REQUIRED",
|
||||
gated_error_message(url, pr),
|
||||
status=401,
|
||||
)
|
||||
raise DownloadError(
|
||||
"URL_RESOLVE_FAILED",
|
||||
f"Could not resolve {redact_url(url)}: {pr.error or 'unknown error'}",
|
||||
status=502,
|
||||
)
|
||||
ext = filename_extension(pr.filename) if pr.filename else ""
|
||||
if ext not in ALLOWED_MODEL_EXTENSIONS:
|
||||
raise DownloadError(
|
||||
"URL_NOT_ALLOWED",
|
||||
f"URL resolves to {pr.filename or '<unknown>'!r}, which is not a "
|
||||
f"known model file type {ALLOWED_MODEL_EXTENSIONS}.",
|
||||
)
|
||||
return ext
|
||||
|
||||
def _model_lock(self, model_id: str) -> asyncio.Lock:
|
||||
# Lazily create one lock per model_id. There is no ``await`` between the
|
||||
# lookup and the insert, so under the single asyncio thread this is
|
||||
# atomic and cannot hand out two different locks for the same model_id.
|
||||
lock = self._model_locks.get(model_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
self._model_locks[model_id] = lock
|
||||
return lock
|
||||
|
||||
async def _has_live_download(
|
||||
self, model_id: str, *, exclude_id: Optional[str] = None
|
||||
) -> bool:
|
||||
rows = await asyncio.to_thread(queries.list_downloads)
|
||||
return any(
|
||||
r.model_id == model_id
|
||||
and r.id != exclude_id
|
||||
and r.status in _LIVE_STATUSES
|
||||
for r in rows
|
||||
)
|
||||
|
||||
# ----- control -----
|
||||
|
||||
async def pause(self, download_id: str) -> None:
|
||||
job = self._scheduler.get_job(download_id)
|
||||
if job is not None:
|
||||
job.request_pause()
|
||||
return
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
if row is None:
|
||||
raise DownloadError("NOT_FOUND", "No such download.", status=404)
|
||||
if row.status == DownloadStatus.QUEUED:
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, download_id, status=DownloadStatus.PAUSED
|
||||
)
|
||||
|
||||
async def resume(self, download_id: str) -> None:
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
if row is None:
|
||||
raise DownloadError("NOT_FOUND", "No such download.", status=404)
|
||||
if row.status not in (DownloadStatus.PAUSED, DownloadStatus.FAILED):
|
||||
return
|
||||
# Re-queueing a paused/failed row must respect the single-live-per-model
|
||||
# invariant: another download (e.g. a newer enqueue) may already be live
|
||||
# for this model_id and would share this row's temp/dest path. Hold the
|
||||
# per-model lock across the check and the status flip, and exclude this
|
||||
# row itself (a paused row is already a "live" status).
|
||||
async with self._model_lock(row.model_id):
|
||||
if await self._has_live_download(row.model_id, exclude_id=download_id):
|
||||
raise DownloadError(
|
||||
"ALREADY_DOWNLOADING",
|
||||
f"A download for {row.model_id} is already in progress.",
|
||||
status=409,
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
queries.update_download,
|
||||
download_id,
|
||||
status=DownloadStatus.QUEUED,
|
||||
error=None,
|
||||
)
|
||||
await self._scheduler.pump()
|
||||
|
||||
async def cancel(self, download_id: str) -> None:
|
||||
job = self._scheduler.get_job(download_id)
|
||||
if job is not None:
|
||||
job.request_cancel()
|
||||
return
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
if row is None:
|
||||
raise DownloadError("NOT_FOUND", "No such download.", status=404)
|
||||
if row.status in _LIVE_STATUSES:
|
||||
import os
|
||||
|
||||
try:
|
||||
os.remove(row.temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, download_id, status=DownloadStatus.CANCELLED
|
||||
)
|
||||
|
||||
async def set_priority(self, download_id: str, priority: int) -> None:
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
if row is None:
|
||||
raise DownloadError("NOT_FOUND", "No such download.", status=404)
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, download_id, priority=priority
|
||||
)
|
||||
# Admission-order only; a higher priority is
|
||||
# picked up the next time a slot frees. Pump in case a slot is free now.
|
||||
await self._scheduler.pump()
|
||||
|
||||
async def delete(self, download_id: str) -> None:
|
||||
"""Delete a terminal download so it stays gone from history.
|
||||
|
||||
Refuses to delete a live download so a record is never removed out from
|
||||
under a running worker; cancel it first. Any leftover ``.part`` temp
|
||||
file (e.g. from a failed transfer) is removed, but the finished model
|
||||
file on disk is never touched.
|
||||
"""
|
||||
if self._scheduler.get_job(download_id) is not None:
|
||||
raise DownloadError(
|
||||
"DOWNLOAD_ACTIVE",
|
||||
"Cannot delete a download that is still in progress.",
|
||||
status=409,
|
||||
)
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
if row is None:
|
||||
raise DownloadError("NOT_FOUND", "No such download.", status=404)
|
||||
if row.status in _LIVE_STATUSES:
|
||||
raise DownloadError(
|
||||
"DOWNLOAD_ACTIVE",
|
||||
"Cannot delete a download that is still in progress.",
|
||||
status=409,
|
||||
)
|
||||
|
||||
try:
|
||||
os.remove(row.temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
await asyncio.to_thread(queries.delete_download, download_id)
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Delete all terminal downloads from history in one transaction.
|
||||
|
||||
Skips anything still live (queued/active/paused/verifying, or a running
|
||||
job) so an in-flight download is never removed out from under a worker.
|
||||
Finished model files on disk are never touched; only leftover ``.part``
|
||||
temp files from failed/cancelled transfers are removed. Returns the
|
||||
number of history rows deleted.
|
||||
"""
|
||||
|
||||
rows = await asyncio.to_thread(queries.list_downloads)
|
||||
deletable = [
|
||||
r
|
||||
for r in rows
|
||||
if r.status not in _LIVE_STATUSES
|
||||
and self._scheduler.get_job(r.id) is None
|
||||
]
|
||||
if not deletable:
|
||||
return 0
|
||||
for r in deletable:
|
||||
try:
|
||||
os.remove(r.temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
return await asyncio.to_thread(
|
||||
queries.delete_downloads, [r.id for r in deletable]
|
||||
)
|
||||
|
||||
# ----- read models -----
|
||||
|
||||
def _view(self, row) -> dict:
|
||||
"""Combine the persisted row with live in-memory progress, if running."""
|
||||
job = self._scheduler.get_job(row.id)
|
||||
bytes_done = row.bytes_done
|
||||
total = row.total_bytes
|
||||
speed = None
|
||||
eta = None
|
||||
segments = None
|
||||
if job is not None:
|
||||
st = job.state
|
||||
bytes_done = st.bytes_done
|
||||
total = st.total_bytes if st.total_bytes is not None else total
|
||||
speed = st.speed_bps
|
||||
eta = st.eta_seconds
|
||||
segments = [
|
||||
{"idx": s.idx, "bytes_done": s.bytes_done, "length": s.length}
|
||||
for s in st.segments
|
||||
if s.end >= s.start
|
||||
]
|
||||
progress = (bytes_done / total) if total else None
|
||||
return {
|
||||
"download_id": row.id,
|
||||
"model_id": row.model_id,
|
||||
"url": redact_url(row.url),
|
||||
"status": row.status,
|
||||
"priority": row.priority,
|
||||
"total_bytes": total,
|
||||
"bytes_done": bytes_done,
|
||||
"progress": progress,
|
||||
"speed_bps": speed,
|
||||
"eta_seconds": eta,
|
||||
"segments": segments,
|
||||
"error": row.error,
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
}
|
||||
|
||||
def _view_from_state(self, job) -> dict:
|
||||
"""Build a view purely from the live in-memory job state (no DB)."""
|
||||
st = job.state
|
||||
return {
|
||||
"download_id": st.download_id,
|
||||
"model_id": st.model_id,
|
||||
"url": redact_url(st.url),
|
||||
"status": st.status,
|
||||
"priority": st.priority,
|
||||
"total_bytes": st.total_bytes,
|
||||
"bytes_done": st.bytes_done,
|
||||
"progress": st.progress,
|
||||
"speed_bps": st.speed_bps,
|
||||
"eta_seconds": st.eta_seconds,
|
||||
"segments": [
|
||||
{"idx": s.idx, "bytes_done": s.bytes_done, "length": s.length}
|
||||
for s in st.segments
|
||||
if s.end >= s.start
|
||||
],
|
||||
"error": st.error,
|
||||
}
|
||||
|
||||
def status_sync(self, download_id: str) -> Optional[dict]:
|
||||
"""Synchronous status read for the websocket notify path.
|
||||
|
||||
Uses live in-memory state when the job is running (no DB round-trip on
|
||||
the hot path); falls back to a quick DB read otherwise.
|
||||
"""
|
||||
job = self._scheduler.get_job(download_id)
|
||||
if job is not None:
|
||||
return self._view_from_state(job)
|
||||
row = queries.get_download(download_id)
|
||||
return self._view(row) if row is not None else None
|
||||
|
||||
async def status(self, download_id: str) -> Optional[dict]:
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
return self._view(row) if row is not None else None
|
||||
|
||||
async def list(self) -> list[dict]:
|
||||
rows = await asyncio.to_thread(queries.list_downloads)
|
||||
return [self._view(r) for r in rows]
|
||||
|
||||
async def availability(self, models: dict[str, str]) -> dict[str, dict]:
|
||||
"""Bulk per-id ``{state, progress, ...}`` for the frontend poll.
|
||||
|
||||
``state`` is ``available`` (on disk), ``downloading`` (live row), or
|
||||
``missing``. Cheap: a path lookup plus an in-memory/DB status check.
|
||||
"""
|
||||
rows = await asyncio.to_thread(queries.list_downloads)
|
||||
by_model: dict[str, object] = {}
|
||||
for r in rows:
|
||||
if r.status in _LIVE_STATUSES or r.model_id not in by_model:
|
||||
by_model[r.model_id] = r
|
||||
|
||||
# ``url_allowed`` mirrors the coarse enqueue gate (host/scheme + a
|
||||
# non-disallowed extension); URLs whose extension is only known after a
|
||||
# network resolve — e.g. Civitai download endpoints — report allowed.
|
||||
out: dict[str, dict] = {}
|
||||
for model_id, url in models.items():
|
||||
try:
|
||||
exists = await asyncio.to_thread(paths.resolve_existing, model_id)
|
||||
except InvalidModelId:
|
||||
out[model_id] = {"state": "missing", "url_allowed": is_url_downloadable(url)}
|
||||
continue
|
||||
if exists:
|
||||
out[model_id] = {"state": "available", "url_allowed": is_url_downloadable(url)}
|
||||
continue
|
||||
row = by_model.get(model_id)
|
||||
if row is not None and row.status in _LIVE_STATUSES:
|
||||
view = self._view(row)
|
||||
out[model_id] = {
|
||||
"state": "downloading",
|
||||
"url_allowed": is_url_downloadable(url),
|
||||
"download_id": view["download_id"],
|
||||
"progress": view["progress"],
|
||||
"bytes_done": view["bytes_done"],
|
||||
"total_bytes": view["total_bytes"],
|
||||
"speed_bps": view["speed_bps"],
|
||||
}
|
||||
else:
|
||||
out[model_id] = {"state": "missing", "url_allowed": is_url_downloadable(url)}
|
||||
return out
|
||||
|
||||
|
||||
DOWNLOAD_MANAGER = DownloadManager()
|
||||
@ -1,148 +0,0 @@
|
||||
"""Manual, validated redirect-following request opener.
|
||||
|
||||
Automatic redirects are disabled. We follow hops ourselves
|
||||
so that on *every* hop we (a) re-validate scheme + reject credentials-in-URL,
|
||||
(b) recompute which stored credential — if any — applies to that hop's host,
|
||||
and (c) let the connector's resolver screen the IP. This is the single place
|
||||
that attaches credentials, so a token can never ride a redirect to a CDN host.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator, Optional
|
||||
from urllib.parse import unquote, urljoin, urlsplit, urlunsplit
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.model_downloader.credentials.resolver import resolve_auth_for_hop
|
||||
from app.model_downloader.net.session import get_session
|
||||
from app.model_downloader.security.ssrf import (
|
||||
MAX_REDIRECTS,
|
||||
SSRFError,
|
||||
check_redirect_hop,
|
||||
)
|
||||
|
||||
_REDIRECT_CODES = {301, 302, 303, 307, 308}
|
||||
DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120)
|
||||
|
||||
|
||||
def redact_url(url: str) -> str:
|
||||
"""Drop the query string so a query-scheme secret is never logged/stored."""
|
||||
try:
|
||||
parts = urlsplit(url)
|
||||
except ValueError:
|
||||
return "<unparseable-url>"
|
||||
return urlunsplit(parts._replace(query=""))
|
||||
|
||||
|
||||
_CD_FILENAME_STAR = re.compile(
|
||||
r"filename\*\s*=\s*[^']*'[^']*'([^;]+)", re.IGNORECASE
|
||||
)
|
||||
_CD_FILENAME_QUOTED = re.compile(r'filename\s*=\s*"([^"]+)"', re.IGNORECASE)
|
||||
_CD_FILENAME_BARE = re.compile(r"filename\s*=\s*([^;]+)", re.IGNORECASE)
|
||||
|
||||
|
||||
def filename_from_content_disposition(value: Optional[str]) -> Optional[str]:
|
||||
"""Extract the download filename from a ``Content-Disposition`` header.
|
||||
|
||||
Prefers the RFC 5987 ``filename*=`` form (percent-decoded) over the plain
|
||||
``filename=`` form. Any directory components in the value are stripped so a
|
||||
hostile header can only influence the *name*, never the target directory.
|
||||
Returns ``None`` when no filename is present.
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
for pat, decode in (
|
||||
(_CD_FILENAME_STAR, True),
|
||||
(_CD_FILENAME_QUOTED, False),
|
||||
(_CD_FILENAME_BARE, False),
|
||||
):
|
||||
m = pat.search(value)
|
||||
if not m:
|
||||
continue
|
||||
raw = m.group(1).strip().strip('"')
|
||||
if decode:
|
||||
try:
|
||||
raw = unquote(raw)
|
||||
except Exception:
|
||||
pass
|
||||
name = raw.replace("\\", "/").rsplit("/", 1)[-1].strip()
|
||||
if name:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_final_response(
|
||||
method: str,
|
||||
url: str,
|
||||
credential_id: Optional[str],
|
||||
base_headers: dict[str, str],
|
||||
timeout: aiohttp.ClientTimeout,
|
||||
) -> tuple[aiohttp.ClientResponse, str]:
|
||||
"""Follow redirects manually until a non-redirect response.
|
||||
|
||||
Each intermediate redirect response is released before the next hop.
|
||||
Returns the final ``(response, final_url)``; the caller owns releasing it.
|
||||
"""
|
||||
session = await get_session()
|
||||
current = url
|
||||
hops = 0
|
||||
while True:
|
||||
check_redirect_hop(current, is_initial_url=(hops == 0))
|
||||
parts = urlsplit(current)
|
||||
auth = await resolve_auth_for_hop(
|
||||
parts.hostname or "", parts.scheme, explicit_credential_id=credential_id
|
||||
)
|
||||
req_headers = dict(base_headers)
|
||||
req_url = current
|
||||
if auth is not None:
|
||||
req_headers.update(auth.headers)
|
||||
req_url = auth.apply_to_url(current)
|
||||
|
||||
resp = await session.request(
|
||||
method,
|
||||
req_url,
|
||||
allow_redirects=False,
|
||||
headers=req_headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
if resp.status in _REDIRECT_CODES and resp.headers.get("Location"):
|
||||
next_url = urljoin(str(resp.url), resp.headers["Location"])
|
||||
await resp.release()
|
||||
hops += 1
|
||||
if hops > MAX_REDIRECTS:
|
||||
raise SSRFError(
|
||||
f"too many redirects (> {MAX_REDIRECTS}) for {redact_url(url)}"
|
||||
)
|
||||
current = next_url
|
||||
continue
|
||||
return resp, redact_url(str(resp.url))
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def open_validated(
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
credential_id: Optional[str] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
timeout: aiohttp.ClientTimeout = DEFAULT_TIMEOUT,
|
||||
) -> AsyncIterator[tuple[aiohttp.ClientResponse, str]]:
|
||||
"""Open ``method url`` following redirects manually and validated.
|
||||
|
||||
Yields ``(response, final_url)`` where ``final_url`` is redacted of any
|
||||
query string. The response is released automatically on exit.
|
||||
"""
|
||||
resp, final_url = await _resolve_final_response(
|
||||
method, url, credential_id, dict(headers or {}), timeout
|
||||
)
|
||||
try:
|
||||
yield resp, final_url
|
||||
finally:
|
||||
try:
|
||||
await resp.release()
|
||||
except Exception: # pragma: no cover - best-effort cleanup
|
||||
logging.debug("[model_downloader] response release error", exc_info=True)
|
||||
@ -1,157 +0,0 @@
|
||||
"""Pre-download probe.
|
||||
|
||||
Issues a tiny ranged GET (``Range: bytes=0-0``) — which doubles as a
|
||||
range-support test — to discover ``Content-Length``, ``Accept-Ranges``,
|
||||
``ETag``/``Last-Modified``, and the final post-redirect URL. For HuggingFace
|
||||
LFS files the true size also appears in the non-standard ``X-Linked-Size``
|
||||
header, which we read as a fallback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse, urlsplit
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.model_downloader.net.http import (
|
||||
filename_from_content_disposition,
|
||||
open_validated,
|
||||
redact_url,
|
||||
)
|
||||
from app.model_downloader.net.session import parse_int_header
|
||||
|
||||
_PROBE_TIMEOUT = aiohttp.ClientTimeout(total=60, sock_connect=30, sock_read=30)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProbeResult:
|
||||
ok: bool
|
||||
status: int
|
||||
final_url: Optional[str] = None
|
||||
total_bytes: Optional[int] = None
|
||||
accept_ranges: bool = False
|
||||
etag: Optional[str] = None
|
||||
last_modified: Optional[str] = None
|
||||
gated: bool = False # 401/403 — needs (or has wrong) credentials
|
||||
error: Optional[str] = None
|
||||
# HuggingFace's ``X-Error-Code`` header (e.g. ``GatedRepo``,
|
||||
# ``RepoNotFound``) when the host reports one. Lets us tell "this repo is
|
||||
# gated — request access" apart from "you just need a token".
|
||||
error_code: Optional[str] = None
|
||||
# Filename the server intends this response to be saved as: the
|
||||
# ``Content-Disposition`` name if present, else the post-redirect URL's
|
||||
# basename. Used to resolve the real extension for URLs (e.g. Civitai's
|
||||
# ``/api/download`` endpoints) that carry no extension in their path.
|
||||
filename: Optional[str] = None
|
||||
|
||||
@property
|
||||
def is_gated_repo(self) -> bool:
|
||||
"""True when the host says the repo is gated (access must be granted).
|
||||
|
||||
Distinct from a plain missing/invalid token: even a valid credential
|
||||
won't help until the user accepts the model's terms on its page.
|
||||
"""
|
||||
return (self.error_code or "").lower() == "gatedrepo"
|
||||
|
||||
|
||||
def _total_from_content_range(value: Optional[str]) -> Optional[int]:
|
||||
# "bytes 0-0/12345" -> 12345 ; "bytes 0-0/*" -> None
|
||||
if not value or "/" not in value:
|
||||
return None
|
||||
total = value.rsplit("/", 1)[1].strip()
|
||||
return parse_int_header(total)
|
||||
|
||||
|
||||
def _filename_from_response(
|
||||
content_disposition: Optional[str], final_url: Optional[str]
|
||||
) -> Optional[str]:
|
||||
name = filename_from_content_disposition(content_disposition)
|
||||
if name:
|
||||
return name
|
||||
if final_url:
|
||||
base = urlsplit(final_url).path.rsplit("/", 1)[-1]
|
||||
if base:
|
||||
return base
|
||||
return None
|
||||
|
||||
|
||||
async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult:
|
||||
"""Probe ``url`` and return discovered metadata, failing soft."""
|
||||
try:
|
||||
async with open_validated(
|
||||
"GET",
|
||||
url,
|
||||
credential_id=credential_id,
|
||||
headers={"Range": "bytes=0-0", "Accept-Encoding": "identity"},
|
||||
timeout=_PROBE_TIMEOUT,
|
||||
) as (resp, final_url):
|
||||
if resp.status in (401, 403):
|
||||
error_code = resp.headers.get("X-Error-Code")
|
||||
error_message = resp.headers.get("X-Error-Message")
|
||||
return ProbeResult(
|
||||
ok=False, status=resp.status, final_url=final_url, gated=True,
|
||||
error_code=error_code,
|
||||
error=(
|
||||
error_message
|
||||
or f"host returned {resp.status} (authentication required)"
|
||||
),
|
||||
)
|
||||
if resp.status not in (200, 206):
|
||||
return ProbeResult(
|
||||
ok=False, status=resp.status, final_url=final_url,
|
||||
error=f"probe returned HTTP {resp.status}",
|
||||
)
|
||||
|
||||
headers = resp.headers
|
||||
accept_ranges = False
|
||||
total: Optional[int] = None
|
||||
if resp.status == 206:
|
||||
accept_ranges = True
|
||||
total = _total_from_content_range(headers.get("Content-Range"))
|
||||
else: # 200: server ignored the range
|
||||
accept_ranges = headers.get("Accept-Ranges", "").lower() == "bytes"
|
||||
total = parse_int_header(headers.get("Content-Length"))
|
||||
|
||||
if total is None:
|
||||
total = parse_int_header(headers.get("X-Linked-Size"))
|
||||
|
||||
return ProbeResult(
|
||||
ok=True,
|
||||
status=resp.status,
|
||||
final_url=final_url,
|
||||
total_bytes=total,
|
||||
accept_ranges=accept_ranges,
|
||||
etag=headers.get("ETag"),
|
||||
last_modified=headers.get("Last-Modified"),
|
||||
filename=_filename_from_response(
|
||||
headers.get("Content-Disposition"), final_url
|
||||
),
|
||||
)
|
||||
except Exception as e: # network / SSRF / timeout
|
||||
host = urlparse(url).netloc or "<unknown>"
|
||||
logging.debug("[model_downloader] probe failed for %s: %s", host, type(e).__name__)
|
||||
return ProbeResult(ok=False, status=0, error="probe failed: network error")
|
||||
|
||||
|
||||
def gated_error_message(url: str, pr: ProbeResult) -> str:
|
||||
"""Build a user-facing message for a gated/auth-required probe result.
|
||||
|
||||
Distinguishes a *gated* repo (access must be requested/granted on the model
|
||||
page — a token alone is not enough) from a plain missing/invalid credential.
|
||||
"""
|
||||
redacted = redact_url(url)
|
||||
if pr.is_gated_repo:
|
||||
detail = (pr.error or "access is restricted").rstrip()
|
||||
if detail and not detail.endswith((".", "!", "?")):
|
||||
detail += "."
|
||||
return (
|
||||
f"{redacted} is a gated model — {detail} Request access on the model's "
|
||||
f"page, add an API key for this host at /api/download/credentials, and retry."
|
||||
)
|
||||
return (
|
||||
f"{redacted} requires authentication. Add an API key for this host at "
|
||||
f"/api/download/credentials and retry."
|
||||
)
|
||||
@ -1,72 +0,0 @@
|
||||
"""Lazily-created shared :class:`aiohttp.ClientSession`.
|
||||
|
||||
A single session reuses TLS handshakes and TCP connections across the probe
|
||||
and the many segment GETs to the same host (HuggingFace is the dominant
|
||||
case), which is a large speedup on cold connections and exactly the
|
||||
connection-reuse strategy that lets us match aria2c.
|
||||
|
||||
The connector uses :class:`ValidatingResolver` so every connection — initial
|
||||
or post-redirect — is screened for private/special-use IPs at connect time.
|
||||
TLS is pinned to certifi's CA bundle because the OS trust store is not wired
|
||||
up on some Python installs (python.org macOS, slim containers).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ssl
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
try:
|
||||
import certifi
|
||||
_CA_FILE = certifi.where()
|
||||
except Exception: # pragma: no cover - certifi is a transitive dep of aiohttp
|
||||
_CA_FILE = None
|
||||
|
||||
from comfy.cli_args import args
|
||||
from app.model_downloader.security.ssrf import ValidatingResolver
|
||||
|
||||
_session: Optional[aiohttp.ClientSession] = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def ssl_context() -> ssl.SSLContext:
|
||||
if _CA_FILE is not None:
|
||||
return ssl.create_default_context(cafile=_CA_FILE)
|
||||
return ssl.create_default_context()
|
||||
|
||||
|
||||
async def get_session() -> aiohttp.ClientSession:
|
||||
"""Return the shared session, creating it on first use."""
|
||||
global _session
|
||||
if _session is not None and not _session.closed:
|
||||
return _session
|
||||
async with _lock:
|
||||
if _session is None or _session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit_per_host=max(1, getattr(args, "download_max_connections_per_host", 16)),
|
||||
ssl=ssl_context(),
|
||||
resolver=ValidatingResolver(),
|
||||
)
|
||||
_session = aiohttp.ClientSession(connector=connector)
|
||||
return _session
|
||||
|
||||
|
||||
async def close_session() -> None:
|
||||
global _session
|
||||
if _session is not None and not _session.closed:
|
||||
await _session.close()
|
||||
_session = None
|
||||
|
||||
|
||||
def parse_int_header(value: Optional[str]) -> Optional[int]:
|
||||
"""Parse a non-negative integer header value, or None if bad/absent."""
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
n = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return n if n >= 0 else None
|
||||
@ -1,177 +0,0 @@
|
||||
"""Priority scheduler + lifecycle.
|
||||
|
||||
Owns the set of running jobs and admits queued downloads up to a global
|
||||
concurrency limit (K), highest priority first, FIFO within a priority. Runs
|
||||
entirely on the existing ComfyUI asyncio loop; blocking work (disk, hashing,
|
||||
DB) is offloaded by the job/writer layers.
|
||||
|
||||
On startup it reconciles DB vs. disk: ``active``/``verifying`` rows left by a
|
||||
previous run are reset to ``queued`` and resumed from persisted offsets, and
|
||||
orphaned ``.part`` files with no live download row are swept.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
from comfy.cli_args import args
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.engine.job import DownloadJob, JobSpec
|
||||
from app.model_downloader.security import paths
|
||||
|
||||
# Backoff for retryable failures
|
||||
_BACKOFF_BASE = 2.0
|
||||
_BACKOFF_CAP = 300.0
|
||||
_MAX_ATTEMPTS = 6
|
||||
|
||||
|
||||
class Scheduler:
|
||||
def __init__(self) -> None:
|
||||
self._jobs: dict[str, DownloadJob] = {}
|
||||
self._tasks: dict[str, asyncio.Task] = {}
|
||||
self._backoff_until: dict[str, float] = {}
|
||||
self._pump_lock = asyncio.Lock()
|
||||
self._notify_cb: Optional[Callable[[str], None]] = None
|
||||
self._started = False
|
||||
|
||||
@property
|
||||
def max_active(self) -> int:
|
||||
return max(1, getattr(args, "download_max_active", 3))
|
||||
|
||||
def set_notify(self, cb: Optional[Callable[[str], None]]) -> None:
|
||||
self._notify_cb = cb
|
||||
|
||||
def get_job(self, download_id: str) -> Optional[DownloadJob]:
|
||||
return self._jobs.get(download_id)
|
||||
|
||||
def is_active(self, download_id: str) -> bool:
|
||||
return download_id in self._tasks
|
||||
|
||||
# ----- startup -----
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._started:
|
||||
return
|
||||
self._started = True
|
||||
try:
|
||||
await asyncio.to_thread(queries.reconcile_live_downloads)
|
||||
await asyncio.to_thread(self._sweep_orphan_temp_files)
|
||||
except Exception as e:
|
||||
logging.warning("[model_downloader] startup reconcile failed: %s", e)
|
||||
await self.pump()
|
||||
|
||||
@staticmethod
|
||||
def _sweep_orphan_temp_files() -> None:
|
||||
"""Remove ``.part`` files not referenced by a resumable download row.
|
||||
|
||||
Resumable partials are preserved; only truly orphaned temp files from
|
||||
crashed runs are deleted. ``FAILED`` is included because
|
||||
:meth:`DownloadManager.resume` explicitly permits resuming a
|
||||
retry-exhausted failed row: deleting its partial here while the
|
||||
per-segment offsets survive in the DB would make the next resume
|
||||
preallocate a fresh sparse file, skip every "complete" segment, and
|
||||
leave zero-filled holes that pass the size-only verification gate.
|
||||
"""
|
||||
live = {
|
||||
row.temp_path
|
||||
for row in queries.list_downloads()
|
||||
if row.status
|
||||
in (
|
||||
DownloadStatus.QUEUED,
|
||||
DownloadStatus.PAUSED,
|
||||
DownloadStatus.FAILED,
|
||||
)
|
||||
}
|
||||
for path in paths.iter_all_tmp_paths():
|
||||
if path in live:
|
||||
continue
|
||||
try:
|
||||
os.remove(path)
|
||||
logging.info("[model_downloader] removed orphan temp file: %s", path)
|
||||
except OSError as e:
|
||||
logging.warning("[model_downloader] could not remove %s: %s", path, e)
|
||||
|
||||
# ----- admission -----
|
||||
|
||||
async def pump(self) -> None:
|
||||
async with self._pump_lock:
|
||||
slots = self.max_active - len(self._tasks)
|
||||
if slots <= 0:
|
||||
return
|
||||
now = time.monotonic()
|
||||
candidates = await asyncio.to_thread(queries.list_queued_downloads)
|
||||
for row in candidates:
|
||||
if slots <= 0:
|
||||
break
|
||||
if row.id in self._tasks:
|
||||
continue
|
||||
if self._backoff_until.get(row.id, 0.0) > now:
|
||||
continue
|
||||
self._admit(row)
|
||||
slots -= 1
|
||||
|
||||
def _admit(self, row) -> None:
|
||||
spec = JobSpec(
|
||||
download_id=row.id,
|
||||
url=row.url,
|
||||
model_id=row.model_id,
|
||||
dest_path=row.dest_path,
|
||||
temp_path=row.temp_path,
|
||||
priority=row.priority,
|
||||
credential_id=row.credential_id,
|
||||
expected_sha256=row.expected_sha256,
|
||||
allow_any_extension=row.allow_any_extension,
|
||||
etag=row.etag,
|
||||
attempts=row.attempts,
|
||||
)
|
||||
job = DownloadJob(spec, notify_cb=self._notify_cb)
|
||||
self._jobs[row.id] = job
|
||||
self._tasks[row.id] = asyncio.ensure_future(self._run_job(job))
|
||||
|
||||
async def _run_job(self, job: DownloadJob) -> None:
|
||||
download_id = job.spec.download_id
|
||||
status = DownloadStatus.FAILED
|
||||
try:
|
||||
status = await job.run()
|
||||
except Exception as e: # run() is defensive, but never let a task die silently
|
||||
logging.error("[model_downloader] job %s crashed: %s", download_id, e)
|
||||
queries.update_download(
|
||||
download_id,
|
||||
status=DownloadStatus.FAILED,
|
||||
error=f"internal error: {e}",
|
||||
)
|
||||
if self._notify_cb:
|
||||
self._notify_cb(download_id)
|
||||
finally:
|
||||
self._tasks.pop(download_id, None)
|
||||
self._jobs.pop(download_id, None)
|
||||
|
||||
if status == DownloadStatus.QUEUED:
|
||||
if job.spec.attempts >= _MAX_ATTEMPTS:
|
||||
queries.update_download(
|
||||
download_id,
|
||||
status=DownloadStatus.FAILED,
|
||||
error=f"giving up after {job.spec.attempts} attempts",
|
||||
)
|
||||
if self._notify_cb:
|
||||
self._notify_cb(download_id)
|
||||
else:
|
||||
delay = min(
|
||||
_BACKOFF_CAP, _BACKOFF_BASE ** job.spec.attempts
|
||||
) + random.uniform(0, 1.0)
|
||||
self._backoff_until[download_id] = time.monotonic() + delay
|
||||
asyncio.ensure_future(self._delayed_pump(delay))
|
||||
await self.pump()
|
||||
|
||||
async def _delayed_pump(self, delay: float) -> None:
|
||||
await asyncio.sleep(delay)
|
||||
await self.pump()
|
||||
|
||||
|
||||
SCHEDULER = Scheduler()
|
||||
@ -1,140 +0,0 @@
|
||||
"""URL allowlist for server-side model fetches.
|
||||
|
||||
Default-deny. A URL is downloadable only when its parsed host + scheme are
|
||||
allowlisted AND (unless explicitly relaxed) its final filename ends in a
|
||||
known model extension.
|
||||
|
||||
The built-in host defaults mirror the frontend's ``isModelDownloadable``
|
||||
allowlist so the two flows agree on what is eligible; ``--download-allowed-hosts``
|
||||
extends it for self-hosted mirrors. Matching is done on ``urlparse().hostname``
|
||||
(never a raw string prefix) so userinfo tricks like
|
||||
``http://127.0.0.1@169.254.169.254/x.safetensors`` — whose real host is the
|
||||
metadata IP — cannot slip past.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
# host -> set of allowed schemes. Frontend parity (HuggingFace / Civitai /
|
||||
# localhost). Extra hosts from --download-allowed-hosts are https-only.
|
||||
_DEFAULT_ALLOWED_HOSTS: dict[str, set[str]] = {
|
||||
"huggingface.co": {"https"},
|
||||
"civitai.com": {"https"},
|
||||
"localhost": {"http", "https"},
|
||||
"127.0.0.1": {"http", "https"},
|
||||
}
|
||||
|
||||
# Hosts for which loopback addresses are intentionally permitted (the localhost
|
||||
# "download a local model" feature). Every other host's loopback resolution is
|
||||
# rejected by the SSRF resolver.
|
||||
LOOPBACK_HOSTS = frozenset({"localhost", "127.0.0.1", "::1"})
|
||||
|
||||
# Known model file extensions (frontend parity). Checked on the final filename.
|
||||
ALLOWED_MODEL_EXTENSIONS = (
|
||||
".safetensors",
|
||||
".sft",
|
||||
".ckpt",
|
||||
".pth",
|
||||
".pt",
|
||||
".gguf",
|
||||
".bin",
|
||||
)
|
||||
|
||||
|
||||
def _allowed_hosts() -> dict[str, set[str]]:
|
||||
hosts = {h: set(s) for h, s in _DEFAULT_ALLOWED_HOSTS.items()}
|
||||
for extra in getattr(args, "download_allowed_hosts", []) or []:
|
||||
host = extra.strip().lower()
|
||||
if host:
|
||||
hosts.setdefault(host, set()).add("https")
|
||||
return hosts
|
||||
|
||||
|
||||
def is_host_allowed(host: str | None, scheme: str | None) -> bool:
|
||||
"""True iff ``host`` is allowlisted for ``scheme``.
|
||||
|
||||
Used both for the initial URL and re-checked on every redirect hop,
|
||||
so a whitelisted URL cannot 30x into an off-list host.
|
||||
"""
|
||||
if not host or not scheme:
|
||||
return False
|
||||
allowed = _allowed_hosts().get(host.lower())
|
||||
return allowed is not None and scheme.lower() in allowed
|
||||
|
||||
|
||||
def has_allowed_extension(path: str, allow_any_extension: bool = False) -> bool:
|
||||
if allow_any_extension:
|
||||
return True
|
||||
return path.lower().endswith(ALLOWED_MODEL_EXTENSIONS)
|
||||
|
||||
|
||||
def filename_extension(name: str) -> str:
|
||||
"""Lowercased extension (including the leading dot) of a bare filename.
|
||||
|
||||
Returns ``""`` when there is no extension. A leading-dot name
|
||||
(``.safetensors``) is treated as having no extension (all stem), matching
|
||||
``os.path.splitext`` semantics so dotfiles aren't mistaken for typed files.
|
||||
"""
|
||||
base = name.replace("\\", "/").rsplit("/", 1)[-1]
|
||||
dot = base.rfind(".")
|
||||
if dot <= 0:
|
||||
return ""
|
||||
return base[dot:].lower()
|
||||
|
||||
|
||||
def is_allowed_extension_name(name: str) -> bool:
|
||||
"""True iff ``name`` ends in one of the known model extensions."""
|
||||
return name.lower().endswith(ALLOWED_MODEL_EXTENSIONS)
|
||||
|
||||
|
||||
def is_host_allowed_url(url: str) -> bool:
|
||||
"""True iff ``url`` parses and its host+scheme are allowlisted."""
|
||||
if not isinstance(url, str) or not url:
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError:
|
||||
return False
|
||||
return is_host_allowed(parsed.hostname, parsed.scheme)
|
||||
|
||||
|
||||
def url_path_extension(url: str) -> str:
|
||||
"""Extension of the URL *path* basename (query ignored), or ``""``."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError:
|
||||
return ""
|
||||
return filename_extension(parsed.path)
|
||||
|
||||
|
||||
def is_url_downloadable(url: str) -> bool:
|
||||
"""Coarse enqueue gate: host/scheme allowed and extension not disallowed.
|
||||
|
||||
Unlike :func:`is_url_allowed` (which demands a known extension *in the URL*),
|
||||
this also admits URLs whose path carries no extension at all — e.g. a Civitai
|
||||
``/api/download/models/<id>`` endpoint whose real filename only shows up in
|
||||
the redirect target / ``Content-Disposition``. The true extension is then
|
||||
resolved from the network and re-validated before the download is admitted.
|
||||
A path bearing an explicit *non-model* extension (``.zip``, ``.html``, ...)
|
||||
is still rejected here.
|
||||
"""
|
||||
if not is_host_allowed_url(url):
|
||||
return False
|
||||
ext = url_path_extension(url)
|
||||
return ext == "" or ext in ALLOWED_MODEL_EXTENSIONS
|
||||
|
||||
|
||||
def is_url_allowed(url: str, allow_any_extension: bool = False) -> bool:
|
||||
"""Check whether ``url`` is permitted as a server-side download source."""
|
||||
if not isinstance(url, str) or not url:
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError:
|
||||
return False
|
||||
if not is_host_allowed(parsed.hostname, parsed.scheme):
|
||||
return False
|
||||
return has_allowed_extension(parsed.path, allow_any_extension)
|
||||
@ -1,132 +0,0 @@
|
||||
"""Path resolution + traversal safety for downloads.
|
||||
|
||||
A ``model_id`` is a *relative destination path* of the form
|
||||
``<directory>/<filename>`` (e.g. ``loras/my_lora.safetensors``). This module
|
||||
turns one into an absolute on-disk path under one of ComfyUI's registered
|
||||
model folders, rejecting unknown folders, path traversal, and symlink escape.
|
||||
This is the only thing that composes destination paths, so the engine never
|
||||
touches user-supplied path strings directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import folder_paths
|
||||
|
||||
from app.model_downloader.constants import TMP_SUFFIX
|
||||
from app.model_downloader.security.allowlist import ALLOWED_MODEL_EXTENSIONS
|
||||
|
||||
# A model_id component is a single path segment of safe characters — no slashes,
|
||||
# no "..", no leading dots that could escape the target directory.
|
||||
_SEGMENT_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$")
|
||||
|
||||
|
||||
class InvalidModelId(ValueError):
|
||||
"""Raised when a model_id is malformed or names an unknown model folder."""
|
||||
|
||||
|
||||
def parse_model_id(model_id: str, allow_any_extension: bool = False) -> tuple[str, str]:
|
||||
"""Split ``<directory>/<filename>`` and validate both components.
|
||||
|
||||
Returns ``(directory, filename)``. Does not touch the filesystem.
|
||||
"""
|
||||
if not isinstance(model_id, str) or "/" not in model_id:
|
||||
raise InvalidModelId(
|
||||
f"model_id must be '<directory>/<filename>', got {model_id!r}"
|
||||
)
|
||||
directory, _, filename = model_id.partition("/")
|
||||
if "/" in filename or not directory or not filename:
|
||||
raise InvalidModelId(
|
||||
f"model_id must have exactly one '/' separator, got {model_id!r}"
|
||||
)
|
||||
if not _SEGMENT_RE.match(directory):
|
||||
raise InvalidModelId(f"invalid directory segment {directory!r}")
|
||||
if not _SEGMENT_RE.match(filename):
|
||||
raise InvalidModelId(f"invalid filename segment {filename!r}")
|
||||
if not allow_any_extension and not filename.lower().endswith(
|
||||
ALLOWED_MODEL_EXTENSIONS
|
||||
):
|
||||
raise InvalidModelId(
|
||||
f"filename must end with a known model extension "
|
||||
f"{ALLOWED_MODEL_EXTENSIONS}, got {filename!r}"
|
||||
)
|
||||
if directory not in folder_paths.folder_names_and_paths:
|
||||
raise InvalidModelId(f"unknown model folder {directory!r}")
|
||||
return directory, filename
|
||||
|
||||
|
||||
def apply_extension(model_id: str, ext: str) -> str:
|
||||
"""Return ``model_id`` with its filename forced to end in ``ext``.
|
||||
|
||||
``ext`` includes the leading dot (e.g. ``".safetensors"``). If the filename
|
||||
already ends in a *known model extension* it is replaced; otherwise ``ext``
|
||||
is appended (so ``loras/mymodel`` -> ``loras/mymodel.safetensors`` and
|
||||
``loras/mymodel.ckpt`` -> ``loras/mymodel.safetensors``). A filename with a
|
||||
non-model suffix (``my.model.v2``) is treated as an extensionless stem and
|
||||
``ext`` is appended. The directory part is left untouched; validation is
|
||||
still the caller's job via :func:`parse_model_id`.
|
||||
"""
|
||||
directory, sep, filename = model_id.partition("/")
|
||||
if not sep:
|
||||
return model_id # malformed; parse_model_id will reject it
|
||||
low = filename.lower()
|
||||
for known in ALLOWED_MODEL_EXTENSIONS:
|
||||
if low.endswith(known):
|
||||
filename = filename[: -len(known)]
|
||||
break
|
||||
return f"{directory}{sep}{filename}{ext}"
|
||||
|
||||
|
||||
def resolve_existing(model_id: str, allow_any_extension: bool = False) -> Optional[str]:
|
||||
"""Return the absolute path of an installed model, or None if missing.
|
||||
|
||||
Honours ``extra_model_paths.yaml`` transparently via ``get_full_path``.
|
||||
"""
|
||||
directory, filename = parse_model_id(model_id, allow_any_extension)
|
||||
return folder_paths.get_full_path(directory, filename)
|
||||
|
||||
|
||||
def resolve_destination(
|
||||
model_id: str, allow_any_extension: bool = False
|
||||
) -> tuple[str, str]:
|
||||
"""Return ``(final_path, temp_path)`` for a download.
|
||||
|
||||
Downloads land at the first registered path for the model's directory
|
||||
(the "primary" location). ``temp_path`` is a sibling ``.part`` file that
|
||||
is atomically renamed onto ``final_path`` on success. The result is
|
||||
asserted to stay within the registered root (defence in depth on top of
|
||||
the segment regex).
|
||||
"""
|
||||
directory, filename = parse_model_id(model_id, allow_any_extension)
|
||||
roots = folder_paths.get_folder_paths(directory)
|
||||
if not roots:
|
||||
raise InvalidModelId(f"no on-disk path registered for folder {directory!r}")
|
||||
root = os.path.realpath(roots[0])
|
||||
final_path = os.path.realpath(os.path.join(root, filename))
|
||||
if final_path != root and not final_path.startswith(root + os.sep):
|
||||
raise InvalidModelId(f"resolved path escapes model root: {model_id!r}")
|
||||
temp_path = f"{final_path}{TMP_SUFFIX}"
|
||||
return final_path, temp_path
|
||||
|
||||
|
||||
def iter_all_tmp_paths() -> Iterator[str]:
|
||||
"""Yield this subsystem's temp files under every registered model folder.
|
||||
|
||||
Matches only the distinctive ``TMP_SUFFIX`` so the startup orphan sweep
|
||||
can never delete temp files created by other tools.
|
||||
"""
|
||||
seen_roots: set[str] = set()
|
||||
for directory in list(folder_paths.folder_names_and_paths.keys()):
|
||||
for root in folder_paths.get_folder_paths(directory):
|
||||
if root in seen_roots or not os.path.isdir(root):
|
||||
continue
|
||||
seen_roots.add(root)
|
||||
try:
|
||||
for entry in os.scandir(root):
|
||||
if entry.is_file() and entry.name.endswith(TMP_SUFFIX):
|
||||
yield entry.path
|
||||
except OSError:
|
||||
continue
|
||||
@ -1,163 +0,0 @@
|
||||
"""SSRF / exfiltration defenses.
|
||||
|
||||
Two cooperating layers:
|
||||
|
||||
1. :class:`ValidatingResolver` is installed on the shared connector. Every
|
||||
connection — the initial probe and every segment GET, including ones made
|
||||
after a redirect — resolves its host through this resolver, which rejects
|
||||
any address that lands on a private / special-use IP range. Because the
|
||||
resolve and the connect happen together inside the connector, there is no
|
||||
check-then-connect window for DNS rebinding to exploit.
|
||||
|
||||
2. :func:`check_redirect_hop` re-validates every hop. The host allowlist gates
|
||||
only the *initial* user-supplied URL (anti-SSRF for arbitrary input);
|
||||
legitimate downloads from allowlisted origins redirect to presigned CDN
|
||||
hosts that are deliberately NOT on the allowlist (HF ->
|
||||
``cdn-lfs*.huggingface.co``, Civitai -> signed Cloudflare/S3), so hops are
|
||||
instead screened for scheme, embedded credentials, and — via the resolver
|
||||
above — private IPs. Credentials are only ever attached when a hop's host
|
||||
exactly matches a stored credential, so they are dropped on the CDN hop.
|
||||
Loopback (the "download a local model" feature) is exempt from IP filtering
|
||||
only for the initial URL: a *redirect* may never target a loopback host or
|
||||
a blocked IP-literal, which the resolver alone can't enforce (it exempts
|
||||
loopback literals and never sees IP literals through DNS).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from aiohttp.abc import AbstractResolver
|
||||
from aiohttp.resolver import DefaultResolver
|
||||
|
||||
from app.model_downloader.security.allowlist import LOOPBACK_HOSTS
|
||||
|
||||
# Cap the redirect chain length a hop may use.
|
||||
MAX_REDIRECTS = 5
|
||||
|
||||
|
||||
class SSRFError(Exception):
|
||||
"""A hop failed an SSRF / allowlist check."""
|
||||
|
||||
|
||||
def is_scheme_allowed(scheme: str | None, host: str | None) -> bool:
|
||||
"""True iff ``scheme`` is permitted for ``host`` on a download hop.
|
||||
|
||||
https is always allowed; plain http only for loopback/approved dev hosts.
|
||||
"""
|
||||
if not scheme:
|
||||
return False
|
||||
scheme = scheme.lower()
|
||||
if scheme == "https":
|
||||
return True
|
||||
if scheme == "http":
|
||||
return bool(host) and host.lower() in LOOPBACK_HOSTS
|
||||
return False
|
||||
|
||||
|
||||
def is_blocked_ip(ip_str: str) -> bool:
|
||||
"""True for any address we refuse to connect to.
|
||||
|
||||
Covers loopback, link-local (incl. 169.254.169.254 cloud metadata),
|
||||
RFC1918 private ranges, unique-local (ULA), unspecified (0.0.0.0/::),
|
||||
multicast and other reserved ranges.
|
||||
"""
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return True # unparseable -> refuse
|
||||
# On CPython before the gh-113171 fix (backported to 3.12.4/3.11.9/
|
||||
# 3.10.14/3.9.19) the is_* properties don't see through IPv4-mapped IPv6
|
||||
# (e.g. ::ffff:169.254.169.254), so resolve and re-check the embedded IPv4
|
||||
# to keep mapped metadata/private addresses from slipping past the filter.
|
||||
mapped = getattr(ip, "ipv4_mapped", None)
|
||||
if mapped is not None:
|
||||
ip = mapped
|
||||
return (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
or ip.is_reserved
|
||||
or ip.is_unspecified
|
||||
)
|
||||
|
||||
|
||||
class ValidatingResolver(AbstractResolver):
|
||||
"""Delegating resolver that drops blocked IPs from every resolution.
|
||||
|
||||
If a hostname resolves only to blocked addresses, the connection fails
|
||||
closed with an :class:`OSError`, which aiohttp surfaces as a connection
|
||||
error to the caller.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._inner = DefaultResolver()
|
||||
|
||||
async def resolve(self, host, port=0, family=socket.AF_INET):
|
||||
infos = await self._inner.resolve(host, port, family)
|
||||
# localhost/127.0.0.1 are an explicit, opt-in allowlist feature.
|
||||
if isinstance(host, str) and host.lower() in LOOPBACK_HOSTS:
|
||||
return infos
|
||||
safe = [info for info in infos if not is_blocked_ip(info["host"])]
|
||||
if not safe:
|
||||
raise OSError(
|
||||
f"refusing to connect to {host!r}: resolves only to "
|
||||
f"private/special-use addresses"
|
||||
)
|
||||
return safe
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._inner.close()
|
||||
|
||||
|
||||
def check_redirect_hop(url: str, *, is_initial_url: bool = False) -> str:
|
||||
"""Validate one hop's URL.
|
||||
|
||||
Returns the URL unchanged on success; raises :class:`SSRFError` otherwise.
|
||||
Requires https for external hosts (http only for loopback/approved dev
|
||||
hosts) and forbids credentials-in-URL. The host is NOT re-checked against
|
||||
the allowlist (CDN redirect targets are off-list by design); credential
|
||||
leakage is prevented by exact host matching at attach time, and the landing
|
||||
filename's extension is gated separately by the caller.
|
||||
|
||||
Loopback/blocked-IP screening: the connector's resolver filters resolvable
|
||||
hostnames but exempts literal loopback hosts (``localhost``/``127.0.0.1``/
|
||||
``::1``) and never sees IP literals through DNS. That loopback exemption is
|
||||
legitimate only for the *initial* user-supplied URL (``is_initial_url``);
|
||||
on a redirect hop we reject loopback hosts and any blocked IP-literal here,
|
||||
so a 30x can't steer a server-side GET at loopback/internal services.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError as e:
|
||||
raise SSRFError(f"unparseable redirect URL {url!r}: {e}") from e
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise SSRFError(f"redirect URL has no host: {url!r}")
|
||||
if not is_scheme_allowed(parsed.scheme, host):
|
||||
raise SSRFError(
|
||||
f"redirect to disallowed scheme {parsed.scheme!r} for host "
|
||||
f"{host!r} (https required for external hosts)"
|
||||
)
|
||||
if parsed.username or parsed.password:
|
||||
raise SSRFError("credentials-in-URL are not allowed")
|
||||
host_is_loopback = host.lower() in LOOPBACK_HOSTS
|
||||
if not is_initial_url and host_is_loopback:
|
||||
raise SSRFError(f"redirect to loopback host {host!r} is not allowed")
|
||||
# IP-literal targets never go through DNS, so the connector's resolver can't
|
||||
# screen them — check them directly. The only blocked IP allowed through is
|
||||
# a loopback literal on the initial URL (handled by the exemption above).
|
||||
try:
|
||||
ipaddress.ip_address(host)
|
||||
except ValueError:
|
||||
is_ip_literal = False
|
||||
else:
|
||||
is_ip_literal = True
|
||||
if is_ip_literal and is_blocked_ip(host) and not (
|
||||
is_initial_url and host_is_loopback
|
||||
):
|
||||
raise SSRFError(f"redirect to blocked internal address {host!r}")
|
||||
return url
|
||||
@ -1,49 +0,0 @@
|
||||
"""Hub-checksum verification = SHA256.
|
||||
|
||||
Only used to confirm a download matches a *provided* ``expected_sha256``. It
|
||||
is NOT the dedup key (that is blake3, owned by the assets system). The full
|
||||
sequential read happens at most once, here, only when a checksum was supplied.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Callable, Optional
|
||||
|
||||
_CHUNK = 8 * 1024 * 1024
|
||||
|
||||
InterruptCheck = Callable[[], bool]
|
||||
|
||||
|
||||
class ChecksumError(Exception):
|
||||
"""The computed SHA256 did not match the expected value."""
|
||||
|
||||
|
||||
def sha256_file(path: str, interrupt_check: Optional[InterruptCheck] = None) -> Optional[str]:
|
||||
"""Stream the file and return its lowercase hex SHA256.
|
||||
|
||||
Returns ``None`` if interrupted via ``interrupt_check``.
|
||||
"""
|
||||
h = hashlib.sha256()
|
||||
with open(path, "rb") as f:
|
||||
while True:
|
||||
if interrupt_check is not None and interrupt_check():
|
||||
return None
|
||||
chunk = f.read(_CHUNK)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def verify_sha256(
|
||||
path: str, expected: str, interrupt_check: Optional[InterruptCheck] = None
|
||||
) -> None:
|
||||
"""Raise :class:`ChecksumError` unless the file's SHA256 matches ``expected``."""
|
||||
actual = sha256_file(path, interrupt_check)
|
||||
if actual is None:
|
||||
return # interrupted; caller will re-verify on resume
|
||||
if actual.lower() != expected.lower():
|
||||
raise ChecksumError(
|
||||
f"sha256 mismatch: expected {expected.lower()}, got {actual.lower()}"
|
||||
)
|
||||
@ -1,53 +0,0 @@
|
||||
"""Dedup + catalog handoff — reuse the assets system.
|
||||
|
||||
We do NOT build a parallel indexer. "Do I already have it?" is answered by
|
||||
``resolve_existing`` (path) at enqueue time and, where a hash is known, by the
|
||||
assets blake3 catalog. After a completed download we register the file
|
||||
through the assets ingest path so it is cataloged and (eventually) hashed by
|
||||
the existing enrichment worker.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _register_sync(abs_path: str) -> Optional[str]:
|
||||
"""Register a finished file into the assets catalog. Returns asset hash."""
|
||||
try:
|
||||
from app.assets.services.ingest import register_file_in_place
|
||||
except Exception as e: # assets package import failure — non-fatal
|
||||
logging.debug("[model_downloader] assets ingest unavailable: %s", e)
|
||||
return None
|
||||
try:
|
||||
result = register_file_in_place(abs_path, name=os.path.basename(abs_path), tags=[])
|
||||
return result.asset.hash if result and result.asset else None
|
||||
except Exception as e:
|
||||
# The file is already safely on disk; cataloging is best-effort.
|
||||
logging.warning(
|
||||
"[model_downloader] could not register %s into assets catalog: %s",
|
||||
abs_path, e,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def register_completed(abs_path: str) -> Optional[str]:
|
||||
"""Catalog a completed download via the assets system (off the event loop)."""
|
||||
return await asyncio.to_thread(_register_sync, abs_path)
|
||||
|
||||
|
||||
def _find_by_hash_sync(blake3_hex: str) -> Optional[str]:
|
||||
try:
|
||||
from app.assets.services.asset_management import get_asset_by_hash
|
||||
except Exception:
|
||||
return None
|
||||
asset = get_asset_by_hash("blake3:" + blake3_hex)
|
||||
return asset.hash if asset is not None else None
|
||||
|
||||
|
||||
async def find_existing_by_hash(blake3_hex: str) -> Optional[str]:
|
||||
"""Pure DB lookup — never triggers hashing on the hot path."""
|
||||
return await asyncio.to_thread(_find_by_hash_sync, blake3_hex)
|
||||
@ -1,86 +0,0 @@
|
||||
"""Cheap structural validation, no full read.
|
||||
|
||||
For ``.safetensors``/``.sft`` we parse the header (first few KB): it carries
|
||||
the tensor table and the byte length of the data region. We assert
|
||||
``file_size == 8 + header_len + data_region_len``. This detects truncation
|
||||
and most corruption for free, before any crypto hashing. Other extensions
|
||||
have no cheap structural check and pass through.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
from typing import Optional
|
||||
|
||||
_SAFETENSORS_EXTS = (".safetensors", ".sft")
|
||||
# A sane upper bound so a corrupt header length can't make us read gigabytes.
|
||||
_MAX_HEADER_BYTES = 100 * 1024 * 1024
|
||||
|
||||
|
||||
class StructuralError(Exception):
|
||||
"""The file failed its structural integrity check."""
|
||||
|
||||
|
||||
def validate(path: str, name_hint: Optional[str] = None) -> None:
|
||||
"""Validate the file at ``path``. Raises :class:`StructuralError` on failure.
|
||||
|
||||
The file format is detected from ``name_hint`` when provided, otherwise from
|
||||
``path``. Callers that download into a temp file with an opaque suffix (e.g.
|
||||
``*.comfy-download.part``) must pass the final destination name as
|
||||
``name_hint`` so the format check is not silently skipped.
|
||||
"""
|
||||
lower = (name_hint or path).lower()
|
||||
if lower.endswith(_SAFETENSORS_EXTS):
|
||||
_validate_safetensors(path)
|
||||
# No structural check for other formats; the size + (optional) checksum
|
||||
# gates in the engine cover those.
|
||||
|
||||
|
||||
def _validate_safetensors(path: str) -> None:
|
||||
file_size = os.path.getsize(path)
|
||||
if file_size < 8:
|
||||
raise StructuralError(f"file too small to be safetensors ({file_size} bytes)")
|
||||
with open(path, "rb") as f:
|
||||
header_len = struct.unpack("<Q", f.read(8))[0]
|
||||
if header_len <= 0 or header_len > _MAX_HEADER_BYTES:
|
||||
raise StructuralError(f"implausible safetensors header length {header_len}")
|
||||
if 8 + header_len > file_size:
|
||||
raise StructuralError("safetensors header extends past end of file")
|
||||
try:
|
||||
header = json.loads(f.read(header_len).decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
||||
raise StructuralError(f"safetensors header is not valid JSON: {e}") from e
|
||||
|
||||
if not isinstance(header, dict):
|
||||
raise StructuralError("safetensors header is not a JSON object")
|
||||
|
||||
data_len = 0
|
||||
for name, entry in header.items():
|
||||
if name == "__metadata__":
|
||||
continue
|
||||
if not isinstance(entry, dict) or "data_offsets" not in entry:
|
||||
raise StructuralError(f"tensor {name!r} missing data_offsets")
|
||||
offsets = entry["data_offsets"]
|
||||
if not (isinstance(offsets, list) and len(offsets) == 2):
|
||||
raise StructuralError(f"tensor {name!r} has malformed data_offsets")
|
||||
begin, end = offsets
|
||||
# bool is an int subclass; reject it explicitly to avoid True/False offsets.
|
||||
if (
|
||||
not isinstance(begin, int)
|
||||
or not isinstance(end, int)
|
||||
or isinstance(begin, bool)
|
||||
or isinstance(end, bool)
|
||||
or begin < 0
|
||||
or end < begin
|
||||
):
|
||||
raise StructuralError(f"tensor {name!r} has malformed data_offsets")
|
||||
data_len = max(data_len, end)
|
||||
|
||||
expected = 8 + header_len + data_len
|
||||
if file_size != expected:
|
||||
raise StructuralError(
|
||||
f"size mismatch: file is {file_size} bytes, header implies {expected} "
|
||||
f"(8 + {header_len} header + {data_len} data)"
|
||||
)
|
||||
@ -50,21 +50,45 @@ class ModelFileManager:
|
||||
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
||||
async def get_model_preview(request):
|
||||
folder_name = request.match_info.get("folder", None)
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
filename = request.match_info.get("filename", None)
|
||||
|
||||
if folder_name not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
|
||||
# The "{filename:.*}" capture also matches the empty string, which
|
||||
# would resolve to the folder itself; reject it explicitly.
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
try:
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
except (TypeError, ValueError):
|
||||
return web.Response(status=400)
|
||||
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
if path_index < 0 or path_index >= len(folders[0]):
|
||||
return web.Response(status=404)
|
||||
folder = folders[0][path_index]
|
||||
full_filename = os.path.join(folder, filename)
|
||||
full_filename = os.path.normpath(os.path.join(folder, filename))
|
||||
|
||||
# Prevent path traversal: the requested file must stay within the
|
||||
# configured model folder. `filename` is an unrestricted ".*" capture,
|
||||
# so values like "../../../../etc/passwd" would otherwise escape it.
|
||||
if not folder_paths.is_within_directory(folder, full_filename):
|
||||
return web.Response(status=403)
|
||||
|
||||
previews = self.get_model_previews(full_filename)
|
||||
default_preview = previews[0] if len(previews) > 0 else None
|
||||
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
||||
return web.Response(status=404)
|
||||
|
||||
# The preview is selected by a glob inside get_model_previews, so a
|
||||
# companion file (e.g. "model.preview.png") could itself be a symlink
|
||||
# resolving outside the model folder. Re-validate the file actually
|
||||
# opened: is_within_directory realpaths it, catching symlink escape.
|
||||
if isinstance(default_preview, str) and not folder_paths.is_within_directory(folder, default_preview):
|
||||
return web.Response(status=403)
|
||||
|
||||
try:
|
||||
with Image.open(default_preview) as img:
|
||||
img_bytes = BytesIO()
|
||||
|
||||
@ -6,6 +6,7 @@ import glob
|
||||
import shutil
|
||||
import logging
|
||||
import tempfile
|
||||
import mimetypes
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
@ -336,7 +337,20 @@ class UserManager():
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
|
||||
return web.FileResponse(path)
|
||||
# User data files are arbitrary user-supplied content and are never
|
||||
# meant to render inline. Disable MIME sniffing and force a download
|
||||
# so uploaded markup/scripts can't execute in the app origin (stored
|
||||
# XSS). Content-Disposition: attachment is the load-bearing guard;
|
||||
# the content-type override and nosniff are defence in depth.
|
||||
content_type = mimetypes.guess_type(path)[0] or 'application/octet-stream'
|
||||
if folder_paths.is_dangerous_content_type(content_type):
|
||||
content_type = 'application/octet-stream'
|
||||
|
||||
return web.FileResponse(path, headers={
|
||||
"Content-Type": content_type,
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Content-Disposition": "attachment",
|
||||
})
|
||||
|
||||
@routes.post("/userdata/{file}")
|
||||
async def post_userdata(request):
|
||||
|
||||
@ -33,28 +33,6 @@ class EnumAction(argparse.Action):
|
||||
setattr(namespace, self.dest, value)
|
||||
|
||||
|
||||
def _positive_int(value: str) -> int:
|
||||
"""argparse type that rejects zero and negative integers."""
|
||||
try:
|
||||
ivalue = int(value)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(f"{value!r} is not an integer")
|
||||
if ivalue <= 0:
|
||||
raise argparse.ArgumentTypeError(f"{value!r} must be a positive integer (> 0)")
|
||||
return ivalue
|
||||
|
||||
|
||||
def _non_negative_int(value: str) -> int:
|
||||
"""argparse type that rejects negatives but allows zero (a disable sentinel)."""
|
||||
try:
|
||||
ivalue = int(value)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(f"{value!r} is not an integer")
|
||||
if ivalue < 0:
|
||||
raise argparse.ArgumentTypeError(f"{value!r} must be a non-negative integer (>= 0)")
|
||||
return ivalue
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
||||
@ -262,18 +240,10 @@ database_default_path = os.path.abspath(
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
|
||||
parser.add_argument("--enable-asset-hashing", action="store_true", help="Compute blake3 content hashes when scanning assets. Hashing enables future asset-portability features (deduplication, cross-machine model resolution) but adds startup cost and per-output cost on large models directories. Off by default; enable to opt in.")
|
||||
parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
|
||||
parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")
|
||||
|
||||
# ----- Model download manager (PRD: docs/prd-download-manager.md) -----
|
||||
parser.add_argument("--download-segments", type=_positive_int, default=8, metavar="N", help="Number of parallel HTTP range segments per file for the model download manager (default: 8).")
|
||||
parser.add_argument("--download-max-active", type=_positive_int, default=3, metavar="N", help="Maximum number of model downloads running concurrently (default: 3).")
|
||||
parser.add_argument("--download-max-connections-per-host", type=_positive_int, default=16, metavar="N", help="Maximum simultaneous connections to a single host for the download manager (default: 16).")
|
||||
parser.add_argument("--download-chunk-size", type=_positive_int, default=4 * 1024 * 1024, metavar="BYTES", help="Read chunk size in bytes for the download manager (default: 4 MiB).")
|
||||
parser.add_argument("--download-max-bytes", type=_non_negative_int, default=1024 * 1024 * 1024 * 1024, metavar="BYTES", help="Maximum size in bytes of a single download; aborts transfers that exceed it (guards against malicious/non-conforming hosts filling the disk). Set to 0 to disable (default: 1 TiB).")
|
||||
parser.add_argument("--download-allowed-hosts", type=str, nargs="*", default=[], metavar="HOST", help="Additional hostnames to add to the download manager allowlist (https only). The built-in defaults always include huggingface.co and civitai.com.")
|
||||
parser.add_argument("--download-allow-any-extension", action="store_true", help="Allow the download manager to fetch files with any extension (default: only known model extensions like .safetensors).")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
|
||||
@ -167,7 +167,7 @@ class Qwen3VLTokenizer(sd1_clip.SD1Tokenizer):
|
||||
embed_count = 0
|
||||
for r in tokens[key_name]:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 151655: # <|image_pad|>
|
||||
if isinstance(r[i][0], (int, float)) and r[i][0] == 151655: # <|image_pad|>
|
||||
if len(images) > embed_count:
|
||||
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
|
||||
@ -104,7 +104,6 @@ _CORE_FEATURE_FLAGS: dict[str, Any] = {
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
"node_replacements": True,
|
||||
"assets": args.enable_assets,
|
||||
"server_side_model_downloads": True,
|
||||
}
|
||||
|
||||
# CLI-provided flags cannot overwrite core flags
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -316,3 +316,36 @@ VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"1080p": 150,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SeedAudioConfig(BaseModel):
|
||||
format: str = Field(default="mp3")
|
||||
sample_rate: int = Field(default=24000)
|
||||
speech_rate: int = Field(default=0)
|
||||
loudness_rate: int = Field(default=0)
|
||||
pitch_rate: int = Field(default=0)
|
||||
|
||||
|
||||
class SeedAudioReference(BaseModel):
|
||||
speaker: str | None = Field(default=None)
|
||||
audio_data: str | None = Field(default=None)
|
||||
audio_url: str | None = Field(default=None)
|
||||
image_data: str | None = Field(default=None)
|
||||
image_url: str | None = Field(default=None)
|
||||
|
||||
|
||||
class SeedAudioRequest(BaseModel):
|
||||
model: str = Field(default="seed-audio-1.0")
|
||||
text_prompt: str = Field(...)
|
||||
references: list[SeedAudioReference] | None = Field(default=None)
|
||||
audio_config: SeedAudioConfig = Field(default_factory=SeedAudioConfig)
|
||||
watermark: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SeedAudioResponse(BaseModel):
|
||||
audio: str | None = Field(default=None)
|
||||
url: str | None = Field(default=None)
|
||||
duration: float | None = Field(default=None)
|
||||
original_duration: float | None = Field(default=None)
|
||||
code: int | None = Field(default=None)
|
||||
message: str | None = Field(default=None)
|
||||
|
||||
@ -121,6 +121,7 @@ class GeminiGenerationConfig(BaseModel):
|
||||
topK: int | None = Field(None, ge=1)
|
||||
topP: float | None = Field(None, ge=0.0, le=1.0)
|
||||
thinkingConfig: GeminiThinkingConfig | None = Field(None)
|
||||
responseModalities: list[str] | None = Field(None)
|
||||
|
||||
|
||||
class GeminiImageOutputOptions(BaseModel):
|
||||
|
||||
@ -33,53 +33,6 @@ class IdeogramColorPalette(
|
||||
)
|
||||
|
||||
|
||||
class ImageRequest(BaseModel):
|
||||
aspect_ratio: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.",
|
||||
)
|
||||
color_palette: Optional[Dict[str, Any]] = Field(
|
||||
None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.'
|
||||
)
|
||||
magic_prompt_option: Optional[str] = Field(
|
||||
None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')."
|
||||
)
|
||||
model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')")
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None,
|
||||
description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.',
|
||||
)
|
||||
num_images: Optional[int] = Field(
|
||||
1,
|
||||
description='Optional. Number of images to generate (1-8). Defaults to 1.',
|
||||
ge=1,
|
||||
le=8,
|
||||
)
|
||||
prompt: str = Field(
|
||||
..., description='Required. The prompt to use to generate the image.'
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.",
|
||||
)
|
||||
seed: Optional[int] = Field(
|
||||
None,
|
||||
description='Optional. A number between 0 and 2147483647.',
|
||||
ge=0,
|
||||
le=2147483647,
|
||||
)
|
||||
style_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.",
|
||||
)
|
||||
|
||||
|
||||
class IdeogramGenerateRequest(BaseModel):
|
||||
image_request: ImageRequest = Field(
|
||||
..., description='The image generation request parameters.'
|
||||
)
|
||||
|
||||
|
||||
class Datum(BaseModel):
|
||||
is_image_safe: Optional[bool] = Field(
|
||||
None, description='Indicates whether the image is considered safe.'
|
||||
@ -113,20 +66,6 @@ class StyleCode(RootModel[str]):
|
||||
root: str = Field(..., pattern='^[0-9A-Fa-f]{8}$')
|
||||
|
||||
|
||||
class Datum1(BaseModel):
|
||||
is_image_safe: Optional[bool] = None
|
||||
prompt: Optional[str] = None
|
||||
resolution: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
style_type: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class IdeogramV3IdeogramResponse(BaseModel):
|
||||
created: Optional[datetime] = None
|
||||
data: Optional[List[Datum1]] = None
|
||||
|
||||
|
||||
class RenderingSpeed1(str, Enum):
|
||||
TURBO = 'TURBO'
|
||||
DEFAULT = 'DEFAULT'
|
||||
|
||||
@ -1,147 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, confloat
|
||||
|
||||
|
||||
class StabilityFormat(str, Enum):
|
||||
png = 'png'
|
||||
jpeg = 'jpeg'
|
||||
webp = 'webp'
|
||||
|
||||
|
||||
class StabilityAspectRatio(str, Enum):
|
||||
ratio_1_1 = "1:1"
|
||||
ratio_16_9 = "16:9"
|
||||
ratio_9_16 = "9:16"
|
||||
ratio_3_2 = "3:2"
|
||||
ratio_2_3 = "2:3"
|
||||
ratio_5_4 = "5:4"
|
||||
ratio_4_5 = "4:5"
|
||||
ratio_21_9 = "21:9"
|
||||
ratio_9_21 = "9:21"
|
||||
|
||||
|
||||
def get_stability_style_presets(include_none=True):
|
||||
presets = []
|
||||
if include_none:
|
||||
presets.append("None")
|
||||
return presets + [x.value for x in StabilityStylePreset]
|
||||
|
||||
|
||||
class StabilityStylePreset(str, Enum):
|
||||
_3d_model = "3d-model"
|
||||
analog_film = "analog-film"
|
||||
anime = "anime"
|
||||
cinematic = "cinematic"
|
||||
comic_book = "comic-book"
|
||||
digital_art = "digital-art"
|
||||
enhance = "enhance"
|
||||
fantasy_art = "fantasy-art"
|
||||
isometric = "isometric"
|
||||
line_art = "line-art"
|
||||
low_poly = "low-poly"
|
||||
modeling_compound = "modeling-compound"
|
||||
neon_punk = "neon-punk"
|
||||
origami = "origami"
|
||||
photographic = "photographic"
|
||||
pixel_art = "pixel-art"
|
||||
tile_texture = "tile-texture"
|
||||
|
||||
|
||||
class Stability_SD3_5_Model(str, Enum):
|
||||
sd3_5_large = "sd3.5-large"
|
||||
# sd3_5_large_turbo = "sd3.5-large-turbo"
|
||||
sd3_5_medium = "sd3.5-medium"
|
||||
|
||||
|
||||
class Stability_SD3_5_GenerationMode(str, Enum):
|
||||
text_to_image = "text-to-image"
|
||||
image_to_image = "image-to-image"
|
||||
|
||||
|
||||
class StabilityStable3_5Request(BaseModel):
|
||||
model: str = Field(...)
|
||||
mode: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
aspect_ratio: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
style_preset: Optional[str] = Field(None)
|
||||
cfg_scale: float = Field(...)
|
||||
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
|
||||
|
||||
|
||||
class StabilityUpscaleConservativeRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
creativity: Optional[confloat(ge=0.2, le=0.5)] = Field(None)
|
||||
|
||||
|
||||
class StabilityUpscaleCreativeRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
creativity: Optional[confloat(ge=0.1, le=0.5)] = Field(None)
|
||||
style_preset: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityStableUltraRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
aspect_ratio: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
style_preset: Optional[str] = Field(None)
|
||||
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
|
||||
|
||||
|
||||
class StabilityStableUltraResponse(BaseModel):
|
||||
image: Optional[str] = Field(None)
|
||||
finish_reason: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class StabilityResultsGetResponse(BaseModel):
|
||||
image: Optional[str] = Field(None)
|
||||
finish_reason: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
id: Optional[str] = Field(None)
|
||||
name: Optional[str] = Field(None)
|
||||
errors: Optional[list[str]] = Field(None)
|
||||
status: Optional[str] = Field(None)
|
||||
result: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityAsyncResponse(BaseModel):
|
||||
id: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityTextToAudioRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
duration: int = Field(190, ge=1, le=190)
|
||||
seed: int = Field(0, ge=0, le=4294967294)
|
||||
steps: int = Field(8, ge=4, le=8)
|
||||
output_format: str = Field("wav")
|
||||
|
||||
|
||||
class StabilityAudioToAudioRequest(StabilityTextToAudioRequest):
|
||||
strength: float = Field(0.01, ge=0.01, le=1.0)
|
||||
|
||||
|
||||
class StabilityAudioInpaintRequest(StabilityTextToAudioRequest):
|
||||
mask_start: int = Field(30, ge=0, le=190)
|
||||
mask_end: int = Field(190, ge=0, le=190)
|
||||
|
||||
|
||||
class StabilityAudioResponse(BaseModel):
|
||||
audio: Optional[str] = Field(None)
|
||||
@ -1,3 +1,4 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import math
|
||||
@ -20,6 +21,10 @@ from comfy_api_nodes.apis.bytedance import (
|
||||
GetAssetResponse,
|
||||
Image2VideoTaskCreationRequest,
|
||||
ImageTaskCreationResponse,
|
||||
SeedAudioConfig,
|
||||
SeedAudioReference,
|
||||
SeedAudioRequest,
|
||||
SeedAudioResponse,
|
||||
Seedance2TaskCreationRequest,
|
||||
SeedanceCreateAssetRequest,
|
||||
SeedanceCreateAssetResponse,
|
||||
@ -43,6 +48,8 @@ from comfy_api_nodes.apis.bytedance import (
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
audio_bytes_to_audio_input,
|
||||
audio_input_to_mp3,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
downscale_image_tensor_by_max_side,
|
||||
@ -51,11 +58,14 @@ from comfy_api_nodes.util import (
|
||||
image_tensor_pair_to_batch,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
upscale_image_tensor_to_min_pixels,
|
||||
upscale_video_to_min_pixels,
|
||||
validate_audio_duration,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
@ -2474,6 +2484,311 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode):
|
||||
return IO.NodeOutput(asset_id, resolved_group)
|
||||
|
||||
|
||||
MODE_TEXT = "text only"
|
||||
MODE_AUDIO = "audio reference"
|
||||
MODE_IMAGE = "image reference"
|
||||
MODE_SPEAKER = "preset voice"
|
||||
|
||||
# (speaker_id, display_label) for built-in TTS 2.0 voices; resolvable ids are account-scoped.
|
||||
SEED_AUDIO_PRESET_VOICES: list[tuple[str, str]] = [
|
||||
("zh_female_vv_uranus_bigtts", "Vivi (Female, multilingual)"),
|
||||
("zh_female_xiaohe_uranus_bigtts", "Mindy (Female, multilingual)"),
|
||||
("en_female_stokie_uranus_bigtts", "Stokie (Female, English)"),
|
||||
("en_female_dacey_uranus_bigtts", "Dacey (Female, English)"),
|
||||
("en_male_tim_uranus_bigtts", "Tim (Male, English)"),
|
||||
("zh_male_m191_uranus_bigtts", "Kian (Male, multilingual)"),
|
||||
("zh_male_taocheng_uranus_bigtts", "Cedric (Male, multilingual)"),
|
||||
("zh_male_sophie_uranus_bigtts", "Sophie (Female, multilingual)"),
|
||||
("zh_female_yingyujiaoxue_uranus_bigtts", "Jean (Female, multilingual)"),
|
||||
("zh_male_dayi_uranus_bigtts", "Magnus (Male, multilingual)"),
|
||||
("zh_female_mizai_uranus_bigtts", "Mabel (Female, multilingual)"),
|
||||
("zh_female_jitangnv_uranus_bigtts", "Nadia (Female, multilingual)"),
|
||||
("zh_female_meilinvyou_uranus_bigtts", "Opal (Female, multilingual)"),
|
||||
("zh_female_liuchangnv_uranus_bigtts", "Pearl (Female, multilingual)"),
|
||||
("zh_male_ruyayichen_uranus_bigtts", "Quentin (Male, multilingual)"),
|
||||
("zh_female_vivo_uranus_bigtts", "Vienna (Female, multilingual)"),
|
||||
("zh_female_xiaoai_uranus_bigtts", "Alina (Female, multilingual)"),
|
||||
("zh_female_cancan_uranus_bigtts", "Corinne (Female, multilingual)"),
|
||||
("zh_female_tianmeixiaoyuan_uranus_bigtts", "Esther (Female, multilingual)"),
|
||||
("zh_female_tianmeitaozi_uranus_bigtts", "Freya (Female, multilingual)"),
|
||||
("zh_female_shuangkuaisisi_uranus_bigtts", "Gigi (Female, multilingual)"),
|
||||
("zh_female_peiqi_uranus_bigtts", "Holly (Female, multilingual)"),
|
||||
("zh_female_xiaoxue_uranus_bigtts", "Lyla (Female, multilingual)"),
|
||||
("zh_female_yuanqi_uranus_bigtts", "Daisy (Female, multilingual)"),
|
||||
("zh_female_kefunvsheng_uranus_bigtts", "Tracy (Female, multilingual)"),
|
||||
("zh_male_shaonianzixin_uranus_bigtts", "Jess (Male, multilingual)"),
|
||||
("zh_female_linjianvhai_uranus_bigtts", "Pinky (Female, multilingual)"),
|
||||
("zh_female_kiwi_uranus_bigtts", "Sweety (Female, multilingual)"),
|
||||
("zh_female_sajiaoxuemei_uranus_bigtts", "Sandy (Female, multilingual)"),
|
||||
("de_male_seven_uranus_bigtts", "Sven (Male, German)"),
|
||||
("jp_female_minimi_uranus_bigtts", "Minimi (Female, Japanese)"),
|
||||
("fr_male_usseau_uranus_bigtts", "Usseau (Male, French)"),
|
||||
("es_male_felipe_uranus_bigtts", "Felipe (Male, Spanish)"),
|
||||
("id_male_han_uranus_bigtts", "Han (Male, Indonesian)"),
|
||||
("pt_male_martins_uranus_bigtts", "Martins (Male, Portuguese)"),
|
||||
("it_male_enzo_uranus_bigtts", "Enzo (Male, Italian)"),
|
||||
("kr_male_shane_uranus_bigtts", "Shane (Male, Korean)"),
|
||||
("zh_male_liufei_uranus_bigtts", "Felix (Male, Chinese)"),
|
||||
("zh_female_qingxinnvsheng_uranus_bigtts", "Celeste (Female, Chinese)"),
|
||||
("zh_male_sunwukong_uranus_bigtts", "Monkey King (Male, Chinese)"),
|
||||
]
|
||||
SEED_AUDIO_VOICE_OPTIONS = [label for _, label in SEED_AUDIO_PRESET_VOICES]
|
||||
SEED_AUDIO_VOICE_MAP = {label: speaker_id for speaker_id, label in SEED_AUDIO_PRESET_VOICES}
|
||||
|
||||
_AUDIO_TAG_RE = re.compile(r"@Audio(\d+)", re.IGNORECASE)
|
||||
|
||||
|
||||
def max_audio_tag(prompt: str) -> int:
|
||||
"""Highest N referenced as @AudioN in the prompt (0 if none)."""
|
||||
nums = [int(m) for m in _AUDIO_TAG_RE.findall(prompt or "")]
|
||||
return max(nums) if nums else 0
|
||||
|
||||
|
||||
def connected_audio_indices(reference_mode: dict) -> list[int]:
|
||||
"""Indices (1-based) of connected reference_audio sockets, in order."""
|
||||
return [
|
||||
i
|
||||
for i in range(1, 3 + 1)
|
||||
if reference_mode.get(f"reference_audio_{i}") is not None
|
||||
]
|
||||
|
||||
|
||||
def validate_seed_audio_inputs(
|
||||
text_prompt: str,
|
||||
mode: str,
|
||||
audio_indices: list[int],
|
||||
has_image: bool,
|
||||
preset_voice: str | None = None,
|
||||
) -> None:
|
||||
validate_string(text_prompt, field_name="text_prompt", min_length=1, max_length=3000)
|
||||
max_tag = max_audio_tag(text_prompt)
|
||||
|
||||
if mode == MODE_TEXT:
|
||||
if max_tag:
|
||||
raise ValueError(
|
||||
f"The prompt references @Audio{max_tag}, but reference mode is '{MODE_TEXT}'. "
|
||||
f"Switch to '{MODE_AUDIO}' and connect the reference clip(s)."
|
||||
)
|
||||
elif mode == MODE_AUDIO:
|
||||
if not audio_indices:
|
||||
raise ValueError(
|
||||
f"Reference mode '{MODE_AUDIO}' requires at least one reference_audio input "
|
||||
f"(or switch to '{MODE_TEXT}')."
|
||||
)
|
||||
if audio_indices != list(range(1, len(audio_indices) + 1)):
|
||||
raise ValueError(
|
||||
"Connect reference_audio inputs in order without gaps: reference_audio_1, then _2, then _3."
|
||||
)
|
||||
if max_tag > len(audio_indices):
|
||||
raise ValueError(
|
||||
f"The prompt references @Audio{max_tag}, but only {len(audio_indices)} "
|
||||
f"reference audio(s) are connected."
|
||||
)
|
||||
elif mode == MODE_IMAGE:
|
||||
if not has_image:
|
||||
raise ValueError(f"Reference mode '{MODE_IMAGE}' requires a reference_image input.")
|
||||
if max_tag:
|
||||
raise ValueError(
|
||||
f"@AudioN tags are not used in '{MODE_IMAGE}' mode; the prompt should contain "
|
||||
f"only the text to synthesize."
|
||||
)
|
||||
elif mode == MODE_SPEAKER:
|
||||
if not preset_voice or preset_voice not in SEED_AUDIO_VOICE_MAP:
|
||||
raise ValueError(f"Reference mode '{MODE_SPEAKER}' requires selecting a preset voice.")
|
||||
if max_tag > 1:
|
||||
raise ValueError(
|
||||
f"'{MODE_SPEAKER}' mode uses a single voice, so @Audio{max_tag} is out of range. "
|
||||
f"Remove the @AudioN tags — the whole prompt is read in the selected voice."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown reference mode: {mode!r}")
|
||||
|
||||
|
||||
class ByteDanceSeedAudioNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedAudio",
|
||||
display_name="ByteDance Seed Audio 1.0",
|
||||
category="partner/audio/ByteDance",
|
||||
description=(
|
||||
"Generate speech, music, sound effects and multi-speaker dialogue from a single prompt "
|
||||
"with ByteDance Seed Audio 1.0. Describe the voice(s), emotion, ambience, background music "
|
||||
"and sound effects in the prompt, and include the lines to speak. Optionally pick a built-in "
|
||||
"preset voice, clone voices from up to 3 reference clips (tagged @Audio1-3 in the prompt), "
|
||||
"or derive a voice from a character image. Up to 2 minutes of audio per run."
|
||||
),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"text_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip=(
|
||||
"Describe the voice(s), emotion, pacing, ambience, background music and sound "
|
||||
"effects, and include the lines to speak (name characters inline for dialogue). "
|
||||
"In 'audio reference' mode, refer to connected clips by order as @Audio1, @Audio2, "
|
||||
"@Audio3. Maximum 3000 characters."
|
||||
),
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"reference_mode",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(MODE_TEXT, []),
|
||||
IO.DynamicCombo.Option(
|
||||
MODE_AUDIO,
|
||||
[
|
||||
IO.Audio.Input(
|
||||
"reference_audio_1",
|
||||
optional=True,
|
||||
tooltip="Reference clip for voice cloning, tagged @Audio1 in the prompt. "
|
||||
"Up to 30s.",
|
||||
),
|
||||
IO.Audio.Input(
|
||||
"reference_audio_2",
|
||||
optional=True,
|
||||
tooltip="Reference clip tagged @Audio2 in the prompt. Up to 30s.",
|
||||
),
|
||||
IO.Audio.Input(
|
||||
"reference_audio_3",
|
||||
optional=True,
|
||||
tooltip="Reference clip tagged @Audio3 in the prompt. Up to 30s.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
MODE_IMAGE,
|
||||
[
|
||||
IO.Image.Input(
|
||||
"reference_image",
|
||||
optional=True,
|
||||
tooltip="A single character image; the model derives a voice from it. "
|
||||
"Cannot be combined with reference audio.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
MODE_SPEAKER,
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"preset_voice",
|
||||
options=SEED_AUDIO_VOICE_OPTIONS,
|
||||
default=SEED_AUDIO_VOICE_OPTIONS[0],
|
||||
tooltip="A built-in TTS 2.0 voice that reads the prompt. No reference "
|
||||
"clip needed, and @AudioN tags are not used in this mode.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip=(
|
||||
"How to condition the voice: 'text only' (describe everything in the prompt), "
|
||||
"'audio reference' (clone up to 3 voices, tagged @Audio1-3), 'image reference' "
|
||||
"(derive a voice from one character image), or 'preset voice' (pick a built-in "
|
||||
"named voice that reads the prompt)."
|
||||
),
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"sample_rate",
|
||||
options=["8000", "16000", "24000", "32000", "44100", "48000"],
|
||||
default="24000",
|
||||
tooltip="Output sample rate in Hz.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"speech_rate",
|
||||
default=0,
|
||||
min=-50,
|
||||
max=100,
|
||||
tooltip="Speaking speed. 0 = normal, 100 = 2.0x, -50 = 0.5x.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"loudness_rate",
|
||||
default=0,
|
||||
min=-50,
|
||||
max=100,
|
||||
tooltip="Loudness. 0 = normal, 100 = 2.0x, -50 = 0.5x.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"pitch_rate",
|
||||
default=0,
|
||||
min=-12,
|
||||
max=12,
|
||||
tooltip="Pitch shift in semitones (-12 to 12).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.2145, "format":{"suffix":"/minute","approximate":true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
text_prompt: str,
|
||||
reference_mode: dict,
|
||||
sample_rate: str,
|
||||
speech_rate: int,
|
||||
loudness_rate: int,
|
||||
pitch_rate: int,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
mode = reference_mode["reference_mode"]
|
||||
audio_indices = connected_audio_indices(reference_mode)
|
||||
image = reference_mode.get("reference_image")
|
||||
preset_voice = reference_mode.get("preset_voice")
|
||||
validate_seed_audio_inputs(text_prompt, mode, audio_indices, image is not None, preset_voice)
|
||||
|
||||
references: list[SeedAudioReference] | None = None
|
||||
if mode == MODE_AUDIO:
|
||||
references = []
|
||||
for i in audio_indices:
|
||||
clip = reference_mode[f"reference_audio_{i}"]
|
||||
validate_audio_duration(clip, max_duration=30.0)
|
||||
mp3_bytes = audio_input_to_mp3(clip).getvalue()
|
||||
references.append(SeedAudioReference(audio_data=base64.b64encode(mp3_bytes).decode("utf-8")))
|
||||
elif mode == MODE_IMAGE:
|
||||
image = upscale_image_tensor_to_min_pixels(image, 160_000)
|
||||
references = [SeedAudioReference(image_data=tensor_to_base64_string(image, mime_type="image/png"))]
|
||||
elif mode == MODE_SPEAKER:
|
||||
references = [SeedAudioReference(speaker=SEED_AUDIO_VOICE_MAP[preset_voice])]
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/byteplus/api/v3/tts/create", method="POST"),
|
||||
response_model=SeedAudioResponse,
|
||||
data=SeedAudioRequest(
|
||||
text_prompt=text_prompt,
|
||||
references=references,
|
||||
audio_config=SeedAudioConfig(
|
||||
sample_rate=int(sample_rate),
|
||||
speech_rate=speech_rate,
|
||||
loudness_rate=loudness_rate,
|
||||
pitch_rate=pitch_rate,
|
||||
),
|
||||
),
|
||||
)
|
||||
if not response.audio:
|
||||
raise Exception(
|
||||
f"Seed Audio returned no audio (code={response.code}): {response.message}"
|
||||
)
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response.audio)))
|
||||
|
||||
|
||||
class ByteDanceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -2490,6 +2805,7 @@ class ByteDanceExtension(ComfyExtension):
|
||||
ByteDance2ReferenceNode,
|
||||
ByteDanceCreateImageAsset,
|
||||
ByteDanceCreateVideoAsset,
|
||||
ByteDanceSeedAudioNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl, Types
|
||||
from comfy_api_nodes.apis.gemini import (
|
||||
GeminiContent,
|
||||
GeminiFileData,
|
||||
@ -37,6 +37,7 @@ from comfy_api_nodes.util import (
|
||||
audio_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
@ -45,6 +46,7 @@ from comfy_api_nodes.util import (
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
validate_video_duration,
|
||||
video_to_base64_string,
|
||||
)
|
||||
|
||||
@ -229,10 +231,29 @@ async def get_image_from_response(response: GeminiGenerateContentResponse, thoug
|
||||
return torch.cat(image_tensors, dim=0)
|
||||
|
||||
|
||||
async def get_video_from_response(
|
||||
response: GeminiGenerateContentResponse, cls: type[IO.ComfyNode] | None = None
|
||||
) -> InputImpl.VideoFromFile:
|
||||
parts = get_parts_by_type(response, "video/*")
|
||||
for part in parts:
|
||||
if part.inlineData and part.inlineData.data:
|
||||
return InputImpl.VideoFromFile(BytesIO(base64.b64decode(part.inlineData.data)))
|
||||
if part.fileData and part.fileData.fileUri:
|
||||
return await download_url_to_video_output(part.fileData.fileUri, cls=cls)
|
||||
model_message = get_text_from_response(response).strip()
|
||||
if model_message:
|
||||
raise ValueError(f"Gemini did not generate a video. Model response: {model_message}")
|
||||
raise ValueError(
|
||||
"Gemini did not generate a video. Try rephrasing your prompt, "
|
||||
"shortening the requested duration, or reducing the number of input images/videos."
|
||||
)
|
||||
|
||||
|
||||
def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None:
|
||||
if not response.modelVersion:
|
||||
return None
|
||||
# Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing
|
||||
output_video_tokens_price = 0.0
|
||||
if response.modelVersion == "gemini-2.5-pro":
|
||||
input_tokens_price = 1.25
|
||||
output_text_tokens_price = 10.0
|
||||
@ -265,6 +286,11 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
||||
input_tokens_price = 0.25
|
||||
output_text_tokens_price = 1.50
|
||||
output_image_tokens_price = 30.0
|
||||
elif response.modelVersion == "gemini-omni-flash-preview":
|
||||
input_tokens_price = 2.145
|
||||
output_text_tokens_price = 12.87
|
||||
output_image_tokens_price = 0.0
|
||||
output_video_tokens_price = 25.025
|
||||
else:
|
||||
return None
|
||||
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
|
||||
@ -272,6 +298,8 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
||||
for i in response.usageMetadata.candidatesTokensDetails:
|
||||
if i.modality == Modality.IMAGE:
|
||||
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
|
||||
elif i.modality == Modality.VIDEO:
|
||||
final_price += output_video_tokens_price * i.tokenCount # for Omni Flash
|
||||
else:
|
||||
final_price += output_text_tokens_price * i.tokenCount
|
||||
if response.usageMetadata.thoughtsTokenCount:
|
||||
@ -1531,6 +1559,149 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
OMNI_MAX_IMAGES = 14
|
||||
OMNI_MAX_VIDEOS = 3
|
||||
|
||||
OMNI_MODELS: dict[str, str] = {
|
||||
"Omni Flash": "gemini-omni-flash-preview",
|
||||
}
|
||||
|
||||
|
||||
def _omni_flash_inputs() -> list[Input]:
|
||||
"""Per-model inputs for the Omni video DynamicCombo (prompt + reference media + sampling)."""
|
||||
return [
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Describe the video to generate. Specify the length and aspect ratio directly in the "
|
||||
'prompt, e.g. "a 6-second clip in 16:9". Length may be 3-10 seconds; the aspect ratio must be '
|
||||
"16:9 (landscape) or 9:16 (portrait). The output is 720p, 24 FPS, with audio.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, OMNI_MAX_IMAGES + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional reference image(s) to guide or animate the video. Up to {OMNI_MAX_IMAGES} images.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"videos",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Video.Input("video"),
|
||||
names=[f"video_{i}" for i in range(1, OMNI_MAX_VIDEOS + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional reference video(s) to guide or edit. Up to {OMNI_MAX_VIDEOS} videos, "
|
||||
f"each up to 10 seconds long.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.01,
|
||||
tooltip="Controls randomness. Lower is more focused/deterministic, higher is more varied.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"top_p",
|
||||
default=0.95,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class GeminiVideoOmni(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiVideoOmni",
|
||||
display_name="Google Gemini Omni (Video)",
|
||||
category="partner/video/Gemini",
|
||||
essentials_category="Video Generation",
|
||||
description="Generate a video with audio from a text prompt using Google's Gemini Omni Flash model. "
|
||||
"Optionally provide reference images and/or videos to guide or edit the result. Describe the desired "
|
||||
"length (3-10s) and aspect ratio (16:9 or 9:16) directly in the prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Omni Flash", _omni_flash_inputs()),
|
||||
],
|
||||
tooltip="The Gemini video model used to generate the video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
IO.String.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr='{"type":"usd","usd":0.146,"format":{"suffix":"/second","approximate":true}}'
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, model: dict, seed: int) -> IO.NodeOutput:
|
||||
prompt = model.get("prompt") or ""
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_id = OMNI_MODELS[model["model"]]
|
||||
|
||||
images = [t for t in (model.get("images") or {}).values() if t is not None]
|
||||
videos = [v for v in (model.get("videos") or {}).values() if v is not None]
|
||||
if sum(get_number_of_images(t) for t in images) > OMNI_MAX_IMAGES:
|
||||
raise ValueError(f"The current maximum number of supported images is {OMNI_MAX_IMAGES}.")
|
||||
if len(videos) > OMNI_MAX_VIDEOS:
|
||||
raise ValueError(f"The current maximum number of supported videos is {OMNI_MAX_VIDEOS}.")
|
||||
for video in videos:
|
||||
validate_video_duration(video, max_duration=10)
|
||||
|
||||
parts: list[GeminiPart] = []
|
||||
if images or videos:
|
||||
parts.extend(await build_gemini_media_parts(cls, images, [], videos))
|
||||
parts.append(GeminiPart(text=prompt))
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"),
|
||||
data=GeminiGenerateContentRequest(
|
||||
contents=[GeminiContent(role=GeminiRole.user, parts=parts)],
|
||||
generationConfig=GeminiGenerationConfig(
|
||||
responseModalities=["TEXT", "VIDEO"],
|
||||
temperature=model.get("temperature", 1.0),
|
||||
topP=model.get("top_p", 0.95),
|
||||
),
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await get_video_from_response(response, cls=cls),
|
||||
get_text_from_response(response),
|
||||
)
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -1541,6 +1712,7 @@ class GeminiExtension(ComfyExtension):
|
||||
GeminiImage2,
|
||||
GeminiNanoBanana2,
|
||||
GeminiNanoBanana2V2,
|
||||
GeminiVideoOmni,
|
||||
GeminiInputFiles,
|
||||
]
|
||||
|
||||
|
||||
@ -5,9 +5,7 @@ from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from comfy_api_nodes.apis.ideogram import (
|
||||
IdeogramGenerateRequest,
|
||||
IdeogramGenerateResponse,
|
||||
ImageRequest,
|
||||
IdeogramV3Request,
|
||||
IdeogramV3EditRequest,
|
||||
IdeogramV4Request,
|
||||
@ -21,101 +19,6 @@ from comfy_api_nodes.util import (
|
||||
validate_string,
|
||||
)
|
||||
|
||||
V1_V1_RES_MAP = {
|
||||
"Auto":"AUTO",
|
||||
"512 x 1536":"RESOLUTION_512_1536",
|
||||
"576 x 1408":"RESOLUTION_576_1408",
|
||||
"576 x 1472":"RESOLUTION_576_1472",
|
||||
"576 x 1536":"RESOLUTION_576_1536",
|
||||
"640 x 1024":"RESOLUTION_640_1024",
|
||||
"640 x 1344":"RESOLUTION_640_1344",
|
||||
"640 x 1408":"RESOLUTION_640_1408",
|
||||
"640 x 1472":"RESOLUTION_640_1472",
|
||||
"640 x 1536":"RESOLUTION_640_1536",
|
||||
"704 x 1152":"RESOLUTION_704_1152",
|
||||
"704 x 1216":"RESOLUTION_704_1216",
|
||||
"704 x 1280":"RESOLUTION_704_1280",
|
||||
"704 x 1344":"RESOLUTION_704_1344",
|
||||
"704 x 1408":"RESOLUTION_704_1408",
|
||||
"704 x 1472":"RESOLUTION_704_1472",
|
||||
"720 x 1280":"RESOLUTION_720_1280",
|
||||
"736 x 1312":"RESOLUTION_736_1312",
|
||||
"768 x 1024":"RESOLUTION_768_1024",
|
||||
"768 x 1088":"RESOLUTION_768_1088",
|
||||
"768 x 1152":"RESOLUTION_768_1152",
|
||||
"768 x 1216":"RESOLUTION_768_1216",
|
||||
"768 x 1232":"RESOLUTION_768_1232",
|
||||
"768 x 1280":"RESOLUTION_768_1280",
|
||||
"768 x 1344":"RESOLUTION_768_1344",
|
||||
"832 x 960":"RESOLUTION_832_960",
|
||||
"832 x 1024":"RESOLUTION_832_1024",
|
||||
"832 x 1088":"RESOLUTION_832_1088",
|
||||
"832 x 1152":"RESOLUTION_832_1152",
|
||||
"832 x 1216":"RESOLUTION_832_1216",
|
||||
"832 x 1248":"RESOLUTION_832_1248",
|
||||
"864 x 1152":"RESOLUTION_864_1152",
|
||||
"896 x 960":"RESOLUTION_896_960",
|
||||
"896 x 1024":"RESOLUTION_896_1024",
|
||||
"896 x 1088":"RESOLUTION_896_1088",
|
||||
"896 x 1120":"RESOLUTION_896_1120",
|
||||
"896 x 1152":"RESOLUTION_896_1152",
|
||||
"960 x 832":"RESOLUTION_960_832",
|
||||
"960 x 896":"RESOLUTION_960_896",
|
||||
"960 x 1024":"RESOLUTION_960_1024",
|
||||
"960 x 1088":"RESOLUTION_960_1088",
|
||||
"1024 x 640":"RESOLUTION_1024_640",
|
||||
"1024 x 768":"RESOLUTION_1024_768",
|
||||
"1024 x 832":"RESOLUTION_1024_832",
|
||||
"1024 x 896":"RESOLUTION_1024_896",
|
||||
"1024 x 960":"RESOLUTION_1024_960",
|
||||
"1024 x 1024":"RESOLUTION_1024_1024",
|
||||
"1088 x 768":"RESOLUTION_1088_768",
|
||||
"1088 x 832":"RESOLUTION_1088_832",
|
||||
"1088 x 896":"RESOLUTION_1088_896",
|
||||
"1088 x 960":"RESOLUTION_1088_960",
|
||||
"1120 x 896":"RESOLUTION_1120_896",
|
||||
"1152 x 704":"RESOLUTION_1152_704",
|
||||
"1152 x 768":"RESOLUTION_1152_768",
|
||||
"1152 x 832":"RESOLUTION_1152_832",
|
||||
"1152 x 864":"RESOLUTION_1152_864",
|
||||
"1152 x 896":"RESOLUTION_1152_896",
|
||||
"1216 x 704":"RESOLUTION_1216_704",
|
||||
"1216 x 768":"RESOLUTION_1216_768",
|
||||
"1216 x 832":"RESOLUTION_1216_832",
|
||||
"1232 x 768":"RESOLUTION_1232_768",
|
||||
"1248 x 832":"RESOLUTION_1248_832",
|
||||
"1280 x 704":"RESOLUTION_1280_704",
|
||||
"1280 x 720":"RESOLUTION_1280_720",
|
||||
"1280 x 768":"RESOLUTION_1280_768",
|
||||
"1280 x 800":"RESOLUTION_1280_800",
|
||||
"1312 x 736":"RESOLUTION_1312_736",
|
||||
"1344 x 640":"RESOLUTION_1344_640",
|
||||
"1344 x 704":"RESOLUTION_1344_704",
|
||||
"1344 x 768":"RESOLUTION_1344_768",
|
||||
"1408 x 576":"RESOLUTION_1408_576",
|
||||
"1408 x 640":"RESOLUTION_1408_640",
|
||||
"1408 x 704":"RESOLUTION_1408_704",
|
||||
"1472 x 576":"RESOLUTION_1472_576",
|
||||
"1472 x 640":"RESOLUTION_1472_640",
|
||||
"1472 x 704":"RESOLUTION_1472_704",
|
||||
"1536 x 512":"RESOLUTION_1536_512",
|
||||
"1536 x 576":"RESOLUTION_1536_576",
|
||||
"1536 x 640":"RESOLUTION_1536_640",
|
||||
}
|
||||
|
||||
V1_V2_RATIO_MAP = {
|
||||
"1:1":"ASPECT_1_1",
|
||||
"4:3":"ASPECT_4_3",
|
||||
"3:4":"ASPECT_3_4",
|
||||
"16:9":"ASPECT_16_9",
|
||||
"9:16":"ASPECT_9_16",
|
||||
"2:1":"ASPECT_2_1",
|
||||
"1:2":"ASPECT_1_2",
|
||||
"3:2":"ASPECT_3_2",
|
||||
"2:3":"ASPECT_2_3",
|
||||
"4:5":"ASPECT_4_5",
|
||||
"5:4":"ASPECT_5_4",
|
||||
}
|
||||
|
||||
V3_RATIO_MAP = {
|
||||
"1:3":"1x3",
|
||||
@ -229,298 +132,6 @@ async def download_and_process_images(image_urls):
|
||||
return stacked_tensors
|
||||
|
||||
|
||||
class IdeogramV1(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV1",
|
||||
display_name="Ideogram V1",
|
||||
category="partner/image/Ideogram",
|
||||
description="Generates images using the Ideogram V1 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"turbo",
|
||||
default=False,
|
||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V1_V2_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of what to exclude from the image",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]),
|
||||
expr="""
|
||||
(
|
||||
$n := widgets.num_images;
|
||||
$base := (widgets.turbo = true) ? 0.0286 : 0.0858;
|
||||
{"type":"usd","usd": $round($base * $n, 2)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
turbo=False,
|
||||
aspect_ratio="1:1",
|
||||
magic_prompt_option="AUTO",
|
||||
seed=0,
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
):
|
||||
# Determine the model based on turbo setting
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
model = "V_1_TURBO" if turbo else "V_1"
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=IdeogramGenerateRequest(
|
||||
image_request=ImageRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
num_images=num_images,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
|
||||
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
)
|
||||
),
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
class IdeogramV2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV2",
|
||||
display_name="Ideogram V2",
|
||||
category="partner/image/Ideogram",
|
||||
description="Generates images using the Ideogram V2 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"turbo",
|
||||
default=False,
|
||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V1_V2_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=list(V1_V1_RES_MAP.keys()),
|
||||
default="Auto",
|
||||
tooltip="The resolution for image generation. "
|
||||
"If not set to AUTO, this overrides the aspect_ratio setting.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"style_type",
|
||||
options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
||||
default="NONE",
|
||||
tooltip="Style type for generation (V2 only)",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of what to exclude from the image",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
#"color_palette": (
|
||||
# IO.STRING,
|
||||
# {
|
||||
# "multiline": False,
|
||||
# "default": "",
|
||||
# "tooltip": "Color palette preset name or hex colors with weights",
|
||||
# },
|
||||
#),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]),
|
||||
expr="""
|
||||
(
|
||||
$n := widgets.num_images;
|
||||
$base := (widgets.turbo = true) ? 0.0715 : 0.1144;
|
||||
{"type":"usd","usd": $round($base * $n, 2)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
turbo=False,
|
||||
aspect_ratio="1:1",
|
||||
resolution="Auto",
|
||||
magic_prompt_option="AUTO",
|
||||
seed=0,
|
||||
style_type="NONE",
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
color_palette="",
|
||||
):
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
resolution = V1_V1_RES_MAP.get(resolution, None)
|
||||
# Determine the model based on turbo setting
|
||||
model = "V_2_TURBO" if turbo else "V_2"
|
||||
|
||||
# Handle resolution vs aspect_ratio logic
|
||||
# If resolution is not AUTO, it overrides aspect_ratio
|
||||
final_resolution = None
|
||||
final_aspect_ratio = None
|
||||
|
||||
if resolution != "AUTO":
|
||||
final_resolution = resolution
|
||||
else:
|
||||
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=IdeogramGenerateRequest(
|
||||
image_request=ImageRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
num_images=num_images,
|
||||
seed=seed,
|
||||
aspect_ratio=final_aspect_ratio,
|
||||
resolution=final_resolution,
|
||||
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||
style_type=style_type if style_type != "NONE" else None,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
color_palette=color_palette if color_palette else None,
|
||||
)
|
||||
),
|
||||
max_retries=1,
|
||||
)
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
class IdeogramV3(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -917,8 +528,6 @@ class IdeogramExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
IdeogramV1,
|
||||
IdeogramV2,
|
||||
IdeogramV3,
|
||||
IdeogramV4,
|
||||
]
|
||||
|
||||
@ -1,932 +0,0 @@
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, Input, IO
|
||||
from comfy_api_nodes.apis.stability import (
|
||||
StabilityUpscaleConservativeRequest,
|
||||
StabilityUpscaleCreativeRequest,
|
||||
StabilityAsyncResponse,
|
||||
StabilityResultsGetResponse,
|
||||
StabilityStable3_5Request,
|
||||
StabilityStableUltraRequest,
|
||||
StabilityStableUltraResponse,
|
||||
StabilityAspectRatio,
|
||||
Stability_SD3_5_Model,
|
||||
Stability_SD3_5_GenerationMode,
|
||||
get_stability_style_presets,
|
||||
StabilityTextToAudioRequest,
|
||||
StabilityAudioToAudioRequest,
|
||||
StabilityAudioInpaintRequest,
|
||||
StabilityAudioResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
validate_audio_duration,
|
||||
validate_string,
|
||||
audio_input_to_mp3,
|
||||
bytesio_to_image_tensor,
|
||||
tensor_to_bytesio,
|
||||
audio_bytes_to_audio_input,
|
||||
sync_op,
|
||||
poll_op,
|
||||
ApiEndpoint,
|
||||
)
|
||||
|
||||
import torch
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class StabilityPollStatus(str, Enum):
|
||||
finished = "finished"
|
||||
in_progress = "in_progress"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
def get_async_dummy_status(x: StabilityResultsGetResponse):
|
||||
if x.name is not None or x.errors is not None:
|
||||
return StabilityPollStatus.failed
|
||||
elif x.finish_reason is not None:
|
||||
return StabilityPollStatus.finished
|
||||
return StabilityPollStatus.in_progress
|
||||
|
||||
|
||||
class StabilityStableImageUltraNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityStableImageUltraNode",
|
||||
display_name="Stability AI Stable Image Ultra",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
|
||||
"elements, colors, and subjects will lead to better results. " +
|
||||
"To control the weight of a given word use the format `(word:weight)`," +
|
||||
"where `word` is the word you'd like to control the weight of and `weight`" +
|
||||
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
|
||||
"would convey a sky that was blue and green, but more green than blue.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=StabilityAspectRatio,
|
||||
default=StabilityAspectRatio.ratio_1_1,
|
||||
tooltip="Aspect ratio of generated image.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"style_preset",
|
||||
options=get_stability_style_presets(),
|
||||
tooltip="Optional desired style of generated image.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"image_denoise",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.08}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
style_preset: str,
|
||||
seed: int,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
negative_prompt: str = "",
|
||||
image_denoise: Optional[float] = 0.5,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
if image is not None:
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
|
||||
else:
|
||||
image_denoise = None
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
if style_preset == "None":
|
||||
style_preset = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
data=StabilityStableUltraRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
seed=seed,
|
||||
strength=image_denoise,
|
||||
style_preset=style_preset,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityStableImageSD_3_5Node",
|
||||
display_name="Stability AI Stable Diffusion 3.5 Image",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Stability_SD3_5_Model,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=StabilityAspectRatio,
|
||||
default=StabilityAspectRatio.ratio_1_1,
|
||||
tooltip="Aspect ratio of generated image.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"style_preset",
|
||||
options=get_stability_style_presets(),
|
||||
tooltip="Optional desired style of generated image.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"cfg_scale",
|
||||
default=4.0,
|
||||
min=1.0,
|
||||
max=10.0,
|
||||
step=0.1,
|
||||
tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"image_denoise",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$contains(widgets.model,"large")
|
||||
? {"type":"usd","usd":0.065}
|
||||
: {"type":"usd","usd":0.035}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
style_preset: str,
|
||||
seed: int,
|
||||
cfg_scale: float,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
negative_prompt: str = "",
|
||||
image_denoise: Optional[float] = 0.5,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
mode = Stability_SD3_5_GenerationMode.text_to_image
|
||||
if image is not None:
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
|
||||
mode = Stability_SD3_5_GenerationMode.image_to_image
|
||||
aspect_ratio = None
|
||||
else:
|
||||
image_denoise = None
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
if style_preset == "None":
|
||||
style_preset = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
data=StabilityStable3_5Request(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
seed=seed,
|
||||
strength=image_denoise,
|
||||
style_preset=style_preset,
|
||||
cfg_scale=cfg_scale,
|
||||
model=model,
|
||||
mode=mode,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
||||
"""
|
||||
Upscale image with minimal alterations to 4K resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleConservativeNode",
|
||||
display_name="Stability AI Upscale Conservative",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"creativity",
|
||||
default=0.35,
|
||||
min=0.2,
|
||||
max=0.5,
|
||||
step=0.01,
|
||||
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.4}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
creativity: float,
|
||||
seed: int,
|
||||
negative_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
data=StabilityUpscaleConservativeRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
creativity=round(creativity,2),
|
||||
seed=seed,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
||||
"""
|
||||
Upscale image with minimal alterations to 4K resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleCreativeNode",
|
||||
display_name="Stability AI Upscale Creative",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"creativity",
|
||||
default=0.3,
|
||||
min=0.1,
|
||||
max=0.5,
|
||||
step=0.01,
|
||||
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"style_preset",
|
||||
options=get_stability_style_presets(),
|
||||
tooltip="Optional desired style of generated image.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
default="",
|
||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||
force_input=True,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.6}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
creativity: float,
|
||||
style_preset: str,
|
||||
seed: int,
|
||||
negative_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
if style_preset == "None":
|
||||
style_preset = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
|
||||
response_model=StabilityAsyncResponse,
|
||||
data=StabilityUpscaleCreativeRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
creativity=round(creativity,2),
|
||||
style_preset=style_preset,
|
||||
seed=seed,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
|
||||
response_model=StabilityResultsGetResponse,
|
||||
poll_interval=3,
|
||||
status_extractor=lambda x: get_async_dummy_status(x),
|
||||
)
|
||||
|
||||
if response_poll.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_poll.result)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityUpscaleFastNode(IO.ComfyNode):
|
||||
"""
|
||||
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleFastNode",
|
||||
display_name="Stability AI Upscale Fast",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.02}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, image: torch.Tensor) -> IO.NodeOutput:
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return IO.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityTextToAudio(IO.ComfyNode):
|
||||
"""Generates high-quality music and sound effects from text descriptions."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityTextToAudio",
|
||||
display_name="Stability AI Text To Audio",
|
||||
category="partner/audio/Stability AI",
|
||||
essentials_category="Audio",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["stable-audio-2.5"],
|
||||
),
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=190,
|
||||
min=1,
|
||||
max=190,
|
||||
step=1,
|
||||
tooltip="Controls the duration in seconds of the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=8,
|
||||
min=4,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="Controls the number of sampling steps.",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Audio.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.2}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
|
||||
response_model=StabilityAudioResponse,
|
||||
data=payload,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
|
||||
class StabilityAudioToAudio(IO.ComfyNode):
|
||||
"""Transforms existing audio samples into new high-quality compositions using text instructions."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityAudioToAudio",
|
||||
display_name="Stability AI Audio To Audio",
|
||||
category="partner/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["stable-audio-2.5"],
|
||||
),
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=190,
|
||||
min=1,
|
||||
max=190,
|
||||
step=1,
|
||||
tooltip="Controls the duration in seconds of the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=8,
|
||||
min=4,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="Controls the number of sampling steps.",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"strength",
|
||||
default=1,
|
||||
min=0.01,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Parameter controls how much influence the audio parameter has on the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Audio.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.2}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
validate_audio_duration(audio, 6, 190)
|
||||
payload = StabilityAudioToAudioRequest(
|
||||
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
|
||||
)
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
|
||||
response_model=StabilityAudioResponse,
|
||||
data=payload,
|
||||
content_type="multipart/form-data",
|
||||
files={"audio": audio_input_to_mp3(audio)},
|
||||
)
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
|
||||
class StabilityAudioInpaint(IO.ComfyNode):
|
||||
"""Transforms part of existing audio sample using text instructions."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="StabilityAudioInpaint",
|
||||
display_name="Stability AI Audio Inpaint",
|
||||
category="partner/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["stable-audio-2.5"],
|
||||
),
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=190,
|
||||
min=1,
|
||||
max=190,
|
||||
step=1,
|
||||
tooltip="Controls the duration in seconds of the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=8,
|
||||
min=4,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="Controls the number of sampling steps.",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"mask_start",
|
||||
default=30,
|
||||
min=0,
|
||||
max=190,
|
||||
step=1,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"mask_end",
|
||||
default=190,
|
||||
min=0,
|
||||
max=190,
|
||||
step=1,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Audio.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.2}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
audio: Input.Audio,
|
||||
duration: int,
|
||||
seed: int,
|
||||
steps: int,
|
||||
mask_start: int,
|
||||
mask_end: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
if mask_end <= mask_start:
|
||||
raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})")
|
||||
validate_audio_duration(audio, 6, 190)
|
||||
|
||||
payload = StabilityAudioInpaintRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
steps=steps,
|
||||
mask_start=mask_start,
|
||||
mask_end=mask_end,
|
||||
)
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
|
||||
response_model=StabilityAudioResponse,
|
||||
data=payload,
|
||||
content_type="multipart/form-data",
|
||||
files={"audio": audio_input_to_mp3(audio)},
|
||||
)
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
|
||||
class StabilityExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
StabilityStableImageUltraNode,
|
||||
StabilityStableImageSD_3_5Node,
|
||||
StabilityUpscaleConservativeNode,
|
||||
StabilityUpscaleCreativeNode,
|
||||
StabilityUpscaleFastNode,
|
||||
StabilityTextToAudio,
|
||||
StabilityAudioToAudio,
|
||||
StabilityAudioInpaint,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> StabilityExtension:
|
||||
return StabilityExtension()
|
||||
@ -26,6 +26,7 @@ from .conversions import (
|
||||
text_filepath_to_base64_string,
|
||||
text_filepath_to_data_uri,
|
||||
trim_video,
|
||||
upscale_image_tensor_to_min_pixels,
|
||||
upscale_video_to_min_pixels,
|
||||
video_to_base64_string,
|
||||
)
|
||||
@ -99,6 +100,7 @@ __all__ = [
|
||||
"text_filepath_to_base64_string",
|
||||
"text_filepath_to_data_uri",
|
||||
"trim_video",
|
||||
"upscale_image_tensor_to_min_pixels",
|
||||
"upscale_video_to_min_pixels",
|
||||
"video_to_base64_string",
|
||||
# Validation utilities
|
||||
|
||||
@ -448,6 +448,15 @@ def _compute_upscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[in
|
||||
return new_w, new_h
|
||||
|
||||
|
||||
def upscale_image_tensor_to_min_pixels(image: torch.Tensor, total_pixels: int) -> torch.Tensor:
|
||||
samples = image.movedim(-1, 1)
|
||||
dims = _compute_upscale_dims(samples.shape[3], samples.shape[2], int(total_pixels))
|
||||
if dims is None:
|
||||
return image
|
||||
new_w, new_h = dims
|
||||
return common_upscale(samples, new_w, new_h, "lanczos", "disabled").movedim(1, -1)
|
||||
|
||||
|
||||
def upscale_video_to_min_pixels(video: Input.Video, min_pixels: int) -> Input.Video:
|
||||
"""Upscale a video to meet at least ``min_pixels`` (w * h), preserving aspect ratio.
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import bisect
|
||||
import gc
|
||||
import itertools
|
||||
import psutil
|
||||
import time
|
||||
@ -528,6 +529,38 @@ class RAMPressureCache(LRUCache):
|
||||
if psutil.virtual_memory().available >= target:
|
||||
return
|
||||
|
||||
def remove_cache_key(key):
|
||||
del self.cache[key]
|
||||
self.used_generation.pop(key, None)
|
||||
self.timestamps.pop(key, None)
|
||||
self.children.pop(key, None)
|
||||
|
||||
def has_old_model_patcher(outputs):
|
||||
if outputs is None:
|
||||
return False
|
||||
for output in outputs:
|
||||
if isinstance(output, (list, tuple)):
|
||||
if has_old_model_patcher(output):
|
||||
return True
|
||||
elif isinstance(output, ModelPatcher):
|
||||
return True
|
||||
return False
|
||||
|
||||
old_modelpatcher_keys = []
|
||||
for key, cache_entry in self.cache.items():
|
||||
if self.used_generation[key] == self.generation:
|
||||
continue
|
||||
if has_old_model_patcher(cache_entry.outputs):
|
||||
old_modelpatcher_keys.append(key)
|
||||
|
||||
for key in old_modelpatcher_keys:
|
||||
remove_cache_key(key)
|
||||
|
||||
if old_modelpatcher_keys:
|
||||
gc.collect()
|
||||
if psutil.virtual_memory().available >= target:
|
||||
return
|
||||
|
||||
clean_list = []
|
||||
|
||||
for key, cache_entry in self.cache.items():
|
||||
@ -545,19 +578,17 @@ class RAMPressureCache(LRUCache):
|
||||
scan_list_for_ram_usage(output)
|
||||
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
||||
ram_usage += output.numel() * output.element_size()
|
||||
elif isinstance(output, ModelPatcher) and self.used_generation[key] != self.generation:
|
||||
#old ModelPatchers are the first to go
|
||||
ram_usage = 1e30
|
||||
scan_list_for_ram_usage(cache_entry.outputs)
|
||||
|
||||
oom_score *= ram_usage
|
||||
#In the case where we have no information on the node ram usage at all,
|
||||
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], ram_usage, key))
|
||||
|
||||
while psutil.virtual_memory().available < target and clean_list:
|
||||
_, _, key = clean_list.pop()
|
||||
del self.cache[key]
|
||||
self.used_generation.pop(key, None)
|
||||
self.timestamps.pop(key, None)
|
||||
self.children.pop(key, None)
|
||||
to_free = target - psutil.virtual_memory().available
|
||||
while to_free > 0 and clean_list:
|
||||
_, _, ram_usage, key = clean_list.pop()
|
||||
remove_cache_key(key)
|
||||
to_free -= ram_usage
|
||||
|
||||
gc.collect()
|
||||
|
||||
@ -8,7 +8,8 @@ class CLIPTextEncodeControlnet(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeControlnet",
|
||||
category="experimental/conditioning",
|
||||
display_name="CLIP Text Encode (Controlnet)",
|
||||
category="model/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Conditioning.Input("conditioning"),
|
||||
@ -35,11 +36,12 @@ class T5TokenizerOptions(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="T5TokenizerOptions",
|
||||
category="experimental/conditioning",
|
||||
display_name="T5 Tokenizer Options",
|
||||
category="model/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),
|
||||
io.Int.Input("min_length", default=0, min=0, max=10000, step=1, advanced=True),
|
||||
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1),
|
||||
io.Int.Input("min_length", default=0, min=0, max=10000, step=1),
|
||||
],
|
||||
outputs=[io.Clip.Output()],
|
||||
is_experimental=True,
|
||||
|
||||
@ -1070,7 +1070,7 @@ class AddNoise(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="AddNoise",
|
||||
category="experimental/custom_sampling/noise",
|
||||
category="model/sampling/noise",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
@ -1120,7 +1120,7 @@ class ManualSigmas(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ManualSigmas",
|
||||
search_aliases=["custom noise schedule", "define sigmas"],
|
||||
category="experimental/custom_sampling",
|
||||
category="model/sampling/sigmas",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.String.Input("sigmas", default="1, 0.5", multiline=False)
|
||||
|
||||
@ -123,7 +123,8 @@ class PhotoMakerLoader(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PhotoMakerLoader",
|
||||
category="experimental/photomaker",
|
||||
display_name="Load PhotoMaker Model",
|
||||
category="model/loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
|
||||
],
|
||||
@ -149,7 +150,8 @@ class PhotoMakerEncode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PhotoMakerEncode",
|
||||
category="experimental/photomaker",
|
||||
display_name="PhotoMaker Encode",
|
||||
category="model/conditioning/photomaker",
|
||||
inputs=[
|
||||
io.Photomaker.Input("photomaker"),
|
||||
io.Image.Input("image"),
|
||||
|
||||
@ -119,7 +119,7 @@ class StableCascade_SuperResolutionControlnet(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_SuperResolutionControlnet",
|
||||
category="experimental/stable_cascade",
|
||||
category="experimental/stable cascade",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
|
||||
@ -143,7 +143,7 @@ class VAEDecodeTripoSplat(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeTripoSplat",
|
||||
display_name="TripoSplat Decode",
|
||||
category="3d/latent",
|
||||
category="model/latent/triposplat",
|
||||
description="Decode the sampled TripoSplat latent into a 3D gaussian splat. "
|
||||
"Modify the number of gaussians to vary the density.",
|
||||
inputs=[
|
||||
@ -188,7 +188,7 @@ class TripoSplatSamplingPreview(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoSplatSamplingPreview",
|
||||
display_name="TripoSplat Sampling Preview",
|
||||
category="3d/latent",
|
||||
category="model/latent/triposplat",
|
||||
description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded "
|
||||
"gaussian splat preview at each step.",
|
||||
inputs=[
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.26.0"
|
||||
__version__ = "0.27.0"
|
||||
|
||||
@ -264,6 +264,59 @@ def annotated_filepath(name: str) -> tuple[str, str | None]:
|
||||
return name, base_dir
|
||||
|
||||
|
||||
# Content types a browser may execute or render inline. File endpoints that
|
||||
# serve user-controlled content must force these to download (and ideally set
|
||||
# Content-Disposition: attachment) to avoid stored XSS. Centralised here so the
|
||||
# /view and /userdata handlers can't drift apart. mimetypes.guess_type may
|
||||
# return either the text/* or application/* spelling depending on platform, so
|
||||
# both are listed.
|
||||
DANGEROUS_CONTENT_TYPES = {
|
||||
'text/html', 'text/html-sandboxed', 'application/xhtml+xml',
|
||||
'text/javascript', 'application/javascript', 'application/x-javascript',
|
||||
'application/ecmascript', 'text/css',
|
||||
'image/svg+xml', 'application/xml', 'text/xml',
|
||||
# message/rfc822 (.mht/.mhtml) can carry script in some browsers.
|
||||
'message/rfc822',
|
||||
}
|
||||
|
||||
|
||||
def is_dangerous_content_type(content_type: str | None) -> bool:
|
||||
"""Return True if a browser may execute or render `content_type` inline.
|
||||
|
||||
Normalises before matching so the check can't be slipped past with a
|
||||
charset/boundary parameter (``text/html; charset=utf-8``) or casing
|
||||
(``TEXT/HTML``). Any XML dialect (``*+xml`` or ``*/xml``) is treated as
|
||||
dangerous because XML can carry inline script via stylesheet/entity tricks,
|
||||
which also covers the ``application/{xslt,rss,atom,rdf}+xml`` family without
|
||||
enumerating each one. Endpoints serving user-controlled content should route
|
||||
a dangerous type to ``application/octet-stream`` + ``Content-Disposition:
|
||||
attachment`` + ``X-Content-Type-Options: nosniff``.
|
||||
"""
|
||||
if not content_type:
|
||||
return False
|
||||
normalized = content_type.split(';', 1)[0].strip().lower()
|
||||
if normalized in DANGEROUS_CONTENT_TYPES:
|
||||
return True
|
||||
return normalized.endswith('+xml') or normalized.endswith('/xml')
|
||||
|
||||
|
||||
def is_within_directory(directory: str, target: str) -> bool:
|
||||
"""Return True if `target` resolves to a path inside `directory`.
|
||||
|
||||
Uses realpath on both operands so that a symlink placed inside `directory`
|
||||
that points elsewhere cannot escape the containment check at open time.
|
||||
"""
|
||||
try:
|
||||
directory = os.path.realpath(directory)
|
||||
target = os.path.realpath(target)
|
||||
return os.path.commonpath((directory, target)) == directory
|
||||
except ValueError:
|
||||
# ValueError is raised by realpath() on a path with an embedded null
|
||||
# byte, and by commonpath() on Windows when the paths are on different
|
||||
# drives. In either case the target is not safely within the directory.
|
||||
return False
|
||||
|
||||
|
||||
def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
|
||||
name, base_dir = annotated_filepath(name)
|
||||
|
||||
@ -273,7 +326,12 @@ def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
|
||||
else:
|
||||
base_dir = get_input_directory() # fallback path
|
||||
|
||||
return os.path.join(base_dir, name)
|
||||
filepath = os.path.abspath(os.path.join(base_dir, name))
|
||||
# Prevent path traversal: the resolved path must stay within base_dir.
|
||||
# repr() the name in the message so a crafted value can't inject log lines.
|
||||
if not is_within_directory(base_dir, filepath):
|
||||
raise ValueError("Invalid file path: {!r}".format(name))
|
||||
return filepath
|
||||
|
||||
|
||||
def exists_annotated_filepath(name) -> bool:
|
||||
@ -282,7 +340,10 @@ def exists_annotated_filepath(name) -> bool:
|
||||
if base_dir is None:
|
||||
base_dir = get_input_directory() # fallback path
|
||||
|
||||
filepath = os.path.join(base_dir, name)
|
||||
filepath = os.path.abspath(os.path.join(base_dir, name))
|
||||
# Treat traversal attempts as non-existent rather than probing the filesystem.
|
||||
if not is_within_directory(base_dir, filepath):
|
||||
return False
|
||||
return os.path.exists(filepath)
|
||||
|
||||
|
||||
|
||||
6
main.py
6
main.py
@ -314,7 +314,7 @@ def prompt_worker(q, server_instance):
|
||||
cache_ram = 0
|
||||
cache_ram_inactive = 0
|
||||
if not args.cache_classic and not args.cache_none and args.cache_lru <= 0:
|
||||
cache_ram = min(10.0, max(2.0, comfy.model_management.total_ram * 0.10 / 1024.0))
|
||||
cache_ram = min(10.0, max(1.5, comfy.model_management.total_ram * 0.05 / 1024.0))
|
||||
cache_ram_inactive = min(96.0, comfy.model_management.total_ram / 1024.0)
|
||||
if len(args.cache_ram) > 0:
|
||||
cache_ram = args.cache_ram[0]
|
||||
@ -403,7 +403,7 @@ def prompt_worker(q, server_instance):
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
if not asset_seeder.is_disabled():
|
||||
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=args.enable_asset_hashing)
|
||||
asset_seeder.resume()
|
||||
|
||||
|
||||
@ -458,7 +458,7 @@ def setup_database():
|
||||
if dependencies_available():
|
||||
init_db()
|
||||
if args.enable_assets:
|
||||
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True):
|
||||
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=args.enable_asset_hashing):
|
||||
logging.info("Background asset scan initiated for models, input, output")
|
||||
except Exception as e:
|
||||
if "database is locked" in str(e):
|
||||
|
||||
11
nodes.py
11
nodes.py
@ -349,7 +349,7 @@ class VAEDecodeTiled:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "experimental"
|
||||
CATEGORY = "model/latent"
|
||||
|
||||
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
|
||||
if tile_size < overlap * 4:
|
||||
@ -396,7 +396,7 @@ class VAEEncodeTiled:
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "experimental"
|
||||
CATEGORY = "model/latent"
|
||||
|
||||
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
|
||||
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||
@ -514,7 +514,7 @@ class SaveLatent:
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "experimental"
|
||||
CATEGORY = "model/latent"
|
||||
|
||||
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
@ -559,7 +559,7 @@ class LoadLatent:
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
|
||||
return {"required": {"latent": [sorted(files), ]}, }
|
||||
|
||||
CATEGORY = "experimental"
|
||||
CATEGORY = "model/latent"
|
||||
|
||||
RETURN_TYPES = ("LATENT", )
|
||||
FUNCTION = "load"
|
||||
@ -2155,6 +2155,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"GLIGENTextBoxApply": "Apply GLIGEN Text Box",
|
||||
"ConditioningZeroOut": "Conditioning Zero Out",
|
||||
# Latent
|
||||
"LoadLatent": "Load Latent",
|
||||
"SaveLatent": "Save Latent",
|
||||
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
||||
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
||||
"VAEDecode": "VAE Decode",
|
||||
@ -2189,7 +2191,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageSharpen": "Sharpen Image",
|
||||
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||
"GetImageSize": "Get Image Size",
|
||||
# experimental
|
||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||
}
|
||||
|
||||
546
openapi.yaml
546
openapi.yaml
@ -230,93 +230,6 @@ components:
|
||||
- base_version
|
||||
- workflow_json
|
||||
type: object
|
||||
DownloadEnqueueRequest:
|
||||
description: Request body for enqueuing a server-side model download.
|
||||
properties:
|
||||
allow_any_extension:
|
||||
default: false
|
||||
description: Permit a non-model file extension (default only allows known model extensions).
|
||||
type: boolean
|
||||
credential_id:
|
||||
description: Explicit per-host credential to use; otherwise auto-resolved by host. Still subject to the per-hop host match.
|
||||
nullable: true
|
||||
type: string
|
||||
expected_sha256:
|
||||
description: Optional hub-provided SHA256 to verify the completed file against (fail-closed).
|
||||
nullable: true
|
||||
type: string
|
||||
model_id:
|
||||
description: Destination as "<directory>/<filename>", resolving to a registered model folder (e.g. "loras/my_lora.safetensors").
|
||||
type: string
|
||||
priority:
|
||||
default: 0
|
||||
description: Scheduling priority; higher is admitted first.
|
||||
type: integer
|
||||
url:
|
||||
description: Source URL; must be on the allowlist (host + scheme + extension).
|
||||
type: string
|
||||
required:
|
||||
- url
|
||||
- model_id
|
||||
type: object
|
||||
DownloadStatus:
|
||||
description: Current state and live progress of a single download.
|
||||
properties:
|
||||
bytes_done:
|
||||
type: integer
|
||||
created_at:
|
||||
type: integer
|
||||
download_id:
|
||||
format: uuid
|
||||
type: string
|
||||
error:
|
||||
nullable: true
|
||||
type: string
|
||||
eta_seconds:
|
||||
nullable: true
|
||||
type: number
|
||||
model_id:
|
||||
type: string
|
||||
priority:
|
||||
type: integer
|
||||
progress:
|
||||
description: Fraction in [0,1]; null until total size is known.
|
||||
nullable: true
|
||||
type: number
|
||||
segments:
|
||||
description: Per-segment progress (segmented downloads only).
|
||||
items:
|
||||
properties:
|
||||
bytes_done:
|
||||
type: integer
|
||||
idx:
|
||||
type: integer
|
||||
length:
|
||||
type: integer
|
||||
type: object
|
||||
nullable: true
|
||||
type: array
|
||||
speed_bps:
|
||||
nullable: true
|
||||
type: number
|
||||
status:
|
||||
enum:
|
||||
- queued
|
||||
- active
|
||||
- paused
|
||||
- verifying
|
||||
- completed
|
||||
- failed
|
||||
- cancelled
|
||||
type: string
|
||||
total_bytes:
|
||||
nullable: true
|
||||
type: integer
|
||||
updated_at:
|
||||
type: integer
|
||||
url:
|
||||
type: string
|
||||
type: object
|
||||
ErrorResponse:
|
||||
description: Standard error response with a machine-readable code and human-readable message.
|
||||
properties:
|
||||
@ -598,78 +511,6 @@ components:
|
||||
required:
|
||||
- history
|
||||
type: object
|
||||
HostCredentialUpsert:
|
||||
description: Request body for upserting a per-host credential. The secret is write-only.
|
||||
properties:
|
||||
auth_scheme:
|
||||
default: bearer
|
||||
description: How the secret is attached to requests.
|
||||
enum:
|
||||
- bearer
|
||||
- header
|
||||
- query
|
||||
type: string
|
||||
enabled:
|
||||
default: true
|
||||
type: boolean
|
||||
header_name:
|
||||
description: Header name when auth_scheme=header (defaults to Authorization).
|
||||
nullable: true
|
||||
type: string
|
||||
host:
|
||||
description: Normalized hostname the key applies to (e.g. "civitai.com").
|
||||
type: string
|
||||
label:
|
||||
description: User-friendly name for display.
|
||||
nullable: true
|
||||
type: string
|
||||
match_subdomains:
|
||||
default: false
|
||||
description: Also match label-boundary subdomains of host (off by default; unsafe for hub CDNs).
|
||||
type: boolean
|
||||
query_param:
|
||||
description: Query parameter name when auth_scheme=query.
|
||||
nullable: true
|
||||
type: string
|
||||
secret:
|
||||
description: The API key. Write-only — never returned by any endpoint.
|
||||
type: string
|
||||
required:
|
||||
- host
|
||||
- secret
|
||||
type: object
|
||||
HostCredentialView:
|
||||
description: Masked, API-safe view of a stored credential. Never includes the secret.
|
||||
properties:
|
||||
auth_scheme:
|
||||
type: string
|
||||
created_at:
|
||||
type: integer
|
||||
enabled:
|
||||
type: boolean
|
||||
header_name:
|
||||
nullable: true
|
||||
type: string
|
||||
host:
|
||||
type: string
|
||||
id:
|
||||
format: uuid
|
||||
type: string
|
||||
label:
|
||||
nullable: true
|
||||
type: string
|
||||
match_subdomains:
|
||||
type: boolean
|
||||
query_param:
|
||||
nullable: true
|
||||
type: string
|
||||
secret_last4:
|
||||
description: Last 4 characters of the secret, for masked display only.
|
||||
nullable: true
|
||||
type: string
|
||||
updated_at:
|
||||
type: integer
|
||||
type: object
|
||||
JobCancelResponse:
|
||||
description: Response for POST /api/jobs/{job_id}/cancel. Returned on both fresh cancels and idempotent no-ops.
|
||||
properties:
|
||||
@ -2509,391 +2350,6 @@ paths:
|
||||
summary: Get tag histogram for filtered assets
|
||||
tags:
|
||||
- file
|
||||
/api/download:
|
||||
get:
|
||||
description: List all known downloads (queued, active, paused, and terminal) with live progress.
|
||||
operationId: listDownloads
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
downloads:
|
||||
items:
|
||||
$ref: '#/components/schemas/DownloadStatus'
|
||||
type: array
|
||||
type: object
|
||||
description: List of downloads
|
||||
summary: List downloads
|
||||
tags:
|
||||
- download
|
||||
/api/download/availability:
|
||||
post:
|
||||
description: |
|
||||
Bulk per-id availability for a set of model_ids declared in a workflow.
|
||||
Returns whether each model is available on disk, currently downloading
|
||||
(with progress), or missing, plus whether its URL is on the allowlist.
|
||||
operationId: getModelsAvailability
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
models:
|
||||
additionalProperties:
|
||||
type: string
|
||||
description: Map of "<directory>/<filename>" model_id to its declared source URL.
|
||||
type: object
|
||||
type: object
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
models:
|
||||
additionalProperties: true
|
||||
type: object
|
||||
type: object
|
||||
description: Per-id availability map
|
||||
summary: Bulk model availability + status
|
||||
tags:
|
||||
- download
|
||||
/api/download/clear:
|
||||
post:
|
||||
description: |
|
||||
Delete all terminal downloads (completed, failed, cancelled) from history
|
||||
in one transaction, so the cleared history persists across reloads. Live
|
||||
downloads (queued, active, paused, verifying) are skipped. Finished model
|
||||
files on disk are never removed; only leftover .part temp files are cleaned up.
|
||||
operationId: clearDownloads
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
deleted:
|
||||
description: Number of history rows removed.
|
||||
type: integer
|
||||
type: object
|
||||
description: History cleared
|
||||
summary: Clear terminal downloads from history
|
||||
tags:
|
||||
- download
|
||||
/api/download/credentials:
|
||||
get:
|
||||
description: List stored per-host credentials. Secrets are never returned; only masked metadata (last 4 chars, scheme, label).
|
||||
operationId: listDownloadCredentials
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
credentials:
|
||||
items:
|
||||
$ref: '#/components/schemas/HostCredentialView'
|
||||
type: array
|
||||
type: object
|
||||
description: Masked credential list
|
||||
summary: List host credentials (masked)
|
||||
tags:
|
||||
- download
|
||||
post:
|
||||
description: |
|
||||
Upsert (by host) a per-host API key used to authenticate downloads.
|
||||
The secret is write-only: it is stored once here and never returned by any endpoint.
|
||||
operationId: upsertDownloadCredential
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HostCredentialUpsert'
|
||||
responses:
|
||||
"201":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HostCredentialView'
|
||||
description: Credential stored (masked view returned)
|
||||
"400":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Invalid credential
|
||||
summary: Upsert a host credential
|
||||
tags:
|
||||
- download
|
||||
/api/download/credentials/{id}:
|
||||
delete:
|
||||
description: Delete a stored host credential.
|
||||
operationId: deleteDownloadCredential
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
deleted:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Deleted
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: No such credential
|
||||
summary: Delete a host credential
|
||||
tags:
|
||||
- download
|
||||
get:
|
||||
description: Get a single host credential (masked; never includes the secret).
|
||||
operationId: getDownloadCredential
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HostCredentialView'
|
||||
description: Masked credential
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: No such credential
|
||||
summary: Get a host credential (masked)
|
||||
tags:
|
||||
- download
|
||||
/api/download/enqueue:
|
||||
post:
|
||||
description: |
|
||||
Enqueue a server-side model download. The URL must be on the allowlist
|
||||
(host + scheme + extension) and the model_id must be "<directory>/<filename>"
|
||||
resolving to a registered model folder. Returns immediately; track progress
|
||||
via GET /api/download/{id} or the "download_progress" websocket event.
|
||||
operationId: enqueueDownload
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/DownloadEnqueueRequest'
|
||||
responses:
|
||||
"202":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
accepted:
|
||||
type: boolean
|
||||
download_id:
|
||||
format: uuid
|
||||
type: string
|
||||
type: object
|
||||
description: Download accepted and queued
|
||||
"400":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Invalid request (bad URL, model_id, or not allowlisted)
|
||||
"409":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Already on disk or already downloading
|
||||
summary: Enqueue a model download
|
||||
tags:
|
||||
- download
|
||||
/api/download/{id}:
|
||||
delete:
|
||||
description: |
|
||||
Delete a single terminal download from history so it stays gone across
|
||||
reloads. Refuses (409) to delete a live download (queued, active, paused,
|
||||
verifying) — cancel it first. The finished model file on disk is never
|
||||
removed; only a leftover .part temp file is cleaned up.
|
||||
operationId: deleteDownload
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
deleted:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Deleted
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: No such download
|
||||
"409":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Download is still in progress
|
||||
summary: Delete a download from history
|
||||
tags:
|
||||
- download
|
||||
get:
|
||||
description: Get the current status + progress of a single download.
|
||||
operationId: getDownloadStatus
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/DownloadStatus'
|
||||
description: Download status
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: No such download
|
||||
summary: Get download status
|
||||
tags:
|
||||
- download
|
||||
/api/download/{id}/cancel:
|
||||
post:
|
||||
description: Cancel a download. The partial file is removed.
|
||||
operationId: cancelDownload
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
ok:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Cancelled
|
||||
summary: Cancel a download
|
||||
tags:
|
||||
- download
|
||||
/api/download/{id}/pause:
|
||||
post:
|
||||
description: Pause a download. The partial file and per-segment offsets are retained for resume.
|
||||
operationId: pauseDownload
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
ok:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Paused
|
||||
summary: Pause a download
|
||||
tags:
|
||||
- download
|
||||
/api/download/{id}/priority:
|
||||
post:
|
||||
description: Set a download's scheduling priority. Higher priority is admitted first when a slot frees.
|
||||
operationId: setDownloadPriority
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
priority:
|
||||
type: integer
|
||||
required:
|
||||
- priority
|
||||
type: object
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
ok:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Priority updated
|
||||
summary: Set download priority
|
||||
tags:
|
||||
- download
|
||||
/api/download/{id}/resume:
|
||||
post:
|
||||
description: Resume a paused (or failed) download from its persisted offsets.
|
||||
operationId: resumeDownload
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
ok:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Resumed
|
||||
summary: Resume a download
|
||||
tags:
|
||||
- download
|
||||
/api/embeddings:
|
||||
get:
|
||||
description: Returns the list of text-encoder embeddings available on disk.
|
||||
@ -5647,5 +5103,3 @@ tags:
|
||||
name: queue
|
||||
- description: Job lifecycle queries
|
||||
name: job
|
||||
- description: Model download management
|
||||
name: download
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.26.0"
|
||||
version = "0.27.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.45.20
|
||||
comfyui-workflow-templates==0.10.7
|
||||
comfyui-workflow-templates==0.11.2
|
||||
comfyui-embedded-docs==0.5.6
|
||||
torch
|
||||
torchsde
|
||||
@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=16.0.0
|
||||
comfy-kitchen==0.2.15
|
||||
comfy-kitchen==0.2.16
|
||||
comfy-aimdo==0.4.10
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
|
||||
52
server.py
52
server.py
@ -45,8 +45,6 @@ from app.frontend_management import FrontendManager, parse_version
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.api.routes import register_assets_routes
|
||||
from app.model_downloader.api.routes import register_routes as register_model_downloader_routes
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER
|
||||
from app.assets.services.ingest import register_file_in_place
|
||||
from app.assets.services.asset_management import resolve_hash_to_path
|
||||
|
||||
@ -129,6 +127,7 @@ def create_cors_middleware(allowed_origin: str):
|
||||
|
||||
return cors_middleware
|
||||
|
||||
|
||||
def is_loopback(host):
|
||||
if host is None:
|
||||
return False
|
||||
@ -258,7 +257,6 @@ class PromptServer():
|
||||
else:
|
||||
register_assets_routes(self.app)
|
||||
asset_seeder.disable()
|
||||
register_model_downloader_routes(self.app)
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
@ -619,15 +617,30 @@ class PromptServer():
|
||||
or 'application/octet-stream'
|
||||
)
|
||||
|
||||
# For security, force certain mimetypes to download instead of display
|
||||
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
||||
content_type = 'application/octet-stream' # Forces download
|
||||
# For security, force renderable/active types (HTML, JS,
|
||||
# CSS, SVG, XML — anything that can carry inline <script>
|
||||
# and execute in the page origin) to download instead of
|
||||
# displaying inline, preventing stored XSS. The
|
||||
# attachment disposition is the load-bearing guard: a
|
||||
# bare filename= hint does not force a download per
|
||||
# RFC 6266, so we only attach it on the dangerous branch
|
||||
# to avoid breaking inline display of legitimate images.
|
||||
# Escape backslash/quote per RFC 6266 quoted-string so a
|
||||
# filename containing a double quote (which passes the
|
||||
# ".."/leading-slash filter above) can't break out of the
|
||||
# header's quoted-string and malform the disposition.
|
||||
safe_filename = filename.replace("\\", "\\\\").replace('"', '\\"')
|
||||
disposition = f"filename=\"{safe_filename}\""
|
||||
if folder_paths.is_dangerous_content_type(content_type):
|
||||
content_type = 'application/octet-stream'
|
||||
disposition = f"attachment; filename=\"{safe_filename}\""
|
||||
|
||||
return web.FileResponse(
|
||||
file,
|
||||
headers={
|
||||
"Content-Disposition": f"filename=\"{filename}\"",
|
||||
"Content-Type": content_type
|
||||
"Content-Disposition": disposition,
|
||||
"Content-Type": content_type,
|
||||
"X-Content-Type-Options": "nosniff"
|
||||
}
|
||||
)
|
||||
|
||||
@ -1185,29 +1198,6 @@ class PromptServer():
|
||||
async def setup(self):
|
||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||
await self._setup_model_downloader()
|
||||
|
||||
async def _setup_model_downloader(self):
|
||||
"""Start the download manager: push progress over the websocket and
|
||||
resume any downloads interrupted by a previous run."""
|
||||
def _notify(download_id: str) -> None:
|
||||
try:
|
||||
view = DOWNLOAD_MANAGER.status_sync(download_id)
|
||||
if view is not None:
|
||||
# Drop the url field before broadcasting: the redacted URL
|
||||
# (scheme + host + path) should not leak to every connected
|
||||
# websocket client. download_id / model_id are sufficient to
|
||||
# correlate progress on the frontend.
|
||||
broadcast = {k: v for k, v in view.items() if k != "url"}
|
||||
self.send_sync("download_progress", broadcast)
|
||||
except Exception:
|
||||
logging.debug("download progress notify failed", exc_info=True)
|
||||
|
||||
DOWNLOAD_MANAGER.set_notify(_notify)
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.start()
|
||||
except Exception as e:
|
||||
logging.warning("Failed to start model download manager: %s", e)
|
||||
|
||||
def add_routes(self):
|
||||
self.user_manager.add_routes(self.routes)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import contextlib
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
@ -9,6 +11,40 @@ import requests
|
||||
from helpers import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
def test_download_svg_forced_to_attachment(http: requests.Session, api_base: str):
|
||||
"""GHSA-779p-m5rp-r4h4 CISA-5 (sibling route): an uploaded SVG must never be
|
||||
served inline from GET /api/assets/{id}/content, or an inline <script> runs
|
||||
in the app origin (stored XSS). Even with disposition=inline requested, a
|
||||
dangerous content type must be forced to application/octet-stream +
|
||||
Content-Disposition: attachment + nosniff. Regression guard for the stale
|
||||
inline blocklist that previously omitted image/svg+xml and ignored the
|
||||
centralized folder_paths.is_dangerous_content_type check.
|
||||
"""
|
||||
svg = b'<svg xmlns="http://www.w3.org/2000/svg"><script>alert(1)</script></svg>'
|
||||
files = {"file": ("evil.svg", svg, "image/svg+xml")}
|
||||
form_data = {
|
||||
"tags": json.dumps(["models", "checkpoints", "unit-tests", "svgxss"]),
|
||||
"name": "evil.svg",
|
||||
}
|
||||
up = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||
body = up.json()
|
||||
assert up.status_code in (200, 201), body
|
||||
aid = body["id"]
|
||||
try:
|
||||
r = http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline", timeout=120)
|
||||
r.content
|
||||
assert r.status_code == 200
|
||||
ct = r.headers.get("Content-Type", "").lower()
|
||||
cd = r.headers.get("Content-Disposition", "").lower()
|
||||
assert "svg" not in ct, f"SVG served with a renderable content type: {ct!r}"
|
||||
assert ct.startswith("application/octet-stream"), f"expected octet-stream, got {ct!r}"
|
||||
assert "attachment" in cd, f"inline disposition not overridden to attachment: {cd!r}"
|
||||
assert r.headers.get("X-Content-Type-Options", "").lower() == "nosniff"
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||
|
||||
|
||||
def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
|
||||
@ -53,8 +53,11 @@ def test_annotated_filepath():
|
||||
|
||||
def test_get_annotated_filepath():
|
||||
default_dir = "/default/dir"
|
||||
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
|
||||
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
|
||||
# get_annotated_filepath now normalizes with os.path.abspath (part of the
|
||||
# GHSA-779p traversal hardening), so compare against the normalized form —
|
||||
# on Windows abspath also prepends the current drive letter.
|
||||
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.abspath(os.path.join(default_dir, "test.txt"))
|
||||
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.abspath(os.path.join(folder_paths.get_output_directory(), "test.txt"))
|
||||
|
||||
def test_add_model_folder_path_append(clear_folder_paths):
|
||||
folder_paths.add_model_folder_path("test_folder", "/default/path", is_default=True)
|
||||
|
||||
@ -1,90 +0,0 @@
|
||||
"""Shared fixtures for the model download manager tests.
|
||||
|
||||
These run in-process (no ComfyUI subprocess): a file-backed SQLite DB is
|
||||
initialized once, a temp model folder is registered with ``folder_paths``, and
|
||||
the shared aiohttp session is reset between tests so each async test gets a
|
||||
session bound to its own event loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _drain_scheduler_tasks(scheduler) -> None:
|
||||
"""Cancel and await live scheduler tasks so none outlive the test.
|
||||
|
||||
Uses the actual task handles rather than only clearing ``_tasks``: each
|
||||
per-test event loop is created by ``asyncio.run``, so a task left behind by
|
||||
a crashed/aborted test would otherwise keep its coroutine alive. We cancel
|
||||
every live task and, when its loop is still usable, run it to completion to
|
||||
let the cancellation propagate before dropping the reference.
|
||||
"""
|
||||
for task in list(scheduler._tasks.values()):
|
||||
if task is None:
|
||||
continue
|
||||
loop = task.get_loop()
|
||||
if task.done() or loop.is_closed():
|
||||
continue
|
||||
task.cancel()
|
||||
if not loop.is_running():
|
||||
try:
|
||||
loop.run_until_complete(asyncio.gather(task, return_exceptions=True))
|
||||
except Exception:
|
||||
pass
|
||||
scheduler._tasks.clear()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _init_db():
|
||||
import app.database.db as db
|
||||
from comfy.cli_args import args
|
||||
|
||||
fd, db_path = tempfile.mkstemp(suffix="-dlmgr-test.sqlite3")
|
||||
os.close(fd)
|
||||
args.database_url = f"sqlite:///{db_path}"
|
||||
db.init_db()
|
||||
yield
|
||||
try:
|
||||
os.remove(db_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_runtime():
|
||||
"""Reset module singletons that hold event-loop-bound or cross-test state."""
|
||||
import app.model_downloader.net.session as ns
|
||||
from app.model_downloader.scheduler import SCHEDULER
|
||||
|
||||
ns._session = None
|
||||
_drain_scheduler_tasks(SCHEDULER)
|
||||
SCHEDULER._jobs.clear()
|
||||
SCHEDULER._backoff_until.clear()
|
||||
SCHEDULER._started = False
|
||||
yield
|
||||
_drain_scheduler_tasks(SCHEDULER)
|
||||
ns._session = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_root(tmp_path):
|
||||
"""Register a temp 'loras' model folder and return its absolute path."""
|
||||
import folder_paths
|
||||
|
||||
root = tmp_path / "loras"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
saved = folder_paths.folder_names_and_paths.get("loras")
|
||||
folder_paths.folder_names_and_paths["loras"] = (
|
||||
[str(root)],
|
||||
{".safetensors", ".sft", ".ckpt", ".pt", ".pth"},
|
||||
)
|
||||
yield str(root)
|
||||
if saved is not None:
|
||||
folder_paths.folder_names_and_paths["loras"] = saved
|
||||
else:
|
||||
folder_paths.folder_names_and_paths.pop("loras", None)
|
||||
@ -1,166 +0,0 @@
|
||||
"""Unit tests for the credential store and the per-hop credential resolver.
|
||||
|
||||
Covers the critical rule: a secret is only ever attached when the current
|
||||
hop's host matches a stored credential, and never over a non-https hop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from app.model_downloader.credentials import resolver
|
||||
from app.model_downloader.credentials.store import (
|
||||
CREDENTIAL_STORE,
|
||||
CredentialValidationError,
|
||||
normalize_host,
|
||||
)
|
||||
from app.model_downloader.database.models import HostCredential
|
||||
|
||||
|
||||
# ----- pure host normalization + matching -----
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw,expected",
|
||||
[
|
||||
("Civitai.com", "civitai.com"),
|
||||
("HuggingFace.co:443", "huggingface.co"),
|
||||
(" Example.COM ", "example.com"),
|
||||
],
|
||||
)
|
||||
def test_normalize_host(raw, expected):
|
||||
assert normalize_host(raw) == expected
|
||||
|
||||
|
||||
def _cred(**kw) -> HostCredential:
|
||||
base = dict(
|
||||
id="x", host="civitai.com", match_subdomains=False, auth_scheme="bearer",
|
||||
secret="SECRET", enabled=True,
|
||||
)
|
||||
base.update(kw)
|
||||
return HostCredential(**base)
|
||||
|
||||
|
||||
def test_matches_exact_only_by_default():
|
||||
c = _cred(host="civitai.com")
|
||||
assert resolver._matches(c, "civitai.com") is True
|
||||
assert resolver._matches(c, "api.civitai.com") is False
|
||||
assert resolver._matches(c, "evil-civitai.com") is False
|
||||
|
||||
|
||||
def test_matches_subdomain_label_boundary():
|
||||
c = _cred(host="example.com", match_subdomains=True)
|
||||
assert resolver._matches(c, "api.example.com") is True
|
||||
assert resolver._matches(c, "example.com") is True
|
||||
# not a label boundary -> no match
|
||||
assert resolver._matches(c, "evil-example.com") is False
|
||||
|
||||
|
||||
def test_build_auth_shapes():
|
||||
assert resolver._build_auth(_cred(auth_scheme="bearer")).headers == {
|
||||
"Authorization": "Bearer SECRET"
|
||||
}
|
||||
assert resolver._build_auth(
|
||||
_cred(auth_scheme="header", header_name="X-Api-Key")
|
||||
).headers == {"X-Api-Key": "SECRET"}
|
||||
q = resolver._build_auth(_cred(auth_scheme="query", query_param="token"))
|
||||
assert q.query == {"token": "SECRET"}
|
||||
assert q.apply_to_url("https://civitai.com/x") == "https://civitai.com/x?token=SECRET"
|
||||
|
||||
|
||||
# ----- DB-backed store + resolver -----
|
||||
|
||||
|
||||
def test_store_upsert_is_write_only_and_masked():
|
||||
async def _run():
|
||||
view = await CREDENTIAL_STORE.upsert("civitai.com", "abcd1234", label="my key")
|
||||
# The view never carries the secret, only the last 4.
|
||||
assert not hasattr(view, "secret")
|
||||
assert view.secret_last4 == "1234"
|
||||
assert view.host == "civitai.com"
|
||||
listed = await CREDENTIAL_STORE.list()
|
||||
assert any(v.host == "civitai.com" for v in listed)
|
||||
await CREDENTIAL_STORE.delete(view.id)
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_query_scheme_requires_param():
|
||||
async def _run():
|
||||
with pytest.raises(CredentialValidationError):
|
||||
await CREDENTIAL_STORE.upsert("civitai.com", "k", auth_scheme="query")
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_resolver_never_crosses_host_boundary():
|
||||
async def _run():
|
||||
view = await CREDENTIAL_STORE.upsert("huggingface.co", "hf_secret_key")
|
||||
try:
|
||||
# matching host over https -> attached
|
||||
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
|
||||
assert auth is not None
|
||||
assert auth.headers["Authorization"] == "Bearer hf_secret_key"
|
||||
# CDN redirect host -> dropped
|
||||
assert await resolver.resolve_auth_for_hop("cdn-lfs.huggingface.co", "https") is None
|
||||
# non-https hop -> never attached
|
||||
assert await resolver.resolve_auth_for_hop("huggingface.co", "http") is None
|
||||
finally:
|
||||
await CREDENTIAL_STORE.delete(view.id)
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- env-based HF token fallback -----
|
||||
|
||||
|
||||
def test_env_token_fallback_attaches_when_no_db_credential(monkeypatch):
|
||||
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
|
||||
|
||||
async def _run():
|
||||
# exact host over https -> env token attached
|
||||
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
|
||||
assert auth is not None
|
||||
assert auth.headers["Authorization"] == "Bearer env_hf_token"
|
||||
# non-https hop -> never attached
|
||||
assert await resolver.resolve_auth_for_hop("huggingface.co", "http") is None
|
||||
# CDN redirect host -> dropped (exact-host only)
|
||||
assert await resolver.resolve_auth_for_hop("cdn-lfs.huggingface.co", "https") is None
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_env_token_secondary_var_is_honored(monkeypatch):
|
||||
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||
monkeypatch.setenv("HUGGING_FACE_HUB_TOKEN", "env_hub_token")
|
||||
|
||||
async def _run():
|
||||
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
|
||||
assert auth is not None
|
||||
assert auth.headers["Authorization"] == "Bearer env_hub_token"
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_db_credential_takes_precedence_over_env(monkeypatch):
|
||||
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
|
||||
|
||||
async def _run():
|
||||
view = await CREDENTIAL_STORE.upsert("huggingface.co", "db_secret_key")
|
||||
try:
|
||||
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
|
||||
assert auth is not None
|
||||
assert auth.headers["Authorization"] == "Bearer db_secret_key"
|
||||
finally:
|
||||
await CREDENTIAL_STORE.delete(view.id)
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_env_token_does_not_leak_into_explicit_path(monkeypatch):
|
||||
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
|
||||
|
||||
async def _run():
|
||||
# An explicit credential id that doesn't resolve must stay None; the env
|
||||
# fallback only applies to the auto-resolve branch.
|
||||
auth = await resolver.resolve_auth_for_hop(
|
||||
"huggingface.co", "https", explicit_credential_id="does-not-exist"
|
||||
)
|
||||
assert auth is None
|
||||
asyncio.run(_run())
|
||||
@ -1,136 +0,0 @@
|
||||
"""Unit tests for ``DownloadManager.delete`` and ``DownloadManager.clear``.
|
||||
|
||||
Deleting a terminal row must remove it from history for good (so it does not
|
||||
reappear on the next ``list``), leave live rows untouched, and clean up any
|
||||
leftover ``.part`` temp file without touching the finished model file.
|
||||
|
||||
``clear()`` is the bulk variant: it removes all terminal rows atomically, skips
|
||||
live ones, and returns the count of rows deleted.
|
||||
|
||||
Async methods are driven via ``asyncio.run`` so no pytest-asyncio plugin is
|
||||
required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
|
||||
|
||||
|
||||
def _insert(download_id: str, status: str, *, temp_path: str = "/tmp/none.part") -> None:
|
||||
queries.insert_download(
|
||||
{
|
||||
"id": download_id,
|
||||
"url": "https://huggingface.co/org/model.safetensors",
|
||||
"model_id": "loras/model.safetensors",
|
||||
"dest_path": "/tmp/model.safetensors",
|
||||
"temp_path": temp_path,
|
||||
"status": status,
|
||||
"priority": 0,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_delete_removes_terminal_row_from_history():
|
||||
_insert("done", DownloadStatus.COMPLETED)
|
||||
|
||||
asyncio.run(DOWNLOAD_MANAGER.delete("done"))
|
||||
|
||||
assert queries.get_download("done") is None
|
||||
|
||||
|
||||
def test_delete_refuses_live_row():
|
||||
_insert("live", DownloadStatus.QUEUED)
|
||||
|
||||
with pytest.raises(DownloadError) as excinfo:
|
||||
asyncio.run(DOWNLOAD_MANAGER.delete("live"))
|
||||
|
||||
assert excinfo.value.code == "DOWNLOAD_ACTIVE"
|
||||
assert queries.get_download("live") is not None
|
||||
|
||||
|
||||
def test_delete_missing_row_raises_not_found():
|
||||
with pytest.raises(DownloadError) as excinfo:
|
||||
asyncio.run(DOWNLOAD_MANAGER.delete("nope"))
|
||||
|
||||
assert excinfo.value.code == "NOT_FOUND"
|
||||
|
||||
|
||||
def test_delete_removes_leftover_temp_file(tmp_path):
|
||||
partial = tmp_path / "model.safetensors.part"
|
||||
partial.write_bytes(b"partial")
|
||||
_insert("failed", DownloadStatus.FAILED, temp_path=str(partial))
|
||||
|
||||
asyncio.run(DOWNLOAD_MANAGER.delete("failed"))
|
||||
|
||||
assert not os.path.exists(partial)
|
||||
assert queries.get_download("failed") is None
|
||||
|
||||
|
||||
# ----- clear -----
|
||||
|
||||
|
||||
def test_clear_removes_all_terminal_rows():
|
||||
_insert("c-done", DownloadStatus.COMPLETED)
|
||||
_insert("c-fail", DownloadStatus.FAILED)
|
||||
_insert("c-canc", DownloadStatus.CANCELLED)
|
||||
|
||||
deleted = asyncio.run(DOWNLOAD_MANAGER.clear())
|
||||
|
||||
assert deleted == 3
|
||||
assert queries.get_download("c-done") is None
|
||||
assert queries.get_download("c-fail") is None
|
||||
assert queries.get_download("c-canc") is None
|
||||
|
||||
|
||||
def test_clear_skips_live_rows():
|
||||
_insert("cl-queued", DownloadStatus.QUEUED)
|
||||
_insert("cl-paused", DownloadStatus.PAUSED)
|
||||
_insert("cl-done", DownloadStatus.COMPLETED)
|
||||
|
||||
deleted = asyncio.run(DOWNLOAD_MANAGER.clear())
|
||||
|
||||
assert deleted == 1
|
||||
assert queries.get_download("cl-queued") is not None
|
||||
assert queries.get_download("cl-paused") is not None
|
||||
assert queries.get_download("cl-done") is None
|
||||
|
||||
|
||||
def test_clear_returns_zero_when_nothing_to_delete():
|
||||
_insert("cl-only-live", DownloadStatus.QUEUED)
|
||||
|
||||
deleted = asyncio.run(DOWNLOAD_MANAGER.clear())
|
||||
|
||||
assert deleted == 0
|
||||
assert queries.get_download("cl-only-live") is not None
|
||||
|
||||
|
||||
def test_clear_removes_leftover_temp_files(tmp_path):
|
||||
partial = tmp_path / "clear_partial.part"
|
||||
partial.write_bytes(b"partial data")
|
||||
finished = tmp_path / "finished.safetensors"
|
||||
finished.write_bytes(b"real model weights")
|
||||
|
||||
_insert("cl-part", DownloadStatus.FAILED, temp_path=str(partial))
|
||||
# The finished file is not the temp_path; temp_path for a completed download
|
||||
# no longer exists (already renamed), so use a non-existent path here to
|
||||
# verify clear() tolerates a missing temp file without raising.
|
||||
_insert("cl-comp", DownloadStatus.COMPLETED, temp_path=str(tmp_path / "gone.part"))
|
||||
|
||||
asyncio.run(DOWNLOAD_MANAGER.clear())
|
||||
|
||||
# Leftover .part from the failed download is cleaned up.
|
||||
assert not partial.exists()
|
||||
# Finished model file is never touched.
|
||||
assert finished.exists()
|
||||
|
||||
|
||||
def test_clear_empty_db_returns_zero():
|
||||
deleted = asyncio.run(DOWNLOAD_MANAGER.clear())
|
||||
assert deleted == 0
|
||||
@ -1,637 +0,0 @@
|
||||
"""Integration tests for the download engine against a local aiohttp server.
|
||||
|
||||
Covers single-stream and segmented transfers, deterministic resume from a
|
||||
partial file, and cancel rollback. Async tests are driven via ``asyncio.run``
|
||||
so no pytest-asyncio plugin is required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from comfy.cli_args import args
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.engine.job import DownloadJob, JobSpec
|
||||
from app.model_downloader.net.session import close_session
|
||||
from app.model_downloader.security import paths
|
||||
|
||||
PAYLOAD_ETAG = '"v1"'
|
||||
|
||||
|
||||
def _payload(n: int) -> bytes:
|
||||
return bytes((i * 37 + 11) % 256 for i in range(n))
|
||||
|
||||
|
||||
def _safetensors_payload(total: int) -> bytes:
|
||||
"""A structurally valid ``.safetensors`` blob of exactly ``total`` bytes.
|
||||
|
||||
Success-path tests download to ``.safetensors`` destinations, which the
|
||||
engine now structurally validates before the atomic rename, so their
|
||||
payloads must parse as real safetensors (header length + JSON header +
|
||||
data region whose size matches the declared ``data_offsets``).
|
||||
"""
|
||||
def _header(data_len: int) -> bytes:
|
||||
return json.dumps(
|
||||
{"w": {"dtype": "U8", "shape": [data_len], "data_offsets": [0, data_len]}}
|
||||
).encode("utf-8")
|
||||
|
||||
# The header's byte length depends on the digit count of ``data_len``, so
|
||||
# iterate until ``total == 8 + len(header) + data_len`` is self-consistent.
|
||||
data_len = total - 8 - len(_header(total))
|
||||
for _ in range(8):
|
||||
header = _header(data_len)
|
||||
new_data_len = total - 8 - len(header)
|
||||
if new_data_len == data_len:
|
||||
break
|
||||
data_len = new_data_len
|
||||
assert data_len >= 0, "total too small for a safetensors payload"
|
||||
header = _header(data_len)
|
||||
body = bytes((i * 37 + 11) % 256 for i in range(data_len))
|
||||
return struct.pack("<Q", len(header)) + header + body
|
||||
|
||||
|
||||
def _range_handler(payload: bytes):
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
rng = request.headers.get("Range")
|
||||
if rng:
|
||||
spec = rng.split("=", 1)[1]
|
||||
s, _, e = spec.partition("-")
|
||||
start = int(s)
|
||||
end = int(e) if e else len(payload) - 1
|
||||
chunk = payload[start : end + 1]
|
||||
return web.Response(
|
||||
status=206,
|
||||
body=chunk,
|
||||
headers={
|
||||
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"ETag": PAYLOAD_ETAG,
|
||||
},
|
||||
)
|
||||
return web.Response(
|
||||
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
|
||||
)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _content_disposition_handler(payload: bytes, filename: str):
|
||||
"""A range-capable server that only reveals its filename via a header.
|
||||
|
||||
Models a Civitai-style ``/api/download/...`` endpoint: the URL path has no
|
||||
extension, and the real filename (hence extension) lives in the response
|
||||
``Content-Disposition`` header.
|
||||
"""
|
||||
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"ETag": PAYLOAD_ETAG,
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
}
|
||||
rng = request.headers.get("Range")
|
||||
if rng:
|
||||
spec = rng.split("=", 1)[1]
|
||||
s, _, e = spec.partition("-")
|
||||
start = int(s)
|
||||
end = int(e) if e else len(payload) - 1
|
||||
chunk = payload[start : end + 1]
|
||||
return web.Response(
|
||||
status=206,
|
||||
body=chunk,
|
||||
headers={**headers, "Content-Range": f"bytes {start}-{end}/{len(payload)}"},
|
||||
)
|
||||
return web.Response(status=200, body=payload, headers=headers)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _noranges_handler(payload: bytes):
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
# Always full body, never advertises Accept-Ranges -> single-stream.
|
||||
return web.Response(status=200, body=payload)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _slow_handler(payload: bytes, chunk: int = 16384, delay: float = 0.01):
|
||||
async def handler(request: web.Request) -> web.StreamResponse:
|
||||
resp = web.StreamResponse(
|
||||
status=200, headers={"Content-Length": str(len(payload))}
|
||||
)
|
||||
await resp.prepare(request)
|
||||
for i in range(0, len(payload), chunk):
|
||||
await resp.write(payload[i : i + chunk])
|
||||
await asyncio.sleep(delay)
|
||||
await resp.write_eof()
|
||||
return resp
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _overflow_range_handler(payload: bytes, extra: int = 256 * 1024):
|
||||
"""A non-conforming 206 server that returns MORE than the requested range."""
|
||||
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
rng = request.headers.get("Range")
|
||||
if rng:
|
||||
spec = rng.split("=", 1)[1]
|
||||
s, _, e = spec.partition("-")
|
||||
start = int(s)
|
||||
end = int(e) if e else len(payload) - 1
|
||||
# Maliciously overrun: append extra bytes past the requested end.
|
||||
body = payload[start : end + 1] + bytes(extra)
|
||||
return web.Response(
|
||||
status=206,
|
||||
body=body,
|
||||
headers={
|
||||
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"ETag": PAYLOAD_ETAG,
|
||||
},
|
||||
)
|
||||
return web.Response(
|
||||
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
|
||||
)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _short_range_handler(payload: bytes, drop: int = 64 * 1024):
|
||||
"""A 206 server that returns fewer bytes than requested for later segments.
|
||||
|
||||
Simulates a server cleanly closing a range connection early. The response
|
||||
is internally consistent (Content-Length matches the short body), so the
|
||||
client sees no error and the segment just ends short, leaving a zero-filled
|
||||
hole in the preallocated file.
|
||||
"""
|
||||
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
rng = request.headers.get("Range")
|
||||
if rng:
|
||||
spec = rng.split("=", 1)[1]
|
||||
s, _, e = spec.partition("-")
|
||||
start = int(s)
|
||||
end = int(e) if e else len(payload) - 1
|
||||
chunk = payload[start : end + 1]
|
||||
if start > 0 and len(chunk) > drop:
|
||||
chunk = chunk[:-drop] # truncate a non-first segment
|
||||
return web.Response(
|
||||
status=206,
|
||||
body=chunk,
|
||||
headers={
|
||||
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"ETag": PAYLOAD_ETAG,
|
||||
},
|
||||
)
|
||||
return web.Response(
|
||||
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
|
||||
)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _unbounded_handler(total: int, chunk: int = 16384):
|
||||
"""A 200 stream with no Content-Length / Accept-Ranges (unknown length)."""
|
||||
|
||||
async def handler(request: web.Request) -> web.StreamResponse:
|
||||
resp = web.StreamResponse(status=200)
|
||||
await resp.prepare(request)
|
||||
sent = 0
|
||||
while sent < total:
|
||||
await resp.write(bytes(min(chunk, total - sent)))
|
||||
sent += chunk
|
||||
await resp.write_eof()
|
||||
return resp
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
async def _serve(handler):
|
||||
app = web.Application()
|
||||
app.router.add_route("*", "/{name:.*}", handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", 0)
|
||||
await site.start()
|
||||
port = site._server.sockets[0].getsockname()[1]
|
||||
return runner, port
|
||||
|
||||
|
||||
def _insert(model_id: str, url: str, status: str = DownloadStatus.QUEUED) -> tuple[str, str, str]:
|
||||
final_path, temp_path = paths.resolve_destination(model_id)
|
||||
download_id = str(uuid.uuid4())
|
||||
queries.insert_download(
|
||||
{
|
||||
"id": download_id,
|
||||
"url": url,
|
||||
"model_id": model_id,
|
||||
"dest_path": final_path,
|
||||
"temp_path": temp_path,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
return download_id, final_path, temp_path
|
||||
|
||||
|
||||
# ----- single-stream -----
|
||||
|
||||
|
||||
def test_single_stream_download(model_root):
|
||||
payload = _safetensors_payload(300_000)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_noranges_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, _temp = _insert("loras/single.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/single.safetensors",
|
||||
dest_path=final_path, temp_path=_temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
|
||||
assert os.path.exists(final_path)
|
||||
assert open(final_path, "rb").read() == payload
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- segmented -----
|
||||
|
||||
|
||||
def test_segmented_download(model_root):
|
||||
payload = _safetensors_payload(4 * 1024 * 1024) # 4 MiB -> multiple segments
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_range_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/seg.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/seg.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
|
||||
assert open(final_path, "rb").read() == payload
|
||||
# More than one segment row was planned.
|
||||
assert len(queries.list_segments(did)) > 1
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- deterministic resume from a partial file -----
|
||||
|
||||
|
||||
def test_resume_from_partial(model_root):
|
||||
payload = _safetensors_payload(512 * 1024) # < 1 MiB -> single segment
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_range_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/resume.safetensors", url)
|
||||
# Simulate a prior partial: first 200 KiB already written, offset persisted.
|
||||
prefix = 200 * 1024
|
||||
os.makedirs(os.path.dirname(temp), exist_ok=True)
|
||||
with open(temp, "wb") as f:
|
||||
f.write(payload[:prefix])
|
||||
queries.update_download(did, bytes_done=prefix, etag=PAYLOAD_ETAG)
|
||||
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/resume.safetensors",
|
||||
dest_path=final_path, temp_path=temp, etag=PAYLOAD_ETAG,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
|
||||
assert open(final_path, "rb").read() == payload
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- cancel rollback -----
|
||||
|
||||
|
||||
def test_cancel_rollback(model_root, monkeypatch):
|
||||
monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False)
|
||||
payload = _payload(1024 * 1024)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_slow_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/cancel.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/cancel.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
task = asyncio.ensure_future(job.run())
|
||||
# Wait until some bytes have been written, then cancel.
|
||||
for _ in range(200):
|
||||
await asyncio.sleep(0.01)
|
||||
if job.state.bytes_done > 0:
|
||||
break
|
||||
job.request_cancel()
|
||||
status = await task
|
||||
assert status == DownloadStatus.CANCELLED
|
||||
assert not os.path.exists(temp)
|
||||
assert not os.path.exists(final_path)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- size-bound enforcement (malicious / non-conforming hosts) -----
|
||||
|
||||
|
||||
def test_segment_overflow_aborts(model_root):
|
||||
"""A 206 returning more than the requested range must not overrun."""
|
||||
payload = _payload(4 * 1024 * 1024) # large enough to segment
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_overflow_range_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/overflow.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/overflow.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.FAILED
|
||||
assert not os.path.exists(final_path)
|
||||
assert not os.path.exists(temp)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_short_segment_fails_closed(model_root):
|
||||
"""A segment that ends short must fail, not be accepted as complete.
|
||||
|
||||
The file is preallocated to total_bytes, so the on-disk size still equals
|
||||
total even with a zero-filled hole; completeness must be judged per-segment.
|
||||
"""
|
||||
payload = _safetensors_payload(4 * 1024 * 1024) # large enough to segment
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_short_range_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/short.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/short.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.FAILED, queries.get_download(did).error
|
||||
assert "incomplete" in (queries.get_download(did).error or "")
|
||||
assert not os.path.exists(final_path)
|
||||
assert not os.path.exists(temp)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_structural_validation_rejects_corrupt(model_root):
|
||||
"""A correctly sized but structurally invalid file fails closed (not retried).
|
||||
|
||||
Regression for the dead structural gate: validation must key off the
|
||||
destination extension, not the ``.part`` temp suffix.
|
||||
"""
|
||||
payload = _payload(300_000) # right size, but not a valid safetensors blob
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_noranges_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/corrupt.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/corrupt.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.FAILED, queries.get_download(did).error
|
||||
assert not os.path.exists(final_path)
|
||||
assert not os.path.exists(temp)
|
||||
# Failed closed at first attempt, not re-queued as retryable.
|
||||
assert queries.get_download(did).attempts == 0
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_rejects_oversized_known_download(model_root, monkeypatch):
|
||||
"""A file whose advertised size exceeds the cap is rejected at probe."""
|
||||
monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False)
|
||||
payload = _payload(300_000)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_noranges_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/toobig.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/toobig.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.FAILED
|
||||
assert not os.path.exists(final_path)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_unknown_length_capped_by_max_bytes(model_root, monkeypatch):
|
||||
"""An unbounded unknown-length stream is capped by --download-max-bytes."""
|
||||
monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False)
|
||||
monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_unbounded_handler(2 * 1024 * 1024))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/unbounded.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/unbounded.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.FAILED
|
||||
assert not os.path.exists(final_path)
|
||||
assert not os.path.exists(temp)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- manager + scheduler end-to-end -----
|
||||
|
||||
|
||||
def test_manager_enqueue_to_completion(model_root):
|
||||
payload = _safetensors_payload(2 * 1024 * 1024)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER
|
||||
|
||||
runner, port = await _serve(_range_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/e2e.safetensors")
|
||||
# Wait for completion.
|
||||
final_path, _ = paths.resolve_destination("loras/e2e.safetensors")
|
||||
for _ in range(500):
|
||||
await asyncio.sleep(0.02)
|
||||
row = queries.get_download(did)
|
||||
if row.status in DownloadStatus.TERMINAL:
|
||||
break
|
||||
row = queries.get_download(did)
|
||||
assert row.status == DownloadStatus.COMPLETED, row.error
|
||||
assert open(final_path, "rb").read() == payload
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_manager_rejects_disallowed_url(model_root):
|
||||
async def _run():
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
|
||||
|
||||
with pytest.raises(DownloadError) as ei:
|
||||
await DOWNLOAD_MANAGER.enqueue(
|
||||
"https://evil.example.com/x.safetensors", "loras/bad.safetensors"
|
||||
)
|
||||
assert ei.value.code == "URL_NOT_ALLOWED"
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_manager_resolves_extensionless_url(model_root):
|
||||
"""An allowlisted URL with no extension in its path is resolved from the
|
||||
response, and the stored file adopts the resolved extension."""
|
||||
payload = _safetensors_payload(1 * 1024 * 1024)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER
|
||||
|
||||
runner, port = await _serve(
|
||||
_content_disposition_handler(payload, "RealModel.safetensors")
|
||||
)
|
||||
try:
|
||||
# No extension in the path (Civitai-style) and none in the model_id.
|
||||
url = f"http://127.0.0.1:{port}/api/download/models/12345"
|
||||
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/my_civitai_model")
|
||||
|
||||
row = queries.get_download(did)
|
||||
# The resolved extension was appended to the model_id + destination.
|
||||
assert row.model_id == "loras/my_civitai_model.safetensors"
|
||||
assert row.dest_path.endswith("my_civitai_model.safetensors")
|
||||
|
||||
final_path, _ = paths.resolve_destination(
|
||||
"loras/my_civitai_model.safetensors"
|
||||
)
|
||||
for _ in range(500):
|
||||
await asyncio.sleep(0.02)
|
||||
row = queries.get_download(did)
|
||||
if row.status in DownloadStatus.TERMINAL:
|
||||
break
|
||||
row = queries.get_download(did)
|
||||
assert row.status == DownloadStatus.COMPLETED, row.error
|
||||
assert open(final_path, "rb").read() == payload
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_manager_overrides_extension_from_resolution(model_root):
|
||||
"""A model_id carrying a different known extension is corrected to match
|
||||
the resolved URL's extension."""
|
||||
payload = _safetensors_payload(256 * 1024)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER
|
||||
|
||||
runner, port = await _serve(
|
||||
_content_disposition_handler(payload, "weights.safetensors")
|
||||
)
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/api/download/models/777"
|
||||
# Caller guessed .ckpt; resolution says .safetensors -> corrected.
|
||||
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/guessed.ckpt")
|
||||
row = queries.get_download(did)
|
||||
assert row.model_id == "loras/guessed.safetensors"
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_manager_rejects_non_model_resolution(model_root):
|
||||
"""A URL that resolves to a non-model file is rejected, not downloaded."""
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
|
||||
|
||||
runner, port = await _serve(
|
||||
_content_disposition_handler(b"not a model", "installer.zip")
|
||||
)
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/api/download/models/999"
|
||||
with pytest.raises(DownloadError) as ei:
|
||||
await DOWNLOAD_MANAGER.enqueue(url, "loras/whatever")
|
||||
assert ei.value.code == "URL_NOT_ALLOWED"
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
@ -1,81 +0,0 @@
|
||||
"""Unit tests for the segment planner and structural safetensors validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import struct
|
||||
|
||||
import pytest
|
||||
|
||||
from app.model_downloader.engine.planner import (
|
||||
effective_segment_count,
|
||||
plan_segments,
|
||||
)
|
||||
from app.model_downloader.verify import structural
|
||||
|
||||
|
||||
# ----- planner -----
|
||||
|
||||
|
||||
def test_plan_segments_covers_full_range_contiguously():
|
||||
total = 1000
|
||||
plans = plan_segments(total, 4)
|
||||
assert len(plans) == 4
|
||||
assert plans[0].start == 0
|
||||
assert plans[-1].end == total - 1
|
||||
# contiguous, no gaps/overlaps
|
||||
for a, b in zip(plans, plans[1:]):
|
||||
assert b.start == a.end + 1
|
||||
assert sum(p.length for p in plans) == total
|
||||
|
||||
|
||||
def test_effective_segment_count_falls_back_to_single():
|
||||
# No range support -> single
|
||||
assert effective_segment_count(10_000_000, False, 8) == 1
|
||||
# Unknown size -> single
|
||||
assert effective_segment_count(None, True, 8) == 1
|
||||
# Tiny file -> fewer segments than configured
|
||||
assert effective_segment_count(1024, True, 8) == 1
|
||||
# Large file with range support -> configured count
|
||||
assert effective_segment_count(1_000_000_000, True, 8) == 8
|
||||
|
||||
|
||||
# ----- structural -----
|
||||
|
||||
|
||||
def _make_safetensors(tensor_data_len: int, *, corrupt_size: bool = False) -> bytes:
|
||||
header = {"t": {"dtype": "F32", "shape": [tensor_data_len], "data_offsets": [0, tensor_data_len]}}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
body = b"\x00" * tensor_data_len
|
||||
if corrupt_size:
|
||||
body = body[:-1] # truncate one byte
|
||||
return struct.pack("<Q", len(header_bytes)) + header_bytes + body
|
||||
|
||||
|
||||
def test_structural_valid_safetensors(tmp_path):
|
||||
p = tmp_path / "ok.safetensors"
|
||||
p.write_bytes(_make_safetensors(256))
|
||||
structural.validate(str(p)) # no raise
|
||||
|
||||
|
||||
def test_structural_detects_truncation(tmp_path):
|
||||
p = tmp_path / "bad.safetensors"
|
||||
p.write_bytes(_make_safetensors(256, corrupt_size=True))
|
||||
with pytest.raises(structural.StructuralError):
|
||||
structural.validate(str(p))
|
||||
|
||||
|
||||
def test_structural_skips_unknown_extension(tmp_path):
|
||||
p = tmp_path / "weights.bin"
|
||||
p.write_bytes(b"anything")
|
||||
structural.validate(str(p)) # no structural check, no raise
|
||||
|
||||
|
||||
def test_structural_detects_truncation_via_name_hint(tmp_path):
|
||||
# The downloader validates the opaque temp file (a ``.part`` path) but keys
|
||||
# the format check off the final destination name via ``name_hint``, so
|
||||
# truncation must still be detected instead of silently skipped.
|
||||
p = tmp_path / "bad.comfy-download.part"
|
||||
p.write_bytes(_make_safetensors(256, corrupt_size=True))
|
||||
with pytest.raises(structural.StructuralError):
|
||||
structural.validate(str(p), name_hint="model.safetensors")
|
||||
@ -1,231 +0,0 @@
|
||||
"""Unit tests for the security layer: allowlist, SSRF checks, path safety."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.model_downloader.security import allowlist, paths
|
||||
from app.model_downloader.security.ssrf import (
|
||||
SSRFError,
|
||||
check_redirect_hop,
|
||||
is_blocked_ip,
|
||||
)
|
||||
|
||||
|
||||
# ----- allowlist -----
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url,allowed",
|
||||
[
|
||||
("https://huggingface.co/org/repo/resolve/main/model.safetensors", True),
|
||||
("https://civitai.com/api/download/x/model.safetensors", True),
|
||||
("http://localhost/model.safetensors", True),
|
||||
# off-list host
|
||||
("https://evil.example.com/model.safetensors", False),
|
||||
# http to a non-loopback allowlisted host is not permitted (https only)
|
||||
("http://huggingface.co/org/repo/resolve/main/model.safetensors", False),
|
||||
# bad extension on an allowed host
|
||||
("https://huggingface.co/org/repo/resolve/main/config.json", False),
|
||||
# userinfo trick: real host is the metadata IP, not 127.0.0.1
|
||||
("http://127.0.0.1@169.254.169.254/x.safetensors", False),
|
||||
],
|
||||
)
|
||||
def test_is_url_allowed(url, allowed):
|
||||
assert allowlist.is_url_allowed(url) is allowed
|
||||
|
||||
|
||||
def test_allow_any_extension_relaxes_extension_only():
|
||||
url = "https://huggingface.co/org/repo/resolve/main/weights.bin"
|
||||
assert allowlist.is_url_allowed(url) is True # .bin is in the known set
|
||||
odd = "https://huggingface.co/org/repo/resolve/main/weights.zip"
|
||||
assert allowlist.is_url_allowed(odd) is False
|
||||
assert allowlist.is_url_allowed(odd, allow_any_extension=True) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url,downloadable",
|
||||
[
|
||||
# known model extension in the path -> allowed
|
||||
("https://civitai.com/x/model.safetensors", True),
|
||||
# no extension in the path (Civitai download API) -> allowed, resolved later
|
||||
("https://civitai.com/api/download/models/3031464?fileId=2910346", True),
|
||||
("https://civitai.com/api/download/models/3031464", True),
|
||||
# explicit non-model extension -> rejected even on an allowed host
|
||||
("https://civitai.com/api/download/models/thing.zip", False),
|
||||
("https://huggingface.co/org/repo/resolve/main/config.json", False),
|
||||
# off-list host is never downloadable
|
||||
("https://evil.example.com/api/download/models/1", False),
|
||||
# http to a non-loopback allowlisted host is not permitted
|
||||
("http://civitai.com/api/download/models/1", False),
|
||||
],
|
||||
)
|
||||
def test_is_url_downloadable(url, downloadable):
|
||||
assert allowlist.is_url_downloadable(url) is downloadable
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name,ext",
|
||||
[
|
||||
("model.safetensors", ".safetensors"),
|
||||
("model.SAFETENSORS", ".safetensors"),
|
||||
("archive.tar.gz", ".gz"),
|
||||
("noext", ""),
|
||||
(".safetensors", ""), # leading-dot dotfile -> no extension
|
||||
("a/b/c/model.ckpt", ".ckpt"),
|
||||
],
|
||||
)
|
||||
def test_filename_extension(name, ext):
|
||||
assert allowlist.filename_extension(name) == ext
|
||||
|
||||
|
||||
# ----- SSRF: blocked IPs -----
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ip,blocked",
|
||||
[
|
||||
("169.254.169.254", True), # cloud metadata / link-local
|
||||
("127.0.0.1", True),
|
||||
("10.0.0.5", True),
|
||||
("192.168.1.1", True),
|
||||
("172.16.0.1", True),
|
||||
("::1", True),
|
||||
("0.0.0.0", True),
|
||||
# IPv4-mapped IPv6: must see through the mapping even on CPython
|
||||
# versions predating the gh-113171 is_* property fix.
|
||||
("::ffff:169.254.169.254", True), # mapped cloud metadata
|
||||
("::ffff:127.0.0.1", True), # mapped loopback
|
||||
("::ffff:10.0.0.1", True), # mapped RFC1918
|
||||
("::ffff:8.8.8.8", False), # mapped public address stays allowed
|
||||
("8.8.8.8", False),
|
||||
("1.1.1.1", False),
|
||||
("not-an-ip", True), # unparseable -> refuse
|
||||
],
|
||||
)
|
||||
def test_is_blocked_ip(ip, blocked):
|
||||
assert is_blocked_ip(ip) is blocked
|
||||
|
||||
|
||||
# ----- SSRF: redirect hop validation -----
|
||||
|
||||
|
||||
def test_check_redirect_hop_rejects_bad_scheme_and_userinfo():
|
||||
with pytest.raises(SSRFError):
|
||||
check_redirect_hop("ftp://huggingface.co/x.safetensors")
|
||||
with pytest.raises(SSRFError):
|
||||
check_redirect_hop("https://user:pass@cdn.example.com/x")
|
||||
# A CDN host that is NOT on the allowlist is allowed as a redirect target
|
||||
# (private-IP protection is the resolver's job; credential leak is prevented
|
||||
# by exact host matching).
|
||||
assert check_redirect_hop("https://cdn-lfs.huggingface.co/abc") is not None
|
||||
|
||||
|
||||
def test_check_redirect_hop_http_only_for_loopback():
|
||||
# Plain http to an external host is rejected (no plaintext downgrade).
|
||||
with pytest.raises(SSRFError):
|
||||
check_redirect_hop("http://cdn-lfs.huggingface.co/abc")
|
||||
# http is honored for loopback only on the initial user-supplied URL (the
|
||||
# "download a local model" feature).
|
||||
assert (
|
||||
check_redirect_hop("http://localhost/x.safetensors", is_initial_url=True)
|
||||
is not None
|
||||
)
|
||||
assert (
|
||||
check_redirect_hop("http://127.0.0.1/x.safetensors", is_initial_url=True)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def test_check_redirect_hop_blocks_loopback_and_ip_literals_on_redirect():
|
||||
# A redirect (is_initial_url=False, the default) must never reach loopback,
|
||||
# whether by hostname or by IP literal, nor any other internal IP literal.
|
||||
for target in (
|
||||
"http://localhost/x.safetensors",
|
||||
"http://127.0.0.1/x.safetensors",
|
||||
"https://[::1]/x.safetensors",
|
||||
"https://169.254.169.254/x.safetensors", # cloud metadata
|
||||
"https://10.0.0.5/x.safetensors", # RFC1918
|
||||
):
|
||||
with pytest.raises(SSRFError):
|
||||
check_redirect_hop(target)
|
||||
# Off-allowlist public CDN hosts (hostnames) remain valid redirect targets;
|
||||
# their resolved IPs are screened by the connector's resolver.
|
||||
assert check_redirect_hop("https://cdn-lfs.huggingface.co/abc") is not None
|
||||
|
||||
|
||||
# ----- path safety -----
|
||||
|
||||
|
||||
def test_parse_model_id_valid(model_root):
|
||||
directory, filename = paths.parse_model_id("loras/my_lora.safetensors")
|
||||
assert directory == "loras"
|
||||
assert filename == "my_lora.safetensors"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_id",
|
||||
[
|
||||
"loras/../etc/passwd.safetensors", # traversal
|
||||
"loras/sub/dir.safetensors", # nested
|
||||
"unknownfolder/x.safetensors", # unknown folder
|
||||
"loras/model.txt", # bad extension
|
||||
"noslash.safetensors", # missing directory
|
||||
"loras/", # empty filename
|
||||
],
|
||||
)
|
||||
def test_parse_model_id_rejects(model_root, model_id):
|
||||
with pytest.raises(paths.InvalidModelId):
|
||||
paths.parse_model_id(model_id)
|
||||
|
||||
|
||||
def test_resolve_destination_stays_in_root(model_root):
|
||||
final_path, temp_path = paths.resolve_destination("loras/x.safetensors")
|
||||
assert final_path.startswith(model_root)
|
||||
assert temp_path.startswith(model_root)
|
||||
assert temp_path != final_path
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_id,ext,expected",
|
||||
[
|
||||
# no extension -> append the resolved one
|
||||
("loras/my_civitai_model", ".safetensors", "loras/my_civitai_model.safetensors"),
|
||||
# different known extension -> replace it
|
||||
("loras/mymodel.ckpt", ".safetensors", "loras/mymodel.safetensors"),
|
||||
# same extension -> unchanged
|
||||
("loras/mymodel.safetensors", ".safetensors", "loras/mymodel.safetensors"),
|
||||
# non-model suffix is treated as a stem, extension appended
|
||||
("loras/my.model.v2", ".safetensors", "loras/my.model.v2.safetensors"),
|
||||
# malformed (no slash) is returned untouched for parse_model_id to reject
|
||||
("noslash", ".safetensors", "noslash"),
|
||||
],
|
||||
)
|
||||
def test_apply_extension(model_id, ext, expected):
|
||||
assert paths.apply_extension(model_id, ext) == expected
|
||||
|
||||
|
||||
# ----- Content-Disposition filename parsing -----
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"header,expected",
|
||||
[
|
||||
('attachment; filename="model.safetensors"', "model.safetensors"),
|
||||
("attachment; filename=model.ckpt", "model.ckpt"),
|
||||
# RFC 5987 form is preferred and percent-decoded
|
||||
(
|
||||
"attachment; filename=\"fallback.bin\"; filename*=UTF-8''my%20model.safetensors",
|
||||
"my model.safetensors",
|
||||
),
|
||||
# directory components in a hostile header are stripped to the basename
|
||||
('attachment; filename="../../etc/passwd"', "passwd"),
|
||||
('attachment; filename="a\\\\b\\\\model.pt"', "model.pt"),
|
||||
("inline", None),
|
||||
(None, None),
|
||||
],
|
||||
)
|
||||
def test_filename_from_content_disposition(header, expected):
|
||||
from app.model_downloader.net.http import filename_from_content_disposition
|
||||
|
||||
assert filename_from_content_disposition(header) == expected
|
||||
0
tests-unit/security_test/__init__.py
Normal file
0
tests-unit/security_test/__init__.py
Normal file
192
tests-unit/security_test/test_ghsa_779p_02_preview_traversal.py
Normal file
192
tests-unit/security_test/test_ghsa_779p_02_preview_traversal.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""CI unit tests for FIX #2 of GHSA-779p-m5rp-r4h4.
|
||||
|
||||
Path traversal / hardening in app/model_manager.py get_model_preview
|
||||
(route /experiment/models/preview/{folder}/{path_index}/{filename:.*}).
|
||||
|
||||
Reference: https://github.com/Comfy-Org/ComfyUI/security/advisories/GHSA-779p-m5rp-r4h4
|
||||
"""
|
||||
import pytest
|
||||
import yarl
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from aiohttp import web
|
||||
from unittest.mock import patch
|
||||
from app.model_manager import ModelFileManager
|
||||
|
||||
pytestmark = (
|
||||
pytest.mark.asyncio
|
||||
) # This applies the asyncio mark to all test functions in the module
|
||||
|
||||
@pytest.fixture
|
||||
def model_manager():
|
||||
return ModelFileManager()
|
||||
|
||||
@pytest.fixture
|
||||
def app(model_manager):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
model_manager.add_routes(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
|
||||
async def test_legit_preview_returns_200(aiohttp_client, app, tmp_path):
|
||||
"""Sanity: a real preview PNG inside the model folder is served as webp 200."""
|
||||
img = Image.new('RGB', (16, 16), color=(255, 0, 128))
|
||||
img.save(tmp_path / "test_model.png", format='PNG')
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/0/test_model.png')
|
||||
|
||||
assert response.status == 200
|
||||
assert response.content_type == 'image/webp'
|
||||
|
||||
img_bytes = BytesIO(await response.read())
|
||||
served = Image.open(img_bytes)
|
||||
assert served.format
|
||||
assert served.format.lower() == 'webp'
|
||||
served.close()
|
||||
|
||||
|
||||
async def test_non_integer_path_index_returns_400(aiohttp_client, app, tmp_path):
|
||||
"""A non-integer path_index segment must be rejected with 400."""
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/abc/test_model.png')
|
||||
|
||||
assert response.status == 400
|
||||
|
||||
|
||||
async def test_out_of_range_path_index_returns_404(aiohttp_client, app, tmp_path):
|
||||
"""A path_index beyond the configured folder list must return 404."""
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/99/test_model.png')
|
||||
|
||||
assert response.status == 404
|
||||
|
||||
|
||||
async def test_empty_filename_returns_400(aiohttp_client, app, tmp_path):
|
||||
"""The "{filename:.*}" capture also matches the empty string (trailing
|
||||
slash). It would resolve to the folder itself and must be rejected with 400."""
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/0/')
|
||||
|
||||
assert response.status == 400
|
||||
|
||||
|
||||
async def test_path_traversal_in_filename_returns_403(aiohttp_client, app, tmp_path):
|
||||
"""Path traversal in {filename} must be rejected with 403 and must NOT read
|
||||
a file outside the configured model directory.
|
||||
|
||||
GOTCHA: aiohttp/yarl collapses literal ``../`` dot-segments out of the URL
|
||||
path before it reaches the handler, which would make this test vacuously
|
||||
pass (the request would hit a different/non-existent route). We percent-encode
|
||||
the dots and slashes (``%2e%2e%2f``) and send the URL with
|
||||
``yarl.URL(..., encoded=True)`` so the bytes survive client-side normalization
|
||||
untouched; aiohttp's router then percent-decodes them into ``match_info``,
|
||||
delivering the literal ``../`` traversal to the handler's ``{filename:.*}``
|
||||
capture.
|
||||
|
||||
Without the fix the handler computes
|
||||
``os.path.normpath(os.path.join(folder, "../../../../etc/hosts"))``, which
|
||||
escapes ``tmp_path`` and would be passed straight to get_model_previews ->
|
||||
Image.open, serving bytes from outside the model dir (200/served bytes). The
|
||||
is_within_directory() containment check is the load-bearing fix that turns
|
||||
that escape into a 403.
|
||||
"""
|
||||
# Sanity-anchor: a legit preview exists inside tmp_path, so a 200 path is
|
||||
# genuinely reachable — proving the 403 below is the containment check
|
||||
# firing, not an unrelated 404.
|
||||
img = Image.new('RGB', (16, 16), color=(255, 0, 128))
|
||||
img.save(tmp_path / "test_model.png", format='PNG')
|
||||
|
||||
# Percent-encoded "../../../../etc/hosts" so yarl does not collapse the
|
||||
# dot-segments before the request leaves the client.
|
||||
encoded_traversal = '%2e%2e%2f' * 4 + 'etc%2fhosts'
|
||||
raw_path = '/experiment/models/preview/test_folder/0/' + encoded_traversal
|
||||
url = yarl.URL(raw_path, encoded=True)
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get(url)
|
||||
|
||||
# Confirm the traversal actually reached the handler intact: a 200 here
|
||||
# would mean either normalization stripped the ``../`` (vacuous pass) or
|
||||
# the containment check failed open and served outside-dir bytes.
|
||||
assert response.status == 403, (
|
||||
f"expected 403 from is_within_directory() containment check, "
|
||||
f"got {response.status}; traversal may have been normalized away "
|
||||
f"or the fix failed open"
|
||||
)
|
||||
body = await response.read()
|
||||
assert body == b"", "403 response must not carry any file bytes"
|
||||
|
||||
|
||||
async def test_symlink_companion_preview_returns_403(aiohttp_client, app, tmp_path):
|
||||
"""A companion preview file is selected by a glob inside get_model_previews
|
||||
and then opened. If that companion is a symlink whose path is in-dir but
|
||||
whose target escapes the model folder, it must be rejected with 403 — not
|
||||
served. The requested path itself stays in-dir (so the first containment
|
||||
check passes); the load-bearing fix is the SECOND is_within_directory check
|
||||
on the file actually opened.
|
||||
"""
|
||||
model_dir = tmp_path / "models"
|
||||
model_dir.mkdir()
|
||||
secret_dir = tmp_path / "secret"
|
||||
secret_dir.mkdir()
|
||||
# A real image OUTSIDE the model dir — valid, so without the fix Image.open
|
||||
# would succeed and its bytes would be served (200).
|
||||
secret = secret_dir / "secret.png"
|
||||
Image.new('RGB', (8, 8), color=(0, 0, 0)).save(secret, format='PNG')
|
||||
# Companion preview, in-dir by name but a symlink escaping the model dir.
|
||||
# (No real model file is needed — get_model_previews globs companions by
|
||||
# basename, and omitting a .safetensors avoids the metadata-header read.)
|
||||
companion = model_dir / "model.preview.png"
|
||||
try:
|
||||
companion.symlink_to(secret)
|
||||
except (OSError, NotImplementedError):
|
||||
pytest.skip("symlinks not supported on this platform/filesystem")
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(model_dir)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/0/model.safetensors')
|
||||
|
||||
assert response.status == 403, (
|
||||
f"expected 403 — the globbed companion preview is a symlink resolving "
|
||||
f"outside the model dir and must not be served; got {response.status}"
|
||||
)
|
||||
assert await response.read() == b""
|
||||
|
||||
|
||||
async def test_null_byte_in_filename_no_500(aiohttp_client, app, tmp_path):
|
||||
"""A NUL byte in the filename must yield a clean client rejection, not a 500
|
||||
from an uncaught ValueError in is_within_directory's realpath() call."""
|
||||
raw_path = '/experiment/models/preview/test_folder/0/' + 'a%00b'
|
||||
url = yarl.URL(raw_path, encoded=True)
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get(url)
|
||||
|
||||
assert response.status != 500, (
|
||||
f"NUL byte produced a 500 (uncaught ValueError); expected a clean "
|
||||
f"4xx rejection, got {response.status}"
|
||||
)
|
||||
assert 400 <= response.status < 500
|
||||
@ -0,0 +1,165 @@
|
||||
"""Security tests for GHSA-779p-m5rp-r4h4 — FIX #3.
|
||||
|
||||
Path traversal in folder_paths.get_annotated_filepath / exists_annotated_filepath,
|
||||
plus the shared is_within_directory() containment helper.
|
||||
|
||||
These are pure-function tests (no running server). The input/output/temp
|
||||
directories are pointed at tmp_path via the folder_paths setters, so a crafted
|
||||
name containing `../`, an absolute path, or a symlink that escapes the base
|
||||
directory must be rejected.
|
||||
|
||||
Reference: https://github.com/Comfy-Org/ComfyUI/security/advisories/GHSA-779p-m5rp-r4h4
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
import folder_paths
|
||||
from comfy.options import enable_args_parsing
|
||||
enable_args_parsing()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sandbox(tmp_path):
|
||||
"""Point folder_paths' input/output/temp dirs at a real temp sandbox.
|
||||
|
||||
Yields the realpath'd base, input, output and temp directories. The original
|
||||
directory values are restored afterward so tests stay isolated.
|
||||
"""
|
||||
base = os.path.realpath(str(tmp_path))
|
||||
input_dir = os.path.join(base, "input")
|
||||
output_dir = os.path.join(base, "output")
|
||||
temp_dir = os.path.join(base, "temp")
|
||||
for d in (input_dir, output_dir, temp_dir):
|
||||
os.makedirs(d, exist_ok=True)
|
||||
|
||||
orig_input = folder_paths.get_input_directory()
|
||||
orig_output = folder_paths.get_output_directory()
|
||||
orig_temp = folder_paths.get_temp_directory()
|
||||
|
||||
folder_paths.set_input_directory(input_dir)
|
||||
folder_paths.set_output_directory(output_dir)
|
||||
folder_paths.set_temp_directory(temp_dir)
|
||||
|
||||
yield {
|
||||
"base": base,
|
||||
"input": input_dir,
|
||||
"output": output_dir,
|
||||
"temp": temp_dir,
|
||||
}
|
||||
|
||||
folder_paths.set_input_directory(orig_input)
|
||||
folder_paths.set_output_directory(orig_output)
|
||||
folder_paths.set_temp_directory(orig_temp)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_within_directory() — the shared containment helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_is_within_directory_legit_child(sandbox):
|
||||
base = sandbox["input"]
|
||||
child = os.path.join(base, "sub", "image.png")
|
||||
assert folder_paths.is_within_directory(base, child) is True
|
||||
|
||||
|
||||
def test_is_within_directory_dotdot_escape(sandbox):
|
||||
base = sandbox["input"]
|
||||
escape = os.path.join(base, "..", "..", "etc", "passwd")
|
||||
assert folder_paths.is_within_directory(base, escape) is False
|
||||
|
||||
|
||||
def test_is_within_directory_symlink_escape(sandbox):
|
||||
"""A symlink created INSIDE base that points OUTSIDE base must not pass.
|
||||
|
||||
This is the key new hardening: is_within_directory realpath()s both operands,
|
||||
so a symlink planted in the base directory can't be used to read files
|
||||
elsewhere. We create a real on-disk symlink and a real secret target to
|
||||
verify the check actually resolves the link.
|
||||
"""
|
||||
base = sandbox["input"]
|
||||
|
||||
# A directory living outside the base, holding a secret file.
|
||||
outside = os.path.join(sandbox["base"], "outside_secret_dir")
|
||||
os.makedirs(outside, exist_ok=True)
|
||||
secret = os.path.join(outside, "secret.txt")
|
||||
with open(secret, "w") as f:
|
||||
f.write("top secret")
|
||||
|
||||
# Plant a symlink inside base that points at the outside directory.
|
||||
# symlink creation can require elevated privileges / Developer Mode on
|
||||
# Windows, so skip cleanly where it isn't available (same guard as the
|
||||
# sibling test in test_ghsa_779p_02_preview_traversal.py).
|
||||
link = os.path.join(base, "escape_link")
|
||||
try:
|
||||
os.symlink(outside, link)
|
||||
except (OSError, NotImplementedError):
|
||||
pytest.skip("symlinks not supported on this platform/filesystem")
|
||||
|
||||
# Accessing the secret "through" the in-base symlink must be rejected.
|
||||
target_via_link = os.path.join(link, "secret.txt")
|
||||
assert folder_paths.is_within_directory(base, target_via_link) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_annotated_filepath()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_annotated_filepath_legit_name(sandbox):
|
||||
result = folder_paths.get_annotated_filepath("image.png")
|
||||
assert result == os.path.join(sandbox["input"], "image.png")
|
||||
assert folder_paths.is_within_directory(sandbox["input"], result)
|
||||
|
||||
|
||||
def test_get_annotated_filepath_input_annotation(sandbox):
|
||||
result = folder_paths.get_annotated_filepath("image.png [input]")
|
||||
assert result == os.path.join(sandbox["input"], "image.png")
|
||||
|
||||
|
||||
def test_get_annotated_filepath_output_annotation(sandbox):
|
||||
result = folder_paths.get_annotated_filepath("image.png [output]")
|
||||
assert result == os.path.join(sandbox["output"], "image.png")
|
||||
|
||||
|
||||
def test_get_annotated_filepath_temp_annotation(sandbox):
|
||||
result = folder_paths.get_annotated_filepath("image.png [temp]")
|
||||
assert result == os.path.join(sandbox["temp"], "image.png")
|
||||
|
||||
|
||||
def test_get_annotated_filepath_dotdot_raises(sandbox):
|
||||
with pytest.raises(ValueError):
|
||||
folder_paths.get_annotated_filepath("../etc/passwd")
|
||||
|
||||
|
||||
def test_get_annotated_filepath_dotdot_with_annotation_raises(sandbox):
|
||||
with pytest.raises(ValueError):
|
||||
folder_paths.get_annotated_filepath("../../etc/passwd [output]")
|
||||
|
||||
|
||||
def test_get_annotated_filepath_absolute_escape_raises(sandbox):
|
||||
with pytest.raises(ValueError):
|
||||
folder_paths.get_annotated_filepath("/etc/passwd")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# exists_annotated_filepath()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_exists_annotated_filepath_existing_legit_file(sandbox):
|
||||
real = os.path.join(sandbox["input"], "real.png")
|
||||
with open(real, "w") as f:
|
||||
f.write("data")
|
||||
assert folder_paths.exists_annotated_filepath("real.png") is True
|
||||
|
||||
|
||||
def test_exists_annotated_filepath_traversal_returns_false(sandbox):
|
||||
"""A traversal name must return False without raising and without probing
|
||||
outside the base directory (must never reach os.path.exists for the escape).
|
||||
"""
|
||||
# /etc/passwd exists on POSIX; the function must still report False because
|
||||
# the resolved path escapes the input directory.
|
||||
assert folder_paths.exists_annotated_filepath("../../../../../../etc/passwd") is False
|
||||
|
||||
|
||||
def test_exists_annotated_filepath_absolute_returns_false(sandbox):
|
||||
assert folder_paths.exists_annotated_filepath("/etc/passwd") is False
|
||||
147
tests-unit/security_test/test_ghsa_779p_04_userdata_xss.py
Normal file
147
tests-unit/security_test/test_ghsa_779p_04_userdata_xss.py
Normal file
@ -0,0 +1,147 @@
|
||||
"""
|
||||
CI unit tests for FIX #4 of GHSA-779p-m5rp-r4h4.
|
||||
|
||||
Stored-XSS hardening on GET /userdata/{file} in app/user_manager.py.
|
||||
|
||||
User data files are arbitrary user-supplied content and must never render
|
||||
inline in the app origin. The getuserdata handler:
|
||||
- forces Content-Type to application/octet-stream for any type in
|
||||
folder_paths.DANGEROUS_CONTENT_TYPES (text/html, image/svg+xml,
|
||||
text/javascript, ...),
|
||||
- sets X-Content-Type-Options: nosniff,
|
||||
- sets Content-Disposition: attachment.
|
||||
|
||||
These tests pre-create files in tmp_path and GET them back, asserting the
|
||||
secure response headers. They mirror the aiohttp_client pattern in
|
||||
tests-unit/prompt_server_test/user_manager_test.py.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
from aiohttp import web
|
||||
from app.user_manager import UserManager
|
||||
|
||||
pytestmark = (
|
||||
pytest.mark.asyncio
|
||||
) # This applies the asyncio mark to all test functions in the module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_manager(tmp_path):
|
||||
um = UserManager()
|
||||
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
|
||||
tmp_path, file
|
||||
) if file else tmp_path
|
||||
return um
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(user_manager):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
user_manager.add_routes(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
|
||||
async def test_html_served_as_octet_stream(aiohttp_client, app, tmp_path):
|
||||
(tmp_path / "evil.html").write_text(
|
||||
"<script>console.log('xss-marker-ghsa-779p')</script>"
|
||||
)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata/evil.html")
|
||||
|
||||
assert resp.status == 200
|
||||
ct = resp.headers.get("Content-Type", "")
|
||||
# The load-bearing assertion: a .html file must NOT be served as text/html.
|
||||
assert "text/html" not in ct.lower(), (
|
||||
f"Content-Type {ct!r} would let a browser render/execute the file (stored XSS)."
|
||||
)
|
||||
assert ct == "application/octet-stream"
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert "attachment" in resp.headers.get("Content-Disposition", "")
|
||||
|
||||
|
||||
async def test_svg_served_as_octet_stream(aiohttp_client, app, tmp_path):
|
||||
(tmp_path / "evil.svg").write_text(
|
||||
'<?xml version="1.0"?>'
|
||||
'<svg xmlns="http://www.w3.org/2000/svg">'
|
||||
'<script>console.log("xss-marker-ghsa-779p")</script>'
|
||||
"</svg>"
|
||||
)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata/evil.svg")
|
||||
|
||||
assert resp.status == 200
|
||||
ct = resp.headers.get("Content-Type", "")
|
||||
# SVG can carry inline <script>; it must not be served as image/svg+xml.
|
||||
assert "svg" not in ct.lower(), (
|
||||
f"Content-Type {ct!r} would let a browser render the SVG and execute embedded scripts."
|
||||
)
|
||||
assert ct == "application/octet-stream"
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert "attachment" in resp.headers.get("Content-Disposition", "")
|
||||
|
||||
|
||||
async def test_js_served_as_octet_stream(aiohttp_client, app, tmp_path):
|
||||
(tmp_path / "evil.js").write_text("alert('xss-marker-ghsa-779p')")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata/evil.js")
|
||||
|
||||
assert resp.status == 200
|
||||
ct = resp.headers.get("Content-Type", "").lower()
|
||||
# Must not be served as any executable JavaScript content type.
|
||||
assert "javascript" not in ct, (
|
||||
f"Content-Type {ct!r} is an executable JS type."
|
||||
)
|
||||
assert "ecmascript" not in ct, (
|
||||
f"Content-Type {ct!r} is an executable JS type."
|
||||
)
|
||||
assert ct == "application/octet-stream"
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert "attachment" in resp.headers.get("Content-Disposition", "")
|
||||
|
||||
|
||||
async def test_xml_dialect_served_as_octet_stream(aiohttp_client, app, tmp_path):
|
||||
"""An XML dialect outside the original blocklist (.xslt -> application/xslt+xml)
|
||||
must still be forced to download. This pins the normalised *+xml family rule
|
||||
in folder_paths.is_dangerous_content_type(); a plain set-membership test would
|
||||
have served this inline."""
|
||||
(tmp_path / "evil.xslt").write_text(
|
||||
'<?xml version="1.0"?>'
|
||||
'<xsl:stylesheet version="1.0" '
|
||||
'xmlns:xsl="http://www.w3.org/1999/XSL/Transform">'
|
||||
"<!-- xss-marker-ghsa-779p -->"
|
||||
"</xsl:stylesheet>"
|
||||
)
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata/evil.xslt")
|
||||
|
||||
assert resp.status == 200
|
||||
ct = resp.headers.get("Content-Type", "")
|
||||
assert ct == "application/octet-stream", (
|
||||
f"Content-Type {ct!r}: an *+xml dialect must be forced to octet-stream "
|
||||
f"(it can carry inline script via stylesheet/entity tricks)."
|
||||
)
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert "attachment" in resp.headers.get("Content-Disposition", "")
|
||||
|
||||
|
||||
async def test_benign_txt_still_served(aiohttp_client, app, tmp_path):
|
||||
(tmp_path / "note.txt").write_text("just a harmless note")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata/note.txt")
|
||||
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == "just a harmless note"
|
||||
ct = resp.headers.get("Content-Type", "")
|
||||
# text/plain is not in the dangerous set, so it is acceptable here. The
|
||||
# defence-in-depth headers must still be present regardless.
|
||||
assert "text/plain" in ct.lower()
|
||||
assert resp.headers.get("X-Content-Type-Options") == "nosniff"
|
||||
assert "attachment" in resp.headers.get("Content-Disposition", "")
|
||||
@ -0,0 +1,138 @@
|
||||
"""CI unit guard for FIX #5 of GHSA-779p-m5rp-r4h4 — the /view forced-download set.
|
||||
|
||||
Vuln #5 was stored XSS via SVG upload: the /view endpoint's Content-Type
|
||||
blocklist covered text/html, text/javascript, etc. but was missing
|
||||
image/svg+xml, so an uploaded SVG carrying an inline <script> was served as
|
||||
image/svg+xml and executed in the page origin when rendered.
|
||||
|
||||
The /view forced-download decision lives in the view_image closure registered by
|
||||
server.PromptServer.add_routes (server.py ~line 596), which calls
|
||||
`folder_paths.is_dangerous_content_type(content_type)` — a normalising check that
|
||||
strips charset/boundary parameters and casing and folds in the whole */xml and
|
||||
*+xml dialect family — rather than a bypassable raw
|
||||
`content_type in folder_paths.DANGEROUS_CONTENT_TYPES` membership test. On a match
|
||||
it rewrites the response to application/octet-stream with a
|
||||
Content-Disposition: attachment header. server.py cannot be imported in a unit
|
||||
test (importing it spins up the full PromptServer/aiohttp app and its global side
|
||||
effects), so these tests pin the underlying dangerous-content data
|
||||
(folder_paths.DANGEROUS_CONTENT_TYPES) and the normalising is_dangerous_content_type()
|
||||
helper that the closure actually calls.
|
||||
|
||||
The end-to-end /view assertion (upload an SVG, GET /view, confirm the response
|
||||
is not served as image/svg+xml) lives in the live POC at
|
||||
.security/pocs/test_security_ghsa_779p.py::TestViewSvgContentType, which
|
||||
requires a running server. This file is the fast, server-free CI guard on the
|
||||
set contents so the blocklist can't silently regress.
|
||||
"""
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
# Active/renderable content types that must be forced to download. Each of these
|
||||
# can carry an inline <script> (or otherwise execute) in the page origin if a
|
||||
# browser renders it. image/svg+xml is the original missing item that caused
|
||||
# vuln #5.
|
||||
DANGEROUS = [
|
||||
'image/svg+xml',
|
||||
'application/xml',
|
||||
'text/xml',
|
||||
'text/html',
|
||||
'text/html-sandboxed',
|
||||
'application/xhtml+xml',
|
||||
'text/javascript',
|
||||
'application/javascript',
|
||||
'application/x-javascript',
|
||||
'application/ecmascript',
|
||||
'text/css',
|
||||
]
|
||||
|
||||
# Benign image types that browsers display inline and that must keep rendering;
|
||||
# forcing these to download would break legitimate previews.
|
||||
BENIGN_INLINE_IMAGES = [
|
||||
'image/png',
|
||||
'image/jpeg',
|
||||
'image/webp',
|
||||
'image/gif',
|
||||
]
|
||||
|
||||
|
||||
def test_dangerous_content_types_is_a_set():
|
||||
assert isinstance(folder_paths.DANGEROUS_CONTENT_TYPES, set)
|
||||
|
||||
|
||||
def test_svg_is_in_the_blocklist():
|
||||
"""The specific item whose absence caused vuln #5."""
|
||||
assert 'image/svg+xml' in folder_paths.DANGEROUS_CONTENT_TYPES, (
|
||||
"image/svg+xml missing from DANGEROUS_CONTENT_TYPES — this is exactly "
|
||||
"the regression that reopens GHSA-779p-m5rp-r4h4 vuln #5 (stored XSS "
|
||||
"via SVG upload on /view)."
|
||||
)
|
||||
|
||||
|
||||
def test_all_dangerous_types_present():
|
||||
missing = [ct for ct in DANGEROUS if ct not in folder_paths.DANGEROUS_CONTENT_TYPES]
|
||||
assert not missing, (
|
||||
f"DANGEROUS_CONTENT_TYPES is missing required active/renderable types: "
|
||||
f"{missing}. The /view closure only forces a download for content types "
|
||||
f"in this set; anything missing here is served inline and can execute."
|
||||
)
|
||||
|
||||
|
||||
def test_benign_inline_image_types_absent():
|
||||
leaked = [ct for ct in BENIGN_INLINE_IMAGES if ct in folder_paths.DANGEROUS_CONTENT_TYPES]
|
||||
assert not leaked, (
|
||||
f"Benign inline-displayable image types found in DANGEROUS_CONTENT_TYPES: "
|
||||
f"{leaked}. Forcing these to download would break legitimate image "
|
||||
f"previews in /view — they must keep rendering inline."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_dangerous_content_type() — the normalising check the /view and /userdata
|
||||
# handlers now call instead of a raw `in DANGEROUS_CONTENT_TYPES` membership
|
||||
# test. An exact-string membership test was bypassable with a charset parameter
|
||||
# or odd casing, and missed the wider XML dialect family; these tests pin the
|
||||
# normalisation so that bypass can't reopen.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_function_matches_plain_dangerous_types():
|
||||
for ct in DANGEROUS:
|
||||
assert folder_paths.is_dangerous_content_type(ct) is True, ct
|
||||
|
||||
|
||||
def test_function_strips_parameters_and_casing():
|
||||
"""A charset/boundary parameter or casing must not slip a type past the check.
|
||||
|
||||
This is the bypass surfaced by review: the /view blake3 branch can serve an
|
||||
attacker-controlled, unvalidated asset mime_type like 'text/html; charset=utf-8',
|
||||
which an exact-string set test missed.
|
||||
"""
|
||||
for ct in (
|
||||
'text/html; charset=utf-8',
|
||||
'TEXT/HTML',
|
||||
'Text/HTML; charset=UTF-8',
|
||||
'image/svg+xml; charset=utf-8',
|
||||
' text/html ',
|
||||
):
|
||||
assert folder_paths.is_dangerous_content_type(ct) is True, ct
|
||||
|
||||
|
||||
def test_function_covers_xml_dialect_family():
|
||||
"""Any *+xml / */xml dialect is dangerous without enumerating each one."""
|
||||
for ct in (
|
||||
'application/xslt+xml',
|
||||
'application/rss+xml',
|
||||
'application/atom+xml',
|
||||
'application/rdf+xml',
|
||||
'application/mathml+xml',
|
||||
'message/rfc822',
|
||||
):
|
||||
assert folder_paths.is_dangerous_content_type(ct) is True, ct
|
||||
|
||||
|
||||
def test_function_allows_benign_and_empty():
|
||||
for ct in BENIGN_INLINE_IMAGES + ['application/octet-stream', 'text/plain']:
|
||||
assert folder_paths.is_dangerous_content_type(ct) is False, ct
|
||||
# None / empty (mimetypes.guess_type miss) must not be treated as dangerous.
|
||||
assert folder_paths.is_dangerous_content_type(None) is False
|
||||
assert folder_paths.is_dangerous_content_type('') is False
|
||||
Reference in New Issue
Block a user