mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-02 02:37:08 +08:00
Compare commits
47 Commits
ListInput
...
model_down
| Author | SHA1 | Date | |
|---|---|---|---|
| c98a212589 | |||
| 4c82c708a7 | |||
| fe4d0c9722 | |||
| 64c5853631 | |||
| 8a6e7906f7 | |||
| e4b0a72e83 | |||
| 28b41d4d6d | |||
| 3eb36377a8 | |||
| 27fd68a533 | |||
| 61816d436b | |||
| abc0b728ab | |||
| 1bbd4a57db | |||
| b419fd8399 | |||
| e77983ca28 | |||
| 893ba2ad37 | |||
| 312b282ca8 | |||
| 7690c52a34 | |||
| e326ef3b16 | |||
| 115a4305ea | |||
| 9be31a4b7e | |||
| 7dba134cda | |||
| e0b07014c0 | |||
| 3eade55077 | |||
| c785130223 | |||
| 4ccaaa6f37 | |||
| 58392bf7a6 | |||
| 4ae294d2d5 | |||
| 95b0758a88 | |||
| dba2bfbc02 | |||
| 4bd7cc153e | |||
| 2b708d5af7 | |||
| 9c845eeb9e | |||
| e70f524d5f | |||
| d058ffb761 | |||
| 53ec95b87e | |||
| e02c7a0890 | |||
| 1744026eca | |||
| f660307489 | |||
| c7c18377a3 | |||
| 8fe0243d97 | |||
| ba3f697dbb | |||
| 510ed5c384 | |||
| 7851410511 | |||
| a58473fd9b | |||
| 79c555ce6b | |||
| f19735759e | |||
| a95e461916 |
38
.github/workflows/ci-cursor-review.yml
vendored
Normal file
38
.github/workflows/ci-cursor-review.yml
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
name: CI - Cursor Review
|
||||
|
||||
# Thin caller for the shared reusable cursor-review workflow in
|
||||
# Comfy-Org/github-workflows. The review logic (panel matrix, judge
|
||||
# consolidation, prompts, extract/post/notify scripts) lives there as the
|
||||
# single source of truth, so this repo only carries the repo-specific diff
|
||||
# excludes.
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [labeled, unlabeled]
|
||||
|
||||
concurrency:
|
||||
group: cursor-review-pr-${{ github.event.pull_request.number }}-${{ github.event.label.name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
cursor-review:
|
||||
if: github.event.label.name == 'cursor-review'
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
# SHA-pinned per zizmor `unpinned-uses: hash-pin`. Bump this SHA to pick up
|
||||
# upstream changes; keep `workflows_ref` matching so prompts/scripts load
|
||||
# from the same commit as the workflow definition.
|
||||
uses: Comfy-Org/github-workflows/.github/workflows/cursor-review.yml@047ca48febe3a6647608ed2e0c4331b491cb9d6a # github-workflows#9
|
||||
with:
|
||||
workflows_ref: 047ca48febe3a6647608ed2e0c4331b491cb9d6a
|
||||
diff_excludes: >-
|
||||
:!**/.claude/**
|
||||
:!**/dist/**
|
||||
:!**/vendor/**
|
||||
:!**/*.generated.*
|
||||
:!**/*.min.js
|
||||
:!**/*.min.css
|
||||
secrets:
|
||||
CURSOR_API_KEY: ${{ secrets.CURSOR_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
115
alembic_db/versions/0005_download_manager.py
Normal file
115
alembic_db/versions/0005_download_manager.py
Normal file
@ -0,0 +1,115 @@
|
||||
"""
|
||||
Download manager schema.
|
||||
|
||||
Adds the three tables that back the server-side model download manager
|
||||
: transient job/queue state (``downloads`` + per-segment
|
||||
``download_segments``) and one-API-key-per-host auth (``host_credentials``).
|
||||
|
||||
Revision ID: 0005_download_manager
|
||||
Revises: 0004_drop_tag_type
|
||||
Create Date: 2026-06-27
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0005_download_manager"
|
||||
down_revision = "0004_drop_tag_type"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"downloads",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("url", sa.Text(), nullable=False),
|
||||
sa.Column("final_url", sa.Text(), nullable=True),
|
||||
sa.Column("model_id", sa.String(length=1024), nullable=False),
|
||||
sa.Column("dest_path", sa.Text(), nullable=False),
|
||||
sa.Column("temp_path", sa.Text(), nullable=False),
|
||||
sa.Column("status", sa.String(length=16), nullable=False),
|
||||
sa.Column("priority", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("total_bytes", sa.BigInteger(), nullable=True),
|
||||
sa.Column("bytes_done", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("etag", sa.String(length=512), nullable=True),
|
||||
sa.Column("last_modified", sa.String(length=128), nullable=True),
|
||||
sa.Column(
|
||||
"accept_ranges", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column("expected_sha256", sa.String(length=64), nullable=True),
|
||||
sa.Column("credential_id", sa.String(length=36), nullable=True),
|
||||
sa.Column(
|
||||
"allow_any_extension",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
sa.Column("attempts", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("error", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
sa.CheckConstraint("bytes_done >= 0", name="ck_downloads_bytes_done_nonneg"),
|
||||
sa.CheckConstraint(
|
||||
"total_bytes IS NULL OR total_bytes >= 0",
|
||||
name="ck_downloads_total_bytes_nonneg",
|
||||
),
|
||||
)
|
||||
op.create_index("ix_downloads_status", "downloads", ["status"])
|
||||
op.create_index("ix_downloads_priority", "downloads", ["priority"])
|
||||
op.create_index("ix_downloads_model_id", "downloads", ["model_id"])
|
||||
|
||||
op.create_table(
|
||||
"download_segments",
|
||||
sa.Column(
|
||||
"download_id",
|
||||
sa.String(length=36),
|
||||
sa.ForeignKey("downloads.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("idx", sa.Integer(), nullable=False),
|
||||
sa.Column("start_offset", sa.BigInteger(), nullable=False),
|
||||
sa.Column("end_offset", sa.BigInteger(), nullable=False),
|
||||
sa.Column("bytes_done", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.PrimaryKeyConstraint("download_id", "idx", name="pk_download_segments"),
|
||||
sa.CheckConstraint("bytes_done >= 0", name="ck_segments_bytes_done_nonneg"),
|
||||
sa.CheckConstraint("end_offset >= start_offset", name="ck_segments_range"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"host_credentials",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("host", sa.String(length=255), nullable=False),
|
||||
sa.Column(
|
||||
"match_subdomains",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
sa.Column("label", sa.String(length=255), nullable=True),
|
||||
sa.Column(
|
||||
"auth_scheme", sa.String(length=16), nullable=False, server_default="bearer"
|
||||
),
|
||||
sa.Column("header_name", sa.String(length=255), nullable=True),
|
||||
sa.Column("query_param", sa.String(length=255), nullable=True),
|
||||
sa.Column("secret", sa.Text(), nullable=False),
|
||||
sa.Column("secret_last4", sa.String(length=4), nullable=True),
|
||||
sa.Column("enabled", sa.Boolean(), nullable=False, server_default=sa.text("true")),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
)
|
||||
op.create_index(
|
||||
"uq_host_credentials_host", "host_credentials", ["host"], unique=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("uq_host_credentials_host", table_name="host_credentials")
|
||||
op.drop_table("host_credentials")
|
||||
|
||||
op.drop_table("download_segments")
|
||||
|
||||
op.drop_index("ix_downloads_model_id", table_name="downloads")
|
||||
op.drop_index("ix_downloads_priority", table_name="downloads")
|
||||
op.drop_index("ix_downloads_status", table_name="downloads")
|
||||
op.drop_table("downloads")
|
||||
@ -4,7 +4,11 @@ import shutil
|
||||
from app.logger import log_startup_warning
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
from filelock import FileLock, Timeout
|
||||
from comfy.cli_args import args
|
||||
# NOTE: import the module (not `from ... import args`) so we always read the
|
||||
# live `args` object. Tests reload `comfy.cli_args`, which replaces the module
|
||||
# global; a bound `args` reference would go stale and point at the default
|
||||
# database URL instead of the one configured for the test.
|
||||
import comfy.cli_args
|
||||
|
||||
_DB_AVAILABLE = False
|
||||
Session = None
|
||||
@ -21,6 +25,7 @@ try:
|
||||
|
||||
from app.database.models import Base
|
||||
import app.assets.database.models # noqa: F401 — register models with Base.metadata
|
||||
import app.model_downloader.database.models # noqa: F401 — register models with Base.metadata
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
@ -57,13 +62,13 @@ def get_alembic_config():
|
||||
|
||||
config = Config(config_path)
|
||||
config.set_main_option("script_location", scripts_path)
|
||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
||||
config.set_main_option("sqlalchemy.url", comfy.cli_args.args.database_url)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_db_path():
|
||||
url = args.database_url
|
||||
url = comfy.cli_args.args.database_url
|
||||
if url.startswith("sqlite:///"):
|
||||
return url.split("///")[1]
|
||||
else:
|
||||
@ -97,7 +102,7 @@ def _is_memory_db(db_url):
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
db_url = comfy.cli_args.args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
|
||||
if _is_memory_db(db_url):
|
||||
|
||||
220
app/model_downloader/api/routes.py
Normal file
220
app/model_downloader/api/routes.py
Normal file
@ -0,0 +1,220 @@
|
||||
"""aiohttp routes for the download manager.
|
||||
|
||||
Endpoint surface (all under ``/api/download``), mirroring the response
|
||||
envelope used by ``app/assets/api/routes.py``:
|
||||
|
||||
POST /api/download/enqueue
|
||||
GET /api/download
|
||||
POST /api/download/availability
|
||||
POST /api/download/clear
|
||||
POST /api/download/credentials
|
||||
GET /api/download/credentials
|
||||
GET /api/download/credentials/{id}
|
||||
DELETE /api/download/credentials/{id}
|
||||
GET /api/download/{id}
|
||||
DELETE /api/download/{id}
|
||||
POST /api/download/{id}/pause
|
||||
POST /api/download/{id}/resume
|
||||
POST /api/download/{id}/cancel
|
||||
POST /api/download/{id}/priority
|
||||
|
||||
Note on ordering: the static ``credentials`` routes are registered before the
|
||||
dynamic ``/api/download/{id}`` route so a request to ``.../credentials`` is not
|
||||
captured as ``id == "credentials"``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from aiohttp import web
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from app.model_downloader.api import schemas_in, schemas_out
|
||||
from app.model_downloader.credentials.store import (
|
||||
CREDENTIAL_STORE,
|
||||
CredentialValidationError,
|
||||
)
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
|
||||
|
||||
def register_routes(app: web.Application) -> None:
|
||||
"""Wire the download-manager routes into the running aiohttp app."""
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
|
||||
# ----- envelope helpers (same shape as app/assets/api/routes.py) -----
|
||||
|
||||
|
||||
def _error(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
|
||||
return web.json_response(
|
||||
{"error": {"code": code, "message": message, "details": details or {}}},
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def _ok(payload, status: int = 200) -> web.Response:
|
||||
return web.json_response(payload, status=status)
|
||||
|
||||
|
||||
async def _parse(request: web.Request, model: type[BaseModel]):
|
||||
try:
|
||||
raw = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return _error(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
try:
|
||||
return model.model_validate(raw)
|
||||
except ValidationError as ve:
|
||||
return _error(400, "INVALID_BODY", "Validation failed.", {"errors": json.loads(ve.json())})
|
||||
|
||||
|
||||
def _from_download_error(e: DownloadError) -> web.Response:
|
||||
return _error(e.http_status, e.code, e.message)
|
||||
|
||||
|
||||
# ----- downloads: collection + enqueue + availability -----
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/enqueue")
|
||||
async def enqueue(request: web.Request) -> web.Response:
|
||||
parsed = await _parse(request, schemas_in.EnqueueRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
try:
|
||||
download_id = await DOWNLOAD_MANAGER.enqueue(
|
||||
parsed.url,
|
||||
parsed.model_id,
|
||||
priority=parsed.priority,
|
||||
expected_sha256=parsed.expected_sha256,
|
||||
allow_any_extension=parsed.allow_any_extension,
|
||||
credential_id=parsed.credential_id,
|
||||
)
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"download_id": download_id, "accepted": True}, status=202)
|
||||
|
||||
|
||||
@ROUTES.get("/api/download")
|
||||
async def list_downloads(request: web.Request) -> web.Response:
|
||||
return _ok({"downloads": await DOWNLOAD_MANAGER.list()})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/availability")
|
||||
async def availability(request: web.Request) -> web.Response:
|
||||
parsed = await _parse(request, schemas_in.AvailabilityRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
return _ok({"models": await DOWNLOAD_MANAGER.availability(parsed.models)})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/clear")
|
||||
async def clear(request: web.Request) -> web.Response:
|
||||
deleted = await DOWNLOAD_MANAGER.clear()
|
||||
return _ok({"deleted": deleted})
|
||||
|
||||
|
||||
# ----- credentials (secrets are write-only) — must precede /{id} -----
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/credentials")
|
||||
async def upsert_credential(request: web.Request) -> web.Response:
|
||||
parsed = await _parse(request, schemas_in.CredentialUpsertRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
try:
|
||||
view = await CREDENTIAL_STORE.upsert(
|
||||
parsed.host,
|
||||
parsed.secret,
|
||||
auth_scheme=parsed.auth_scheme,
|
||||
header_name=parsed.header_name,
|
||||
query_param=parsed.query_param,
|
||||
label=parsed.label,
|
||||
match_subdomains=parsed.match_subdomains,
|
||||
enabled=parsed.enabled,
|
||||
)
|
||||
except CredentialValidationError as e:
|
||||
return _error(400, "INVALID_CREDENTIAL", str(e))
|
||||
return _ok(schemas_out.credential_to_dict(view), status=201)
|
||||
|
||||
|
||||
@ROUTES.get("/api/download/credentials")
|
||||
async def list_credentials(request: web.Request) -> web.Response:
|
||||
views = await CREDENTIAL_STORE.list()
|
||||
return _ok({"credentials": [schemas_out.credential_to_dict(v) for v in views]})
|
||||
|
||||
|
||||
@ROUTES.get("/api/download/credentials/{id}")
|
||||
async def get_credential(request: web.Request) -> web.Response:
|
||||
view = await CREDENTIAL_STORE.get(request.match_info["id"])
|
||||
if view is None:
|
||||
return _error(404, "NOT_FOUND", "No such credential.")
|
||||
return _ok(schemas_out.credential_to_dict(view))
|
||||
|
||||
|
||||
@ROUTES.delete("/api/download/credentials/{id}")
|
||||
async def delete_credential(request: web.Request) -> web.Response:
|
||||
deleted = await CREDENTIAL_STORE.delete(request.match_info["id"])
|
||||
if not deleted:
|
||||
return _error(404, "NOT_FOUND", "No such credential.")
|
||||
return _ok({"deleted": True})
|
||||
|
||||
|
||||
# ----- single download by id (dynamic; registered last) -----
|
||||
|
||||
|
||||
@ROUTES.get("/api/download/{id}")
|
||||
async def get_download(request: web.Request) -> web.Response:
|
||||
view = await DOWNLOAD_MANAGER.status(request.match_info["id"])
|
||||
if view is None:
|
||||
return _error(404, "NOT_FOUND", "No such download.")
|
||||
return _ok(view)
|
||||
|
||||
|
||||
@ROUTES.delete("/api/download/{id}")
|
||||
async def delete_download(request: web.Request) -> web.Response:
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.delete(request.match_info["id"])
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"deleted": True})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/{id}/pause")
|
||||
async def pause(request: web.Request) -> web.Response:
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.pause(request.match_info["id"])
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"ok": True})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/{id}/resume")
|
||||
async def resume(request: web.Request) -> web.Response:
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.resume(request.match_info["id"])
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"ok": True})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/{id}/cancel")
|
||||
async def cancel(request: web.Request) -> web.Response:
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.cancel(request.match_info["id"])
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"ok": True})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/{id}/priority")
|
||||
async def set_priority(request: web.Request) -> web.Response:
|
||||
parsed = await _parse(request, schemas_in.PriorityRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.set_priority(request.match_info["id"], parsed.priority)
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"ok": True})
|
||||
51
app/model_downloader/api/schemas_in.py
Normal file
51
app/model_downloader/api/schemas_in.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""Request schemas for the download manager API.
|
||||
|
||||
Pydantic enforces shape at the boundary; handlers operate only on validated
|
||||
values past that point.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.model_downloader.constants import AUTH_SCHEME_BEARER
|
||||
|
||||
|
||||
class EnqueueRequest(BaseModel):
|
||||
url: str
|
||||
model_id: str
|
||||
priority: int = 0
|
||||
expected_sha256: Optional[str] = None
|
||||
allow_any_extension: bool = False
|
||||
credential_id: Optional[str] = None
|
||||
|
||||
|
||||
class PriorityRequest(BaseModel):
|
||||
priority: int
|
||||
|
||||
|
||||
class AvailabilityRequest(BaseModel):
|
||||
"""``{model_id: url}`` — the URLs declared in the workflow JSON."""
|
||||
|
||||
models: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class CredentialUpsertRequest(BaseModel):
|
||||
host: str
|
||||
secret: str
|
||||
auth_scheme: str = AUTH_SCHEME_BEARER
|
||||
header_name: Optional[str] = None
|
||||
query_param: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
match_subdomains: bool = False
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EnqueueRequest",
|
||||
"PriorityRequest",
|
||||
"AvailabilityRequest",
|
||||
"CredentialUpsertRequest",
|
||||
]
|
||||
26
app/model_downloader/api/schemas_out.py
Normal file
26
app/model_downloader/api/schemas_out.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""Response helpers for the download manager API.
|
||||
|
||||
The download/status read models are plain dicts produced by the manager. This
|
||||
module only needs to mask credentials for output (the secret is never returned).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.model_downloader.credentials.store import CredentialView
|
||||
|
||||
|
||||
def credential_to_dict(view: CredentialView) -> dict:
|
||||
"""API-safe credential representation — never includes the secret."""
|
||||
return {
|
||||
"id": view.id,
|
||||
"host": view.host,
|
||||
"auth_scheme": view.auth_scheme,
|
||||
"header_name": view.header_name,
|
||||
"query_param": view.query_param,
|
||||
"label": view.label,
|
||||
"match_subdomains": view.match_subdomains,
|
||||
"enabled": view.enabled,
|
||||
"secret_last4": view.secret_last4,
|
||||
"created_at": view.created_at,
|
||||
"updated_at": view.updated_at,
|
||||
}
|
||||
47
app/model_downloader/constants.py
Normal file
47
app/model_downloader/constants.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""Shared constants for the download manager.
|
||||
|
||||
Status values are persisted as TEXT in the ``downloads`` table; keep them
|
||||
stable. The lifecycle is:
|
||||
|
||||
queued -> active -> verifying -> completed
|
||||
| |-> paused -> (resume) -> active
|
||||
| |-> failed (network, retryable) -> queued (backoff)
|
||||
|-> cancelled
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Auth schemes for HostCredential
|
||||
AUTH_SCHEME_BEARER = "bearer"
|
||||
AUTH_SCHEME_HEADER = "header"
|
||||
AUTH_SCHEME_QUERY = "query"
|
||||
AUTH_SCHEMES = (AUTH_SCHEME_BEARER, AUTH_SCHEME_HEADER, AUTH_SCHEME_QUERY)
|
||||
|
||||
# Hosts for which a bearer token can be sourced from the environment when no
|
||||
# stored credential matches. Values are the env var names to try, in order.
|
||||
# Only consulted during auto-resolve for an exact host match over https, so the
|
||||
# same per-hop boundary rules apply (e.g. the token is dropped on a redirect to
|
||||
# a CDN host). Kept here so the host->env-var mapping lives in one place.
|
||||
ENV_TOKEN_HOSTS = {
|
||||
"huggingface.co": ("HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"),
|
||||
}
|
||||
|
||||
|
||||
class DownloadStatus:
|
||||
QUEUED = "queued"
|
||||
ACTIVE = "active"
|
||||
PAUSED = "paused"
|
||||
VERIFYING = "verifying"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
#: States from which a worker is doing (or about to do) network I/O.
|
||||
LIVE = (QUEUED, ACTIVE, VERIFYING)
|
||||
#: Terminal states — the job will not transition again on its own.
|
||||
TERMINAL = (COMPLETED, FAILED, CANCELLED)
|
||||
|
||||
|
||||
# Default temp-file suffix. Distinctive so the startup orphan sweep only
|
||||
# removes files THIS subsystem created, never unrelated *.tmp files.
|
||||
TMP_SUFFIX = ".comfy-download.part"
|
||||
111
app/model_downloader/credentials/resolver.py
Normal file
111
app/model_downloader/credentials/resolver.py
Normal file
@ -0,0 +1,111 @@
|
||||
"""Turn a stored credential into a per-hop request modifier (PRD section 9.4.2).
|
||||
|
||||
The critical rule: a credential is only ever attached when *the current hop's
|
||||
host* matches a stored credential, and only over https. This is recomputed
|
||||
from scratch on every redirect hop, so a token bound to ``huggingface.co`` is
|
||||
silently dropped when the request is redirected to a presigned CDN host —
|
||||
which is exactly what these hubs expect.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode, urlsplit, urlunsplit
|
||||
|
||||
from app.model_downloader.constants import (
|
||||
AUTH_SCHEME_BEARER,
|
||||
AUTH_SCHEME_HEADER,
|
||||
AUTH_SCHEME_QUERY,
|
||||
ENV_TOKEN_HOSTS,
|
||||
)
|
||||
from app.model_downloader.credentials.store import normalize_host
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.database.models import HostCredential
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestAuth:
|
||||
"""How to modify a single request to carry a credential."""
|
||||
|
||||
headers: dict[str, str] = field(default_factory=dict)
|
||||
query: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def apply_to_url(self, url: str) -> str:
|
||||
if not self.query:
|
||||
return url
|
||||
parts = urlsplit(url)
|
||||
# Append only the credential params, leaving the original query string
|
||||
# (including any repeated keys and existing encoding) untouched.
|
||||
creds = urlencode(self.query)
|
||||
query = f"{parts.query}&{creds}" if parts.query else creds
|
||||
return urlunsplit(parts._replace(query=query))
|
||||
|
||||
|
||||
def _matches(cred: HostCredential, hop_host: str) -> bool:
|
||||
cred_host = cred.host
|
||||
if hop_host == cred_host:
|
||||
return True
|
||||
if cred.match_subdomains:
|
||||
# Label-boundary suffix: api.example.com matches example.com, but
|
||||
# evil-example.com does NOT.
|
||||
return hop_host.endswith("." + cred_host)
|
||||
return False
|
||||
|
||||
|
||||
def _build_auth(cred: HostCredential) -> RequestAuth:
|
||||
if cred.auth_scheme == AUTH_SCHEME_BEARER:
|
||||
return RequestAuth(headers={"Authorization": f"Bearer {cred.secret}"})
|
||||
if cred.auth_scheme == AUTH_SCHEME_HEADER:
|
||||
name = cred.header_name or "Authorization"
|
||||
return RequestAuth(headers={name: cred.secret})
|
||||
if cred.auth_scheme == AUTH_SCHEME_QUERY and cred.query_param:
|
||||
return RequestAuth(query={cred.query_param: cred.secret})
|
||||
return RequestAuth()
|
||||
|
||||
|
||||
def _resolve_sync(
|
||||
host: str, scheme: str, explicit_credential_id: Optional[str]
|
||||
) -> Optional[RequestAuth]:
|
||||
# Never attach a secret over a non-https hop (PRD section 9.4.2).
|
||||
if scheme.lower() != "https":
|
||||
return None
|
||||
hop_host = normalize_host(host)
|
||||
if not hop_host:
|
||||
return None
|
||||
|
||||
if explicit_credential_id is not None:
|
||||
cred = queries.get_credential(explicit_credential_id)
|
||||
# An explicit credential is still subject to the per-hop host check —
|
||||
# it is not forced onto a non-matching host.
|
||||
if cred is None or not cred.enabled or not _matches(cred, hop_host):
|
||||
return None
|
||||
return _build_auth(cred)
|
||||
|
||||
# Auto-resolve: exact host first, then any subdomain-matching credential.
|
||||
cred = queries.get_credential_by_host(hop_host)
|
||||
if cred is not None and cred.enabled:
|
||||
return _build_auth(cred)
|
||||
for sub in queries.list_subdomain_credentials():
|
||||
if sub.enabled and _matches(sub, hop_host):
|
||||
return _build_auth(sub)
|
||||
|
||||
# Env fallback: only for an exact host match, and only after the DB lookups
|
||||
# miss, so a user-set credential always takes precedence. The token is never
|
||||
# persisted; it is read fresh from the environment on each hop.
|
||||
for var in ENV_TOKEN_HOSTS.get(hop_host, ()):
|
||||
token = os.environ.get(var)
|
||||
if token:
|
||||
return RequestAuth(headers={"Authorization": f"Bearer {token}"})
|
||||
return None
|
||||
|
||||
|
||||
async def resolve_auth_for_hop(
|
||||
host: str, scheme: str, *, explicit_credential_id: Optional[str] = None
|
||||
) -> Optional[RequestAuth]:
|
||||
"""Resolve the credential (if any) to attach for one request hop."""
|
||||
return await asyncio.to_thread(
|
||||
_resolve_sync, host, scheme, explicit_credential_id
|
||||
)
|
||||
141
app/model_downloader/credentials/store.py
Normal file
141
app/model_downloader/credentials/store.py
Normal file
@ -0,0 +1,141 @@
|
||||
"""The credential store: one API key per host.
|
||||
|
||||
Secrets are write-only over the API — :class:`CredentialView` carries only
|
||||
masked metadata (``secret_last4`` + scheme + label), never the secret itself.
|
||||
At-rest protection for v1 is filesystem permissions on the shared DB (the DB
|
||||
is the trust boundary); encryption-at-rest is a noted future seam.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from app.model_downloader.constants import (
|
||||
AUTH_SCHEME_BEARER,
|
||||
AUTH_SCHEME_HEADER,
|
||||
AUTH_SCHEME_QUERY,
|
||||
AUTH_SCHEMES,
|
||||
)
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.database.models import HostCredential
|
||||
|
||||
|
||||
def normalize_host(host: str) -> str:
|
||||
"""Lowercase, strip port, IDNA-encode."""
|
||||
if not host:
|
||||
return ""
|
||||
host = host.strip()
|
||||
if "://" in host: # a full URL was pasted — extract just the host
|
||||
host = urlsplit(host).hostname or ""
|
||||
host = host.lower()
|
||||
if host.startswith("[") and "]" in host: # bracketed IPv6 literal
|
||||
host = host[1 : host.index("]")]
|
||||
elif host.count(":") == 1: # host:port (not IPv6)
|
||||
host = host.split(":", 1)[0]
|
||||
try:
|
||||
host = host.encode("idna").decode("ascii")
|
||||
except (UnicodeError, ValueError):
|
||||
pass
|
||||
return host
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CredentialView:
|
||||
"""Masked, API-safe view of a credential — never includes the secret."""
|
||||
|
||||
id: str
|
||||
host: str
|
||||
auth_scheme: str
|
||||
header_name: Optional[str]
|
||||
query_param: Optional[str]
|
||||
label: Optional[str]
|
||||
match_subdomains: bool
|
||||
enabled: bool
|
||||
secret_last4: Optional[str]
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
def _to_view(row: HostCredential) -> CredentialView:
|
||||
return CredentialView(
|
||||
id=row.id,
|
||||
host=row.host,
|
||||
auth_scheme=row.auth_scheme,
|
||||
header_name=row.header_name,
|
||||
query_param=row.query_param,
|
||||
label=row.label,
|
||||
match_subdomains=row.match_subdomains,
|
||||
enabled=row.enabled,
|
||||
secret_last4=row.secret_last4,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class CredentialValidationError(ValueError):
|
||||
"""A credential upsert had inconsistent fields."""
|
||||
|
||||
|
||||
class CredentialStore:
|
||||
"""Async facade over the ``host_credentials`` table.
|
||||
|
||||
DB access is synchronous (SQLite) and offloaded via ``asyncio.to_thread``.
|
||||
"""
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
host: str,
|
||||
secret: str,
|
||||
*,
|
||||
auth_scheme: str = AUTH_SCHEME_BEARER,
|
||||
header_name: Optional[str] = None,
|
||||
query_param: Optional[str] = None,
|
||||
label: Optional[str] = None,
|
||||
match_subdomains: bool = False,
|
||||
enabled: bool = True,
|
||||
) -> CredentialView:
|
||||
host = normalize_host(host)
|
||||
if not host:
|
||||
raise CredentialValidationError("host is required")
|
||||
if not secret:
|
||||
raise CredentialValidationError("secret is required")
|
||||
if auth_scheme not in AUTH_SCHEMES:
|
||||
raise CredentialValidationError(
|
||||
f"auth_scheme must be one of {AUTH_SCHEMES}, got {auth_scheme!r}"
|
||||
)
|
||||
if auth_scheme == AUTH_SCHEME_HEADER and not header_name:
|
||||
header_name = "Authorization"
|
||||
if auth_scheme == AUTH_SCHEME_QUERY and not query_param:
|
||||
raise CredentialValidationError(
|
||||
"query_param is required when auth_scheme='query'"
|
||||
)
|
||||
values = {
|
||||
"host": host,
|
||||
"secret": secret,
|
||||
"secret_last4": secret[-4:] if len(secret) > 4 else None,
|
||||
"auth_scheme": auth_scheme,
|
||||
"header_name": header_name,
|
||||
"query_param": query_param,
|
||||
"label": label,
|
||||
"match_subdomains": match_subdomains,
|
||||
"enabled": enabled,
|
||||
}
|
||||
row = await asyncio.to_thread(queries.upsert_credential, values)
|
||||
return _to_view(row)
|
||||
|
||||
async def list(self) -> list[CredentialView]:
|
||||
rows = await asyncio.to_thread(queries.list_credentials)
|
||||
return [_to_view(r) for r in rows]
|
||||
|
||||
async def get(self, credential_id: str) -> Optional[CredentialView]:
|
||||
row = await asyncio.to_thread(queries.get_credential, credential_id)
|
||||
return _to_view(row) if row is not None else None
|
||||
|
||||
async def delete(self, credential_id: str) -> bool:
|
||||
return await asyncio.to_thread(queries.delete_credential, credential_id)
|
||||
|
||||
|
||||
CREDENTIAL_STORE = CredentialStore()
|
||||
173
app/model_downloader/database/models.py
Normal file
173
app/model_downloader/database/models.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""SQLAlchemy models for the download manager.
|
||||
|
||||
Three tables:
|
||||
|
||||
- ``downloads`` one row per requested file (job + queue state).
|
||||
- ``download_segments`` per-segment byte progress, for segmented resume.
|
||||
- ``host_credentials`` one API key per host, reused across downloads.
|
||||
|
||||
On completion a finished file is registered into the assets catalog;
|
||||
``downloads`` is kept only as job history.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database.models import Base
|
||||
|
||||
|
||||
def _uuid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _now() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
class Download(Base):
|
||||
__tablename__ = "downloads"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||
# Original requested URL and the final URL after validated redirects.
|
||||
url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
final_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# Canonical "<directory>/<filename>" identifier (resolved via folder_paths).
|
||||
model_id: Mapped[str] = mapped_column(String(1024), nullable=False)
|
||||
# Final on-disk location and the .part write target.
|
||||
dest_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
temp_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
|
||||
status: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
priority: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
total_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
bytes_done: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
|
||||
etag: Mapped[str | None] = mapped_column(String(512), nullable=True)
|
||||
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
accept_ranges: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Optional hub-provided checksum to verify against (NOT the dedup key).
|
||||
expected_sha256: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
|
||||
# Explicit credential override; otherwise auto-resolved by host.
|
||||
# RESTRICT keeps a credential from being deleted while a download references it.
|
||||
credential_id: Mapped[str | None] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("host_credentials.id", ondelete="RESTRICT"),
|
||||
nullable=True,
|
||||
)
|
||||
allow_any_extension: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
# How many retryable failures we have seen (for backoff capping).
|
||||
attempts: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
error: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[int] = mapped_column(BigInteger, nullable=False, default=_now)
|
||||
updated_at: Mapped[int] = mapped_column(
|
||||
BigInteger, nullable=False, default=_now, onupdate=_now
|
||||
)
|
||||
|
||||
segments: Mapped[list[DownloadSegment]] = relationship(
|
||||
"DownloadSegment",
|
||||
back_populates="download",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
order_by="DownloadSegment.idx",
|
||||
)
|
||||
|
||||
credential: Mapped[HostCredential | None] = relationship(
|
||||
"HostCredential", back_populates="downloads"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_downloads_status", "status"),
|
||||
Index("ix_downloads_priority", "priority"),
|
||||
Index("ix_downloads_model_id", "model_id"),
|
||||
CheckConstraint("bytes_done >= 0", name="ck_downloads_bytes_done_nonneg"),
|
||||
CheckConstraint(
|
||||
"total_bytes IS NULL OR total_bytes >= 0",
|
||||
name="ck_downloads_total_bytes_nonneg",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Download id={self.id} model_id={self.model_id!r} status={self.status}>"
|
||||
|
||||
|
||||
class DownloadSegment(Base):
|
||||
__tablename__ = "download_segments"
|
||||
|
||||
download_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("downloads.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
idx: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
start_offset: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
end_offset: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
bytes_done: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
|
||||
download: Mapped[Download] = relationship("Download", back_populates="segments")
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("bytes_done >= 0", name="ck_segments_bytes_done_nonneg"),
|
||||
CheckConstraint("end_offset >= start_offset", name="ck_segments_range"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<DownloadSegment {self.download_id}#{self.idx} "
|
||||
f"{self.start_offset}-{self.end_offset} done={self.bytes_done}>"
|
||||
)
|
||||
|
||||
|
||||
class HostCredential(Base):
|
||||
__tablename__ = "host_credentials"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||
# Normalized lowercase hostname, e.g. "civitai.com".
|
||||
host: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
match_subdomains: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
label: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
auth_scheme: Mapped[str] = mapped_column(
|
||||
String(16), nullable=False, default="bearer"
|
||||
)
|
||||
header_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
query_param: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
# The API key itself. Write-only over the API; never returned. See PRD 9.4.4.
|
||||
secret: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
secret_last4: Mapped[str | None] = mapped_column(String(4), nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
created_at: Mapped[int] = mapped_column(BigInteger, nullable=False, default=_now)
|
||||
updated_at: Mapped[int] = mapped_column(
|
||||
BigInteger, nullable=False, default=_now, onupdate=_now
|
||||
)
|
||||
|
||||
downloads: Mapped[list[Download]] = relationship(
|
||||
"Download", back_populates="credential"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_host_credentials_host", "host", unique=True),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<HostCredential id={self.id} host={self.host!r} scheme={self.auth_scheme}>"
|
||||
272
app/model_downloader/database/queries.py
Normal file
272
app/model_downloader/database/queries.py
Normal file
@ -0,0 +1,272 @@
|
||||
"""Synchronous DB access for the download manager.
|
||||
|
||||
All functions open their own short-lived session via ``create_session`` and
|
||||
commit before returning, mirroring ``app/assets`` usage. They are blocking
|
||||
(SQLite) and should be called from async code through ``asyncio.to_thread``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database.models import (
|
||||
Download,
|
||||
DownloadSegment,
|
||||
HostCredential,
|
||||
)
|
||||
|
||||
|
||||
# ----- downloads -----
|
||||
|
||||
|
||||
def insert_download(values: dict) -> None:
|
||||
with create_session() as session:
|
||||
session.add(Download(**values))
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_download(download_id: str) -> Optional[Download]:
|
||||
with create_session() as session:
|
||||
row = session.get(Download, download_id)
|
||||
if row is not None:
|
||||
session.expunge_all()
|
||||
return row
|
||||
|
||||
|
||||
def list_downloads() -> list[Download]:
|
||||
with create_session() as session:
|
||||
rows = list(
|
||||
session.execute(
|
||||
select(Download).order_by(Download.created_at.desc())
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return rows
|
||||
|
||||
|
||||
def list_segments(download_id: str) -> list[DownloadSegment]:
|
||||
with create_session() as session:
|
||||
rows = list(
|
||||
session.execute(
|
||||
select(DownloadSegment)
|
||||
.where(DownloadSegment.download_id == download_id)
|
||||
.order_by(DownloadSegment.idx)
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return rows
|
||||
|
||||
|
||||
def update_download(download_id: str, **fields) -> None:
|
||||
if not fields:
|
||||
return
|
||||
fields.setdefault("updated_at", int(time.time()))
|
||||
with create_session() as session:
|
||||
row = session.get(Download, download_id)
|
||||
if row is None:
|
||||
return
|
||||
for key, value in fields.items():
|
||||
setattr(row, key, value)
|
||||
session.commit()
|
||||
|
||||
|
||||
def delete_download(download_id: str) -> None:
|
||||
with create_session() as session:
|
||||
row = session.get(Download, download_id)
|
||||
if row is not None:
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
|
||||
|
||||
def delete_downloads(download_ids: list[str]) -> int:
|
||||
"""Delete many downloads in one transaction; returns the number removed.
|
||||
|
||||
Uses a bulk ``DELETE ... WHERE id IN (...)``. Segment rows are removed by
|
||||
the ``ON DELETE CASCADE`` foreign key (SQLite ``PRAGMA foreign_keys=ON`` is
|
||||
set in ``app/database/db.py``), so this stays consistent without loading the
|
||||
ORM relationship.
|
||||
"""
|
||||
if not download_ids:
|
||||
return 0
|
||||
with create_session() as session:
|
||||
result = session.execute(
|
||||
delete(Download).where(Download.id.in_(download_ids))
|
||||
)
|
||||
session.commit()
|
||||
return result.rowcount or 0
|
||||
|
||||
|
||||
def replace_segments(download_id: str, segments: list[dict]) -> None:
|
||||
"""Atomically replace the segment plan for a download."""
|
||||
with create_session() as session:
|
||||
session.query(DownloadSegment).filter(
|
||||
DownloadSegment.download_id == download_id
|
||||
).delete()
|
||||
for seg in segments:
|
||||
session.add(DownloadSegment(download_id=download_id, **seg))
|
||||
session.commit()
|
||||
|
||||
|
||||
def update_segment_progress(download_id: str, idx: int, bytes_done: int) -> None:
|
||||
with create_session() as session:
|
||||
row = session.get(DownloadSegment, {"download_id": download_id, "idx": idx})
|
||||
if row is None:
|
||||
return
|
||||
row.bytes_done = bytes_done
|
||||
session.commit()
|
||||
|
||||
|
||||
def list_queued_downloads() -> list[Download]:
|
||||
"""Queued rows ordered for admission (priority desc, then FIFO)."""
|
||||
with create_session() as session:
|
||||
rows = list(
|
||||
session.execute(
|
||||
select(Download)
|
||||
.where(Download.status == DownloadStatus.QUEUED)
|
||||
.order_by(Download.priority.desc(), Download.created_at.asc())
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return rows
|
||||
|
||||
|
||||
def reconcile_live_downloads() -> list[Download]:
|
||||
"""Reset any ``active``/``verifying`` rows left by a previous run.
|
||||
|
||||
On a clean restart there can be no live worker, so anything still marked
|
||||
live is stale. Move it back to ``queued`` (offsets are preserved on the
|
||||
segment rows) so the scheduler re-admits it. Returns the rows that should
|
||||
be re-queued by the scheduler (queued + paused).
|
||||
"""
|
||||
with create_session() as session:
|
||||
stale = list(
|
||||
session.execute(
|
||||
select(Download).where(
|
||||
Download.status.in_([DownloadStatus.ACTIVE, DownloadStatus.VERIFYING])
|
||||
)
|
||||
).scalars()
|
||||
)
|
||||
now = int(time.time())
|
||||
for row in stale:
|
||||
row.status = DownloadStatus.QUEUED
|
||||
row.updated_at = now
|
||||
session.commit()
|
||||
|
||||
resumable = list(
|
||||
session.execute(
|
||||
select(Download)
|
||||
.where(Download.status == DownloadStatus.QUEUED)
|
||||
.order_by(Download.priority.desc(), Download.created_at.asc())
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return resumable
|
||||
|
||||
|
||||
# ----- host credentials -----
|
||||
|
||||
|
||||
def get_credential(credential_id: str) -> Optional[HostCredential]:
|
||||
with create_session() as session:
|
||||
row = session.get(HostCredential, credential_id)
|
||||
if row is not None:
|
||||
session.expunge_all()
|
||||
return row
|
||||
|
||||
|
||||
def get_credential_by_host(host: str) -> Optional[HostCredential]:
|
||||
with create_session() as session:
|
||||
row = (
|
||||
session.execute(
|
||||
select(HostCredential).where(HostCredential.host == host).limit(1)
|
||||
)
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
if row is not None:
|
||||
session.expunge_all()
|
||||
return row
|
||||
|
||||
|
||||
def list_credentials() -> list[HostCredential]:
|
||||
with create_session() as session:
|
||||
rows = list(
|
||||
session.execute(
|
||||
select(HostCredential).order_by(HostCredential.host)
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return rows
|
||||
|
||||
|
||||
def list_subdomain_credentials() -> list[HostCredential]:
|
||||
"""Credentials that opted into subdomain matching, for suffix checks."""
|
||||
with create_session() as session:
|
||||
rows = list(
|
||||
session.execute(
|
||||
select(HostCredential).where(HostCredential.match_subdomains.is_(True))
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return rows
|
||||
|
||||
|
||||
def upsert_credential(values: dict) -> HostCredential:
|
||||
"""Insert or update a credential keyed by ``host``.
|
||||
|
||||
Callers can target the same host concurrently (each runs in its own
|
||||
short-lived session on a separate connection), so the read-then-write here
|
||||
can race: two callers both see no existing row and both attempt an insert.
|
||||
The ``host`` column is uniquely indexed, so the loser's insert raises
|
||||
``IntegrityError``. We recover by rolling back and retrying, at which point
|
||||
the now-committed row is found and updated in place, letting concurrent
|
||||
calls converge instead of failing or creating duplicates.
|
||||
"""
|
||||
host = values["host"]
|
||||
now = int(time.time())
|
||||
last_error: IntegrityError | None = None
|
||||
for _ in range(2):
|
||||
with create_session() as session:
|
||||
row = (
|
||||
session.execute(
|
||||
select(HostCredential).where(HostCredential.host == host).limit(1)
|
||||
)
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
if row is None:
|
||||
row = HostCredential(**values)
|
||||
row.created_at = now
|
||||
row.updated_at = now
|
||||
session.add(row)
|
||||
else:
|
||||
for key, value in values.items():
|
||||
setattr(row, key, value)
|
||||
row.updated_at = now
|
||||
try:
|
||||
session.commit()
|
||||
except IntegrityError as exc:
|
||||
session.rollback()
|
||||
last_error = exc
|
||||
continue
|
||||
session.refresh(row)
|
||||
session.expunge(row)
|
||||
return row
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
|
||||
|
||||
def delete_credential(credential_id: str) -> bool:
|
||||
with create_session() as session:
|
||||
row = session.get(HostCredential, credential_id)
|
||||
if row is None:
|
||||
return False
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
return True
|
||||
615
app/model_downloader/engine/job.py
Normal file
615
app/model_downloader/engine/job.py
Normal file
@ -0,0 +1,615 @@
|
||||
"""The per-download worker.
|
||||
|
||||
One :class:`DownloadJob` drives a single file from probe to verified, cataloged
|
||||
completion. It supports cooperative pause / resume / cancel, segmented
|
||||
multi-connection transfer with positioned writes, and a verification gate
|
||||
(size + structural + optional sha256) before the atomic rename into place.
|
||||
|
||||
Control is cooperative: external callers flip ``_control`` via
|
||||
:meth:`request_pause` / :meth:`request_cancel`; segment loops observe it between
|
||||
chunks and raise, which unwinds cleanly and persists resume offsets.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Optional
|
||||
|
||||
from comfy.cli_args import args
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.engine.planner import (
|
||||
effective_segment_count,
|
||||
plan_segments,
|
||||
)
|
||||
from app.model_downloader.engine.writer import FileWriter
|
||||
from app.model_downloader.net.http import open_validated, redact_url
|
||||
from app.model_downloader.net.probe import probe
|
||||
from app.model_downloader.verify import checksum, dedup, structural
|
||||
|
||||
_RETRYABLE_STATUSES = {408, 429, 500, 502, 503, 504}
|
||||
_PERSIST_INTERVAL = 2.0 # seconds between throttled progress persists
|
||||
|
||||
|
||||
class Paused(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Cancelled(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RemoteChanged(Exception):
|
||||
"""The remote file changed under a resume (got 200 where 206 expected)."""
|
||||
|
||||
|
||||
class RetryableError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FatalError(Exception):
|
||||
"""Non-retryable: 4xx, checksum mismatch, structural failure, gated, etc."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SegmentRuntime:
|
||||
idx: int
|
||||
start: int
|
||||
end: int # inclusive; may be -1 for unknown-size single stream
|
||||
bytes_done: int = 0
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return self.end - self.start + 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeState:
|
||||
download_id: str
|
||||
model_id: str
|
||||
url: str
|
||||
priority: int
|
||||
status: str
|
||||
total_bytes: Optional[int] = None
|
||||
bytes_done: int = 0
|
||||
error: Optional[str] = None
|
||||
segments: list[SegmentRuntime] = field(default_factory=list)
|
||||
started_at: float = field(default_factory=time.monotonic)
|
||||
_last_bytes: int = 0
|
||||
_last_time: float = field(default_factory=time.monotonic)
|
||||
speed_bps: float = 0.0
|
||||
|
||||
@property
|
||||
def progress(self) -> Optional[float]:
|
||||
if not self.total_bytes:
|
||||
return None
|
||||
return min(1.0, self.bytes_done / self.total_bytes)
|
||||
|
||||
@property
|
||||
def eta_seconds(self) -> Optional[float]:
|
||||
if not self.total_bytes or self.speed_bps <= 0:
|
||||
return None
|
||||
remaining = max(0, self.total_bytes - self.bytes_done)
|
||||
return remaining / self.speed_bps
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobSpec:
|
||||
download_id: str
|
||||
url: str
|
||||
model_id: str
|
||||
dest_path: str
|
||||
temp_path: str
|
||||
priority: int = 0
|
||||
credential_id: Optional[str] = None
|
||||
expected_sha256: Optional[str] = None
|
||||
allow_any_extension: bool = False
|
||||
etag: Optional[str] = None
|
||||
attempts: int = 0
|
||||
|
||||
|
||||
class DownloadJob:
|
||||
def __init__(
|
||||
self, spec: JobSpec, notify_cb: Optional[Callable[[str], None]] = None
|
||||
) -> None:
|
||||
self.spec = spec
|
||||
self._notify = notify_cb
|
||||
self._control = "run" # run | pause | cancel
|
||||
self.state = RuntimeState(
|
||||
download_id=spec.download_id,
|
||||
model_id=spec.model_id,
|
||||
url=spec.url,
|
||||
priority=spec.priority,
|
||||
status=DownloadStatus.QUEUED,
|
||||
)
|
||||
self._writer: Optional[FileWriter] = None
|
||||
self._etag: Optional[str] = spec.etag
|
||||
self._last_persist = 0.0
|
||||
|
||||
# ----- external control -----
|
||||
|
||||
def request_pause(self) -> None:
|
||||
if self._control == "run":
|
||||
self._control = "pause"
|
||||
|
||||
def request_cancel(self) -> None:
|
||||
self._control = "cancel"
|
||||
|
||||
def _check_control(self) -> None:
|
||||
if self._control == "cancel":
|
||||
raise Cancelled()
|
||||
if self._control == "pause":
|
||||
raise Paused()
|
||||
|
||||
# ----- lifecycle -----
|
||||
|
||||
async def run(self) -> str:
|
||||
"""Run to a terminal/paused state; returns the final status string."""
|
||||
await self._set_status(DownloadStatus.ACTIVE, error=None)
|
||||
try:
|
||||
pr = await self._probe_and_plan()
|
||||
await self._transfer(pr)
|
||||
await self._finalize()
|
||||
await self._set_status(DownloadStatus.COMPLETED)
|
||||
except Paused:
|
||||
await self._persist_progress(force=True)
|
||||
await self._set_status(DownloadStatus.PAUSED)
|
||||
except Cancelled:
|
||||
await self._close_writer()
|
||||
self._remove_temp()
|
||||
await self._set_status(DownloadStatus.CANCELLED)
|
||||
except RemoteChanged:
|
||||
await self._reset_for_restart()
|
||||
await self._set_status(
|
||||
DownloadStatus.QUEUED, error="remote file changed; restarting"
|
||||
)
|
||||
except RetryableError as e:
|
||||
await self._persist_progress(force=True)
|
||||
await self._set_status(DownloadStatus.QUEUED, error=str(e))
|
||||
except FatalError as e:
|
||||
await self._close_writer()
|
||||
self._remove_temp()
|
||||
await self._set_status(DownloadStatus.FAILED, error=str(e))
|
||||
except Exception as e: # unexpected -> treat as retryable
|
||||
logging.warning(
|
||||
"[model_downloader] %s unexpected error: %s",
|
||||
self.spec.model_id, e, exc_info=True,
|
||||
)
|
||||
await self._persist_progress(force=True)
|
||||
await self._set_status(DownloadStatus.QUEUED, error=f"{type(e).__name__}: {e}")
|
||||
finally:
|
||||
await self._close_writer()
|
||||
return self.state.status
|
||||
|
||||
# ----- probe + plan -----
|
||||
|
||||
async def _probe_and_plan(self):
|
||||
pr = await probe(self.spec.url, credential_id=self.spec.credential_id)
|
||||
if not pr.ok:
|
||||
if pr.gated:
|
||||
raise FatalError(
|
||||
f"{redact_url(self.spec.url)} requires authentication. Add an API key for "
|
||||
f"this host at /api/download/credentials and retry."
|
||||
)
|
||||
if pr.status == 0 or pr.status in _RETRYABLE_STATUSES:
|
||||
raise RetryableError(pr.error or "probe failed")
|
||||
raise FatalError(pr.error or f"probe returned HTTP {pr.status}")
|
||||
|
||||
max_bytes = self._max_download_bytes()
|
||||
if max_bytes is not None and pr.total_bytes is not None and pr.total_bytes > max_bytes:
|
||||
raise FatalError(
|
||||
f"file size {pr.total_bytes} exceeds the maximum allowed "
|
||||
f"download size {max_bytes} (--download-max-bytes)"
|
||||
)
|
||||
|
||||
self._etag = pr.etag or self._etag
|
||||
self.state.total_bytes = pr.total_bytes
|
||||
await asyncio.to_thread(
|
||||
queries.update_download,
|
||||
self.spec.download_id,
|
||||
final_url=pr.final_url,
|
||||
total_bytes=pr.total_bytes,
|
||||
accept_ranges=pr.accept_ranges,
|
||||
etag=pr.etag,
|
||||
last_modified=pr.last_modified,
|
||||
)
|
||||
|
||||
seg_count = effective_segment_count(
|
||||
pr.total_bytes, pr.accept_ranges, max(1, args.download_segments)
|
||||
)
|
||||
existing = await asyncio.to_thread(queries.list_segments, self.spec.download_id)
|
||||
can_resume_segmented = (
|
||||
seg_count > 1
|
||||
and existing
|
||||
and pr.total_bytes is not None
|
||||
and existing[-1].end_offset == pr.total_bytes - 1
|
||||
)
|
||||
if can_resume_segmented and not self._segmented_part_valid(pr.total_bytes):
|
||||
# The persisted per-segment offsets describe bytes in a preallocated
|
||||
# .part that is now gone or the wrong size (e.g. the partial of a
|
||||
# failed download was swept on restart, or removed by a fatal
|
||||
# error). Trusting them would skip already-"complete" segments and
|
||||
# leave zero-filled holes. Discard the offsets and re-plan fresh.
|
||||
logging.info(
|
||||
"[model_downloader] %s discarding segmented resume offsets "
|
||||
"(preallocated .part missing or wrong size); restarting",
|
||||
self.spec.model_id,
|
||||
)
|
||||
self._remove_temp()
|
||||
await asyncio.to_thread(
|
||||
queries.replace_segments, self.spec.download_id, []
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, self.spec.download_id, bytes_done=0
|
||||
)
|
||||
existing = []
|
||||
can_resume_segmented = False
|
||||
|
||||
if can_resume_segmented:
|
||||
# Resume an existing segmented plan.
|
||||
self.state.segments = [
|
||||
SegmentRuntime(s.idx, s.start_offset, s.end_offset, s.bytes_done)
|
||||
for s in existing
|
||||
]
|
||||
elif seg_count > 1 and pr.total_bytes is not None:
|
||||
plans = plan_segments(pr.total_bytes, seg_count)
|
||||
await asyncio.to_thread(
|
||||
queries.replace_segments,
|
||||
self.spec.download_id,
|
||||
[
|
||||
{"idx": p.idx, "start_offset": p.start, "end_offset": p.end, "bytes_done": 0}
|
||||
for p in plans
|
||||
],
|
||||
)
|
||||
self.state.segments = [SegmentRuntime(p.idx, p.start, p.end, 0) for p in plans]
|
||||
else:
|
||||
# Single-stream: one logical segment; bytes_done tracked on the row.
|
||||
row = await asyncio.to_thread(queries.get_download, self.spec.download_id)
|
||||
resume_from = row.bytes_done if row else 0
|
||||
end = (pr.total_bytes - 1) if pr.total_bytes else -1
|
||||
# ``row.bytes_done`` may be the SUM of per-segment offsets from a
|
||||
# prior segmented run (a preallocated, non-contiguous .part). A
|
||||
# single-stream resume writes a contiguous prefix, so the offset is
|
||||
# only trustworthy when the on-disk file is exactly that many
|
||||
# contiguous bytes. This guards the case where a download that ran
|
||||
# segmented now resolves to one segment (server dropped
|
||||
# Accept-Ranges, or --download-segments was lowered between runs):
|
||||
# resuming over non-contiguous data would corrupt the output.
|
||||
if resume_from > 0 and not self._contiguous_prefix_valid(resume_from):
|
||||
logging.info(
|
||||
"[model_downloader] %s discarding untrusted resume offset "
|
||||
"%d (on-disk .part not a contiguous prefix); restarting",
|
||||
self.spec.model_id, resume_from,
|
||||
)
|
||||
resume_from = 0
|
||||
self._remove_temp()
|
||||
if await asyncio.to_thread(queries.list_segments, self.spec.download_id):
|
||||
await asyncio.to_thread(
|
||||
queries.replace_segments, self.spec.download_id, []
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, self.spec.download_id, bytes_done=0
|
||||
)
|
||||
self.state.segments = [SegmentRuntime(0, 0, end, resume_from)]
|
||||
self._recompute_bytes_done()
|
||||
return pr
|
||||
|
||||
# ----- transfer -----
|
||||
|
||||
async def _transfer(self, pr) -> None:
|
||||
self._writer = FileWriter(self.spec.temp_path)
|
||||
await self._writer.open()
|
||||
|
||||
segmented = len(self.state.segments) > 1
|
||||
if segmented and self.state.total_bytes:
|
||||
await self._writer.preallocate(self.state.total_bytes)
|
||||
await self._run_segmented()
|
||||
else:
|
||||
await self._run_single()
|
||||
|
||||
await self._writer.flush()
|
||||
|
||||
async def _run_segmented(self) -> None:
|
||||
pending = [
|
||||
asyncio.ensure_future(self._run_segment(seg))
|
||||
for seg in self.state.segments
|
||||
if seg.bytes_done < seg.length
|
||||
]
|
||||
if not pending:
|
||||
return
|
||||
done, not_done = await asyncio.wait(
|
||||
pending, return_when=asyncio.FIRST_EXCEPTION
|
||||
)
|
||||
first_exc: Optional[BaseException] = None
|
||||
for task in done:
|
||||
exc = task.exception()
|
||||
if exc is not None and first_exc is None:
|
||||
first_exc = exc
|
||||
if first_exc is not None:
|
||||
for task in not_done:
|
||||
task.cancel()
|
||||
await asyncio.gather(*not_done, return_exceptions=True)
|
||||
raise first_exc
|
||||
|
||||
async def _run_segment(self, seg: SegmentRuntime) -> None:
|
||||
offset = seg.start + seg.bytes_done
|
||||
headers = {
|
||||
"Range": f"bytes={offset}-{seg.end}",
|
||||
"Accept-Encoding": "identity",
|
||||
}
|
||||
if self._etag:
|
||||
headers["If-Range"] = self._etag
|
||||
async with open_validated(
|
||||
"GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers
|
||||
) as (resp, _final):
|
||||
if resp.status == 200:
|
||||
# Server ignored the range -> remote changed / no resume support.
|
||||
raise RemoteChanged()
|
||||
if resp.status not in (206,):
|
||||
self._raise_for_status(resp.status)
|
||||
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
|
||||
self._check_control()
|
||||
# Never write past this segment's planned range: a
|
||||
# non-conforming 206 that returns more than the requested
|
||||
# bytes would otherwise overrun adjacent segments and the
|
||||
# preallocated file. Cap the write and abort on overflow.
|
||||
remaining = seg.length - seg.bytes_done
|
||||
if remaining <= 0:
|
||||
raise FatalError(
|
||||
f"segment {seg.idx}: server returned more than the "
|
||||
f"requested {seg.length} bytes"
|
||||
)
|
||||
overflow = len(chunk) > remaining
|
||||
if overflow:
|
||||
chunk = chunk[:remaining]
|
||||
await self._writer.write_at(offset, chunk)
|
||||
offset += len(chunk)
|
||||
seg.bytes_done += len(chunk)
|
||||
self._recompute_bytes_done()
|
||||
await self._persist_progress()
|
||||
if overflow:
|
||||
raise FatalError(
|
||||
f"segment {seg.idx}: server returned more than the "
|
||||
f"requested {seg.length} bytes"
|
||||
)
|
||||
|
||||
async def _run_single(self) -> None:
|
||||
seg = self.state.segments[0]
|
||||
offset = seg.bytes_done # resume from here for single-stream
|
||||
headers = {"Accept-Encoding": "identity"}
|
||||
if offset > 0:
|
||||
headers["Range"] = f"bytes={offset}-"
|
||||
if self._etag:
|
||||
headers["If-Range"] = self._etag
|
||||
async with open_validated(
|
||||
"GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers
|
||||
) as (resp, _final):
|
||||
if offset > 0 and resp.status == 200:
|
||||
# Resume not honoured -> start over from the beginning. Truncate
|
||||
# the existing partial so stale trailing bytes from the prior
|
||||
# attempt cannot survive past the new (possibly shorter) end.
|
||||
offset = 0
|
||||
seg.bytes_done = 0
|
||||
self.state.bytes_done = 0
|
||||
await self._writer.truncate(0)
|
||||
elif offset > 0 and resp.status != 206:
|
||||
self._raise_for_status(resp.status)
|
||||
elif offset == 0 and resp.status != 200:
|
||||
self._raise_for_status(resp.status)
|
||||
# Byte ceiling for this stream: the known total when the server
|
||||
# reported a size, otherwise the configured maximum download size.
|
||||
# Without a bound, a non-conforming response or an unknown-length
|
||||
# stream (end == -1) that never closes could fill the disk (DoS).
|
||||
limit = (seg.end + 1) if seg.end >= 0 else self._max_download_bytes()
|
||||
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
|
||||
self._check_control()
|
||||
overflow = False
|
||||
if limit is not None:
|
||||
remaining = limit - offset
|
||||
if remaining <= 0:
|
||||
raise FatalError(
|
||||
f"download exceeded the maximum size {limit} bytes"
|
||||
)
|
||||
if len(chunk) > remaining:
|
||||
chunk = chunk[:remaining]
|
||||
overflow = True
|
||||
await self._writer.write_at(offset, chunk)
|
||||
offset += len(chunk)
|
||||
seg.bytes_done = offset
|
||||
self.state.bytes_done = offset
|
||||
await self._persist_progress()
|
||||
if overflow:
|
||||
raise FatalError(
|
||||
f"download exceeded the maximum size {limit} bytes"
|
||||
)
|
||||
|
||||
def _max_download_bytes(self) -> Optional[int]:
|
||||
"""Configured maximum download size in bytes, or ``None`` if disabled."""
|
||||
cap = getattr(args, "download_max_bytes", 0)
|
||||
return cap if cap and cap > 0 else None
|
||||
|
||||
def _raise_for_status(self, status: int) -> None:
|
||||
if status in (401, 403):
|
||||
raise FatalError(
|
||||
f"{redact_url(self.spec.url)} returned {status}; add/update an API key for "
|
||||
f"this host at /api/download/credentials."
|
||||
)
|
||||
if status in _RETRYABLE_STATUSES:
|
||||
raise RetryableError(f"HTTP {status}")
|
||||
raise FatalError(f"unexpected HTTP {status}")
|
||||
|
||||
# ----- finalize / verify (PRD section 8.4) -----
|
||||
|
||||
async def _finalize(self) -> None:
|
||||
self._check_control()
|
||||
await self._close_writer()
|
||||
await self._set_status(DownloadStatus.VERIFYING)
|
||||
|
||||
total = self.state.total_bytes
|
||||
segmented = len(self.state.segments) > 1
|
||||
if segmented:
|
||||
# The .part was preallocated to total_bytes, so its on-disk size is
|
||||
# not evidence of completeness: a segment that ends short (truncated
|
||||
# 206 / server closes mid-range) leaves a zero-filled hole while the
|
||||
# file size still equals total. Verify each segment wrote its full
|
||||
# planned range, and trust the byte counter (== sum of segments)
|
||||
# rather than os.path.getsize for the total check.
|
||||
for seg in self.state.segments:
|
||||
if seg.bytes_done != seg.length:
|
||||
raise FatalError(
|
||||
f"segment {seg.idx} incomplete: wrote {seg.bytes_done} "
|
||||
f"of {seg.length} bytes"
|
||||
)
|
||||
observed = self.state.bytes_done
|
||||
else:
|
||||
# Single-stream writes a contiguous prefix, so the on-disk size is
|
||||
# an independent witness of how much actually landed.
|
||||
observed = os.path.getsize(self.spec.temp_path)
|
||||
if total is not None and observed != total:
|
||||
raise FatalError(
|
||||
f"size mismatch: wrote {observed} of {total} bytes"
|
||||
)
|
||||
|
||||
# Structural gate (cheap, no full read) then optional sha256 (full read).
|
||||
# Both failures are non-retryable (a truncated/corrupt or mismatched file
|
||||
# will not heal on retry), so surface them as FatalError rather than
|
||||
# letting the plain Exceptions fall through to the retryable handler.
|
||||
# ``temp_path`` carries the ``.part`` suffix; pass ``dest_path`` so the
|
||||
# structural check detects the real file format instead of skipping it.
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
structural.validate, self.spec.temp_path, self.spec.dest_path
|
||||
)
|
||||
if self.spec.expected_sha256:
|
||||
await asyncio.to_thread(
|
||||
checksum.verify_sha256,
|
||||
self.spec.temp_path,
|
||||
self.spec.expected_sha256,
|
||||
)
|
||||
except (structural.StructuralError, checksum.ChecksumError) as e:
|
||||
raise FatalError(str(e)) from e
|
||||
|
||||
os.makedirs(os.path.dirname(self.spec.dest_path), exist_ok=True)
|
||||
os.replace(self.spec.temp_path, self.spec.dest_path)
|
||||
logging.info(
|
||||
"[model_downloader] completed %s (%d bytes)",
|
||||
self.spec.model_id, observed,
|
||||
)
|
||||
# Catalog into the assets system (blake3 dedup identity). Best-effort.
|
||||
await dedup.register_completed(self.spec.dest_path)
|
||||
|
||||
# ----- helpers -----
|
||||
|
||||
def _recompute_bytes_done(self) -> None:
|
||||
self.state.bytes_done = sum(s.bytes_done for s in self.state.segments)
|
||||
now = time.monotonic()
|
||||
dt = now - self.state._last_time
|
||||
if dt >= 0.5:
|
||||
self.state.speed_bps = (self.state.bytes_done - self.state._last_bytes) / dt
|
||||
self.state._last_bytes = self.state.bytes_done
|
||||
self.state._last_time = now
|
||||
|
||||
async def _persist_progress(self, force: bool = False) -> None:
|
||||
# Both the DB write and the websocket notify are gated by the same
|
||||
# throttle: persisting hits SQLite, and notifying broadcasts to every
|
||||
# client, so doing either per-chunk (small --download-chunk-size or
|
||||
# many concurrent segments) would overwhelm both. Skip entirely inside
|
||||
# the window; the next persist (or a forced one) ships the latest bytes.
|
||||
now = time.monotonic()
|
||||
if not force and now - self._last_persist < _PERSIST_INTERVAL:
|
||||
return
|
||||
self._last_persist = now
|
||||
# SQLite is blocking; run it off the event loop per the queries module
|
||||
# contract so progress persists don't stall the web server.
|
||||
await asyncio.to_thread(self._write_progress)
|
||||
if self._notify:
|
||||
self._notify(self.spec.download_id)
|
||||
|
||||
def _write_progress(self) -> None:
|
||||
queries.update_download(self.spec.download_id, bytes_done=self.state.bytes_done)
|
||||
for seg in self.state.segments:
|
||||
if seg.end >= seg.start: # skip unknown-size sentinel
|
||||
queries.update_segment_progress(
|
||||
self.spec.download_id, seg.idx, seg.bytes_done
|
||||
)
|
||||
|
||||
async def _reset_for_restart(self) -> None:
|
||||
await self._close_writer()
|
||||
self._remove_temp()
|
||||
for seg in self.state.segments:
|
||||
seg.bytes_done = 0
|
||||
self.state.bytes_done = 0
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, self.spec.download_id, bytes_done=0
|
||||
)
|
||||
if await asyncio.to_thread(queries.list_segments, self.spec.download_id):
|
||||
await asyncio.to_thread(
|
||||
queries.replace_segments, self.spec.download_id, []
|
||||
)
|
||||
|
||||
async def _close_writer(self) -> None:
|
||||
if self._writer is not None:
|
||||
try:
|
||||
await self._writer.close()
|
||||
except Exception:
|
||||
logging.debug("[model_downloader] writer close error", exc_info=True)
|
||||
self._writer = None
|
||||
|
||||
def _segmented_part_valid(self, total_bytes: int) -> bool:
|
||||
"""True when the temp file is the preallocated segmented ``.part``.
|
||||
|
||||
A segmented transfer preallocates the .part to ``total_bytes`` up front
|
||||
and tracks how much of each range landed via per-segment offsets. Those
|
||||
offsets are only trustworthy when the file they describe is still on
|
||||
disk at its full preallocated size. A missing file (swept after a
|
||||
failure, removed on a fatal error, deleted by hand) or a wrong-sized one
|
||||
means the persisted offsets no longer correspond to real bytes and must
|
||||
not be resumed over. Doing so would skip "complete" segments and leave
|
||||
zero-filled holes that pass the size-only verification gate.
|
||||
"""
|
||||
try:
|
||||
return os.path.getsize(self.spec.temp_path) == total_bytes
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
def _contiguous_prefix_valid(self, prefix_len: int) -> bool:
|
||||
"""True when the temp file is exactly ``prefix_len`` contiguous bytes.
|
||||
|
||||
Single-stream resume appends sequentially, so a valid resume point
|
||||
implies the .part size equals the persisted offset. A larger file (e.g.
|
||||
one preallocated to ``total_bytes`` by a previous segmented run) or a
|
||||
missing/short file means the persisted offset is not a trustworthy
|
||||
contiguous prefix and must not be resumed over.
|
||||
"""
|
||||
try:
|
||||
return os.path.getsize(self.spec.temp_path) == prefix_len
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
def _remove_temp(self) -> None:
|
||||
try:
|
||||
os.remove(self.spec.temp_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError as e:
|
||||
logging.warning(
|
||||
"[model_downloader] could not remove %s: %s", self.spec.temp_path, e
|
||||
)
|
||||
|
||||
async def _set_status(self, status: str, error: Optional[str] = None) -> None:
|
||||
# ``error`` is authoritative: passing None clears any prior failure
|
||||
# text so transitions out of a failure state (retry/success) don't
|
||||
# leave stale messages on RuntimeState or in the persisted row.
|
||||
self.state.status = status
|
||||
self.state.error = error
|
||||
fields = {"status": status, "bytes_done": self.state.bytes_done, "error": error}
|
||||
if status == DownloadStatus.QUEUED:
|
||||
fields["attempts"] = self.spec.attempts + 1
|
||||
self.spec.attempts += 1
|
||||
await asyncio.to_thread(queries.update_download, self.spec.download_id, **fields)
|
||||
if self._notify:
|
||||
self._notify(self.spec.download_id)
|
||||
51
app/model_downloader/engine/planner.py
Normal file
51
app/model_downloader/engine/planner.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""Segment planning.
|
||||
|
||||
Split a known byte range into S roughly-equal segments, each fetched by its
|
||||
own coroutine with ``Range: bytes=start-end``. Falls back to a single segment
|
||||
when the server doesn't support ranges or the size is unknown/too small for
|
||||
segmentation to be worthwhile.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Below this size, the per-connection setup cost outweighs any parallelism.
|
||||
_MIN_SEGMENT_BYTES = 1 * 1024 * 1024
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SegmentPlan:
|
||||
idx: int
|
||||
start: int
|
||||
end: int # inclusive
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return self.end - self.start + 1
|
||||
|
||||
|
||||
def effective_segment_count(
|
||||
total_bytes: int | None, accept_ranges: bool, configured: int
|
||||
) -> int:
|
||||
"""How many segments to actually use for this file."""
|
||||
if not accept_ranges or total_bytes is None or total_bytes <= 0:
|
||||
return 1
|
||||
by_size = max(1, total_bytes // _MIN_SEGMENT_BYTES)
|
||||
return max(1, min(configured, by_size))
|
||||
|
||||
|
||||
def plan_segments(total_bytes: int, num_segments: int) -> list[SegmentPlan]:
|
||||
"""Return ``num_segments`` contiguous, inclusive byte ranges covering [0, total)."""
|
||||
if total_bytes <= 0 or num_segments <= 1:
|
||||
return [SegmentPlan(idx=0, start=0, end=max(0, total_bytes - 1))]
|
||||
base = total_bytes // num_segments
|
||||
plans: list[SegmentPlan] = []
|
||||
start = 0
|
||||
for i in range(num_segments):
|
||||
# Last segment soaks up the remainder.
|
||||
length = base if i < num_segments - 1 else total_bytes - start
|
||||
end = start + length - 1
|
||||
plans.append(SegmentPlan(idx=i, start=start, end=end))
|
||||
start = end + 1
|
||||
return plans
|
||||
110
app/model_downloader/engine/writer.py
Normal file
110
app/model_downloader/engine/writer.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""Positioned, off-loop file writes.
|
||||
|
||||
Network I/O stays on the event loop; every blocking disk op (preallocate,
|
||||
positioned write, fsync) is run in a bounded thread pool via
|
||||
``run_in_executor`` so downloads never stall inference or the web server.
|
||||
|
||||
A single file descriptor is opened for the whole download. Segments write to
|
||||
their own offsets with ``os.pwrite`` — which is offset-addressed and atomic
|
||||
per call, so concurrent segment writers need no extra locking. Per-chunk
|
||||
fsync is avoided; we fsync once at completion.
|
||||
|
||||
``os.pwrite`` is unavailable on Windows, so there we fall back to
|
||||
``os.lseek`` + ``os.write`` guarded by a per-writer lock (the seek/write pair
|
||||
is not atomic, so concurrent segment writers must be serialized).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
# One shared, bounded pool for all download disk I/O.
|
||||
_EXECUTOR = ThreadPoolExecutor(max_workers=8, thread_name_prefix="dl-writer")
|
||||
|
||||
_HAS_PWRITE = hasattr(os, "pwrite")
|
||||
|
||||
# On Windows ``os.open`` defaults to text mode, which translates every ``\n``
|
||||
# byte into ``\r\n`` on write and corrupts binary payloads (the file grows by
|
||||
# one byte per 0x0A). ``O_BINARY`` disables that translation; it does not exist
|
||||
# on POSIX, where the default is already binary.
|
||||
_O_BINARY = getattr(os, "O_BINARY", 0)
|
||||
|
||||
|
||||
class FileWriter:
|
||||
"""Owns the ``.part`` file descriptor for one download."""
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self._fd: Optional[int] = None
|
||||
# Serializes lseek+write on platforms without os.pwrite (Windows).
|
||||
self._seek_lock = threading.Lock()
|
||||
|
||||
def _open(self) -> None:
|
||||
os.makedirs(os.path.dirname(self.path), exist_ok=True)
|
||||
self._fd = os.open(self.path, os.O_RDWR | os.O_CREAT | _O_BINARY, 0o644)
|
||||
|
||||
async def open(self) -> None:
|
||||
await asyncio.get_running_loop().run_in_executor(_EXECUTOR, self._open)
|
||||
|
||||
async def preallocate(self, size: int) -> None:
|
||||
"""Grow the file to ``size`` so segments write to their offsets."""
|
||||
if self._fd is None or size <= 0:
|
||||
return
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
_EXECUTOR, os.ftruncate, self._fd, size
|
||||
)
|
||||
|
||||
async def truncate(self, size: int = 0) -> None:
|
||||
"""Truncate the file to ``size`` bytes (default: empty it)."""
|
||||
if self._fd is None:
|
||||
return
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
_EXECUTOR, os.ftruncate, self._fd, size
|
||||
)
|
||||
|
||||
def _pwrite_all(self, data: bytes, offset: int) -> None:
|
||||
"""A positioned write may write fewer bytes than requested (signal
|
||||
interruption, near-ENOSPC); loop until every byte lands so we never
|
||||
leave a gap while the caller advances by the full chunk length.
|
||||
|
||||
Uses ``os.pwrite`` where available (offset-addressed, atomic per call).
|
||||
On Windows it falls back to ``os.lseek`` + ``os.write`` under a lock,
|
||||
since that pair is not atomic across concurrent segment writers."""
|
||||
assert self._fd is not None, "writer not opened"
|
||||
view = memoryview(data)
|
||||
written = 0
|
||||
total = len(view)
|
||||
while written < total:
|
||||
if _HAS_PWRITE:
|
||||
n = os.pwrite(self._fd, view[written:], offset + written)
|
||||
else:
|
||||
with self._seek_lock:
|
||||
os.lseek(self._fd, offset + written, os.SEEK_SET)
|
||||
n = os.write(self._fd, view[written:])
|
||||
if n == 0:
|
||||
raise OSError(
|
||||
f"positioned write wrote 0 bytes at offset {offset + written} "
|
||||
f"({written}/{total} bytes written)"
|
||||
)
|
||||
written += n
|
||||
|
||||
async def write_at(self, offset: int, data: bytes) -> None:
|
||||
assert self._fd is not None, "writer not opened"
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
_EXECUTOR, self._pwrite_all, data, offset
|
||||
)
|
||||
|
||||
async def flush(self) -> None:
|
||||
if self._fd is None:
|
||||
return
|
||||
await asyncio.get_running_loop().run_in_executor(_EXECUTOR, os.fsync, self._fd)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._fd is None:
|
||||
return
|
||||
fd, self._fd = self._fd, None
|
||||
await asyncio.get_running_loop().run_in_executor(_EXECUTOR, os.close, fd)
|
||||
455
app/model_downloader/manager.py
Normal file
455
app/model_downloader/manager.py
Normal file
@ -0,0 +1,455 @@
|
||||
"""Public facade for the download manager.
|
||||
|
||||
This is the only object the server imports. It validates requests, owns the
|
||||
:class:`Scheduler`, and exposes a small async API plus read models for status.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Callable, Optional
|
||||
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.net.probe import probe
|
||||
from app.model_downloader.scheduler import SCHEDULER
|
||||
from app.model_downloader.security import paths
|
||||
from app.model_downloader.net.http import redact_url
|
||||
from app.model_downloader.security.allowlist import (
|
||||
ALLOWED_MODEL_EXTENSIONS,
|
||||
filename_extension,
|
||||
is_host_allowed_url,
|
||||
is_url_downloadable,
|
||||
url_path_extension,
|
||||
)
|
||||
from app.model_downloader.security.paths import InvalidModelId
|
||||
|
||||
# Non-terminal statuses: an existing row in one of these blocks a re-enqueue.
|
||||
_LIVE_STATUSES = (
|
||||
DownloadStatus.QUEUED,
|
||||
DownloadStatus.ACTIVE,
|
||||
DownloadStatus.PAUSED,
|
||||
DownloadStatus.VERIFYING,
|
||||
)
|
||||
|
||||
|
||||
class DownloadError(Exception):
|
||||
"""A user-facing error with a stable machine-readable code."""
|
||||
|
||||
def __init__(self, code: str, message: str, status: int = 400) -> None:
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.http_status = status
|
||||
|
||||
|
||||
class DownloadManager:
|
||||
def __init__(self) -> None:
|
||||
self._scheduler = SCHEDULER
|
||||
self._notify_cb: Optional[Callable[[str], None]] = None
|
||||
# Serializes the "check for a live download, then write" critical section
|
||||
# per model_id. ``downloads`` has no uniqueness constraint on model_id
|
||||
# (history rows are kept), so without this two concurrent enqueue/resume
|
||||
# calls could both pass the live check and admit two jobs sharing one
|
||||
# temp/dest path. The manager is a process singleton over a local SQLite
|
||||
# DB, so an in-process lock is sufficient (and avoids a migration).
|
||||
self._model_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
def set_notify(self, cb: Optional[Callable[[str], None]]) -> None:
|
||||
self._notify_cb = cb
|
||||
self._scheduler.set_notify(cb)
|
||||
|
||||
async def start(self) -> None:
|
||||
await self._scheduler.start()
|
||||
|
||||
# ----- enqueue -----
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
url: str,
|
||||
model_id: str,
|
||||
*,
|
||||
priority: int = 0,
|
||||
expected_sha256: Optional[str] = None,
|
||||
allow_any_extension: bool = False,
|
||||
credential_id: Optional[str] = None,
|
||||
) -> str:
|
||||
# Coarse gate first: host/scheme must be allowlisted, and any extension
|
||||
# present in the URL path must be a known model type. A URL whose path
|
||||
# carries NO extension (e.g. Civitai's ``/api/download/models/<id>``) is
|
||||
# admitted here and its real extension is resolved from the network
|
||||
# below before the download is finally accepted.
|
||||
if allow_any_extension:
|
||||
if not is_host_allowed_url(url):
|
||||
raise DownloadError(
|
||||
"URL_NOT_ALLOWED",
|
||||
"URL is not on the download allowlist (host/scheme).",
|
||||
)
|
||||
elif not is_url_downloadable(url):
|
||||
raise DownloadError(
|
||||
"URL_NOT_ALLOWED",
|
||||
"URL is not on the download allowlist (host/scheme/extension).",
|
||||
)
|
||||
|
||||
# When the URL path has no extension, follow it to where it resolves and
|
||||
# adopt the real extension from the response, forcing the stored
|
||||
# filename to match. Skipped when the caller opted into any extension.
|
||||
if not allow_any_extension and url_path_extension(url) == "":
|
||||
resolved_ext = await self._resolve_extension(url, credential_id)
|
||||
model_id = paths.apply_extension(model_id, resolved_ext)
|
||||
|
||||
try:
|
||||
paths.parse_model_id(model_id, allow_any_extension)
|
||||
dest_path, temp_path = paths.resolve_destination(model_id, allow_any_extension)
|
||||
except InvalidModelId as e:
|
||||
raise DownloadError("INVALID_MODEL_ID", str(e))
|
||||
|
||||
if await asyncio.to_thread(
|
||||
paths.resolve_existing, model_id, allow_any_extension
|
||||
):
|
||||
raise DownloadError(
|
||||
"ALREADY_AVAILABLE",
|
||||
f"Model already exists on disk: {model_id}",
|
||||
status=409,
|
||||
)
|
||||
download_id = str(uuid.uuid4())
|
||||
# Hold the per-model lock across the live check and the insert so a
|
||||
# concurrent enqueue/resume for the same model_id cannot interleave
|
||||
# between them and create a second job against the same temp/dest path.
|
||||
async with self._model_lock(model_id):
|
||||
if await self._has_live_download(model_id):
|
||||
raise DownloadError(
|
||||
"ALREADY_DOWNLOADING",
|
||||
f"A download for {model_id} is already in progress.",
|
||||
status=409,
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
queries.insert_download,
|
||||
{
|
||||
"id": download_id,
|
||||
"url": url,
|
||||
"model_id": model_id,
|
||||
"dest_path": dest_path,
|
||||
"temp_path": temp_path,
|
||||
"status": DownloadStatus.QUEUED,
|
||||
"priority": priority,
|
||||
"expected_sha256": expected_sha256,
|
||||
"credential_id": credential_id,
|
||||
"allow_any_extension": allow_any_extension,
|
||||
},
|
||||
)
|
||||
logging.info("[model_downloader] enqueued %s -> %s", redact_url(url), model_id)
|
||||
await self._scheduler.pump()
|
||||
return download_id
|
||||
|
||||
async def _resolve_extension(
|
||||
self, url: str, credential_id: Optional[str]
|
||||
) -> str:
|
||||
"""Follow ``url`` to its final response and return the real extension.
|
||||
|
||||
Used for allowlisted URLs whose path has no extension (e.g. Civitai
|
||||
download endpoints): the filename lives in the ``Content-Disposition``
|
||||
header or the post-redirect URL. Raises :class:`DownloadError` when the
|
||||
URL can't be resolved, needs credentials, or resolves to something that
|
||||
is not a known model file — so we never persist a bogus destination.
|
||||
"""
|
||||
pr = await probe(url, credential_id=credential_id)
|
||||
if not pr.ok:
|
||||
if pr.gated:
|
||||
raise DownloadError(
|
||||
"CREDENTIALS_REQUIRED",
|
||||
f"{redact_url(url)} requires authentication to resolve. Add an "
|
||||
f"API key for this host at /api/download/credentials and retry.",
|
||||
status=401,
|
||||
)
|
||||
raise DownloadError(
|
||||
"URL_RESOLVE_FAILED",
|
||||
f"Could not resolve {redact_url(url)}: {pr.error or 'unknown error'}",
|
||||
status=502,
|
||||
)
|
||||
ext = filename_extension(pr.filename) if pr.filename else ""
|
||||
if ext not in ALLOWED_MODEL_EXTENSIONS:
|
||||
raise DownloadError(
|
||||
"URL_NOT_ALLOWED",
|
||||
f"URL resolves to {pr.filename or '<unknown>'!r}, which is not a "
|
||||
f"known model file type {ALLOWED_MODEL_EXTENSIONS}.",
|
||||
)
|
||||
return ext
|
||||
|
||||
def _model_lock(self, model_id: str) -> asyncio.Lock:
|
||||
# Lazily create one lock per model_id. There is no ``await`` between the
|
||||
# lookup and the insert, so under the single asyncio thread this is
|
||||
# atomic and cannot hand out two different locks for the same model_id.
|
||||
lock = self._model_locks.get(model_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
self._model_locks[model_id] = lock
|
||||
return lock
|
||||
|
||||
async def _has_live_download(
|
||||
self, model_id: str, *, exclude_id: Optional[str] = None
|
||||
) -> bool:
|
||||
rows = await asyncio.to_thread(queries.list_downloads)
|
||||
return any(
|
||||
r.model_id == model_id
|
||||
and r.id != exclude_id
|
||||
and r.status in _LIVE_STATUSES
|
||||
for r in rows
|
||||
)
|
||||
|
||||
# ----- control -----
|
||||
|
||||
async def pause(self, download_id: str) -> None:
|
||||
job = self._scheduler.get_job(download_id)
|
||||
if job is not None:
|
||||
job.request_pause()
|
||||
return
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
if row is None:
|
||||
raise DownloadError("NOT_FOUND", "No such download.", status=404)
|
||||
if row.status == DownloadStatus.QUEUED:
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, download_id, status=DownloadStatus.PAUSED
|
||||
)
|
||||
|
||||
async def resume(self, download_id: str) -> None:
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
if row is None:
|
||||
raise DownloadError("NOT_FOUND", "No such download.", status=404)
|
||||
if row.status not in (DownloadStatus.PAUSED, DownloadStatus.FAILED):
|
||||
return
|
||||
# Re-queueing a paused/failed row must respect the single-live-per-model
|
||||
# invariant: another download (e.g. a newer enqueue) may already be live
|
||||
# for this model_id and would share this row's temp/dest path. Hold the
|
||||
# per-model lock across the check and the status flip, and exclude this
|
||||
# row itself (a paused row is already a "live" status).
|
||||
async with self._model_lock(row.model_id):
|
||||
if await self._has_live_download(row.model_id, exclude_id=download_id):
|
||||
raise DownloadError(
|
||||
"ALREADY_DOWNLOADING",
|
||||
f"A download for {row.model_id} is already in progress.",
|
||||
status=409,
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
queries.update_download,
|
||||
download_id,
|
||||
status=DownloadStatus.QUEUED,
|
||||
error=None,
|
||||
)
|
||||
await self._scheduler.pump()
|
||||
|
||||
async def cancel(self, download_id: str) -> None:
|
||||
job = self._scheduler.get_job(download_id)
|
||||
if job is not None:
|
||||
job.request_cancel()
|
||||
return
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
if row is None:
|
||||
raise DownloadError("NOT_FOUND", "No such download.", status=404)
|
||||
if row.status in _LIVE_STATUSES:
|
||||
import os
|
||||
|
||||
try:
|
||||
os.remove(row.temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, download_id, status=DownloadStatus.CANCELLED
|
||||
)
|
||||
|
||||
async def set_priority(self, download_id: str, priority: int) -> None:
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
if row is None:
|
||||
raise DownloadError("NOT_FOUND", "No such download.", status=404)
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, download_id, priority=priority
|
||||
)
|
||||
# Admission-order only; a higher priority is
|
||||
# picked up the next time a slot frees. Pump in case a slot is free now.
|
||||
await self._scheduler.pump()
|
||||
|
||||
async def delete(self, download_id: str) -> None:
|
||||
"""Delete a terminal download so it stays gone from history.
|
||||
|
||||
Refuses to delete a live download so a record is never removed out from
|
||||
under a running worker; cancel it first. Any leftover ``.part`` temp
|
||||
file (e.g. from a failed transfer) is removed, but the finished model
|
||||
file on disk is never touched.
|
||||
"""
|
||||
if self._scheduler.get_job(download_id) is not None:
|
||||
raise DownloadError(
|
||||
"DOWNLOAD_ACTIVE",
|
||||
"Cannot delete a download that is still in progress.",
|
||||
status=409,
|
||||
)
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
if row is None:
|
||||
raise DownloadError("NOT_FOUND", "No such download.", status=404)
|
||||
if row.status in _LIVE_STATUSES:
|
||||
raise DownloadError(
|
||||
"DOWNLOAD_ACTIVE",
|
||||
"Cannot delete a download that is still in progress.",
|
||||
status=409,
|
||||
)
|
||||
|
||||
try:
|
||||
os.remove(row.temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
await asyncio.to_thread(queries.delete_download, download_id)
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Delete all terminal downloads from history in one transaction.
|
||||
|
||||
Skips anything still live (queued/active/paused/verifying, or a running
|
||||
job) so an in-flight download is never removed out from under a worker.
|
||||
Finished model files on disk are never touched; only leftover ``.part``
|
||||
temp files from failed/cancelled transfers are removed. Returns the
|
||||
number of history rows deleted.
|
||||
"""
|
||||
|
||||
rows = await asyncio.to_thread(queries.list_downloads)
|
||||
deletable = [
|
||||
r
|
||||
for r in rows
|
||||
if r.status not in _LIVE_STATUSES
|
||||
and self._scheduler.get_job(r.id) is None
|
||||
]
|
||||
if not deletable:
|
||||
return 0
|
||||
for r in deletable:
|
||||
try:
|
||||
os.remove(r.temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
return await asyncio.to_thread(
|
||||
queries.delete_downloads, [r.id for r in deletable]
|
||||
)
|
||||
|
||||
# ----- read models -----
|
||||
|
||||
def _view(self, row) -> dict:
|
||||
"""Combine the persisted row with live in-memory progress, if running."""
|
||||
job = self._scheduler.get_job(row.id)
|
||||
bytes_done = row.bytes_done
|
||||
total = row.total_bytes
|
||||
speed = None
|
||||
eta = None
|
||||
segments = None
|
||||
if job is not None:
|
||||
st = job.state
|
||||
bytes_done = st.bytes_done
|
||||
total = st.total_bytes if st.total_bytes is not None else total
|
||||
speed = st.speed_bps
|
||||
eta = st.eta_seconds
|
||||
segments = [
|
||||
{"idx": s.idx, "bytes_done": s.bytes_done, "length": s.length}
|
||||
for s in st.segments
|
||||
if s.end >= s.start
|
||||
]
|
||||
progress = (bytes_done / total) if total else None
|
||||
return {
|
||||
"download_id": row.id,
|
||||
"model_id": row.model_id,
|
||||
"url": redact_url(row.url),
|
||||
"status": row.status,
|
||||
"priority": row.priority,
|
||||
"total_bytes": total,
|
||||
"bytes_done": bytes_done,
|
||||
"progress": progress,
|
||||
"speed_bps": speed,
|
||||
"eta_seconds": eta,
|
||||
"segments": segments,
|
||||
"error": row.error,
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
}
|
||||
|
||||
def _view_from_state(self, job) -> dict:
|
||||
"""Build a view purely from the live in-memory job state (no DB)."""
|
||||
st = job.state
|
||||
return {
|
||||
"download_id": st.download_id,
|
||||
"model_id": st.model_id,
|
||||
"url": redact_url(st.url),
|
||||
"status": st.status,
|
||||
"priority": st.priority,
|
||||
"total_bytes": st.total_bytes,
|
||||
"bytes_done": st.bytes_done,
|
||||
"progress": st.progress,
|
||||
"speed_bps": st.speed_bps,
|
||||
"eta_seconds": st.eta_seconds,
|
||||
"segments": [
|
||||
{"idx": s.idx, "bytes_done": s.bytes_done, "length": s.length}
|
||||
for s in st.segments
|
||||
if s.end >= s.start
|
||||
],
|
||||
"error": st.error,
|
||||
}
|
||||
|
||||
def status_sync(self, download_id: str) -> Optional[dict]:
|
||||
"""Synchronous status read for the websocket notify path.
|
||||
|
||||
Uses live in-memory state when the job is running (no DB round-trip on
|
||||
the hot path); falls back to a quick DB read otherwise.
|
||||
"""
|
||||
job = self._scheduler.get_job(download_id)
|
||||
if job is not None:
|
||||
return self._view_from_state(job)
|
||||
row = queries.get_download(download_id)
|
||||
return self._view(row) if row is not None else None
|
||||
|
||||
async def status(self, download_id: str) -> Optional[dict]:
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
return self._view(row) if row is not None else None
|
||||
|
||||
async def list(self) -> list[dict]:
|
||||
rows = await asyncio.to_thread(queries.list_downloads)
|
||||
return [self._view(r) for r in rows]
|
||||
|
||||
async def availability(self, models: dict[str, str]) -> dict[str, dict]:
|
||||
"""Bulk per-id ``{state, progress, ...}`` for the frontend poll.
|
||||
|
||||
``state`` is ``available`` (on disk), ``downloading`` (live row), or
|
||||
``missing``. Cheap: a path lookup plus an in-memory/DB status check.
|
||||
"""
|
||||
rows = await asyncio.to_thread(queries.list_downloads)
|
||||
by_model: dict[str, object] = {}
|
||||
for r in rows:
|
||||
if r.status in _LIVE_STATUSES or r.model_id not in by_model:
|
||||
by_model[r.model_id] = r
|
||||
|
||||
# ``url_allowed`` mirrors the coarse enqueue gate (host/scheme + a
|
||||
# non-disallowed extension); URLs whose extension is only known after a
|
||||
# network resolve — e.g. Civitai download endpoints — report allowed.
|
||||
out: dict[str, dict] = {}
|
||||
for model_id, url in models.items():
|
||||
try:
|
||||
exists = await asyncio.to_thread(paths.resolve_existing, model_id)
|
||||
except InvalidModelId:
|
||||
out[model_id] = {"state": "missing", "url_allowed": is_url_downloadable(url)}
|
||||
continue
|
||||
if exists:
|
||||
out[model_id] = {"state": "available", "url_allowed": is_url_downloadable(url)}
|
||||
continue
|
||||
row = by_model.get(model_id)
|
||||
if row is not None and row.status in _LIVE_STATUSES:
|
||||
view = self._view(row)
|
||||
out[model_id] = {
|
||||
"state": "downloading",
|
||||
"url_allowed": is_url_downloadable(url),
|
||||
"download_id": view["download_id"],
|
||||
"progress": view["progress"],
|
||||
"bytes_done": view["bytes_done"],
|
||||
"total_bytes": view["total_bytes"],
|
||||
"speed_bps": view["speed_bps"],
|
||||
}
|
||||
else:
|
||||
out[model_id] = {"state": "missing", "url_allowed": is_url_downloadable(url)}
|
||||
return out
|
||||
|
||||
|
||||
DOWNLOAD_MANAGER = DownloadManager()
|
||||
148
app/model_downloader/net/http.py
Normal file
148
app/model_downloader/net/http.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""Manual, validated redirect-following request opener.
|
||||
|
||||
Automatic redirects are disabled. We follow hops ourselves
|
||||
so that on *every* hop we (a) re-validate scheme + reject credentials-in-URL,
|
||||
(b) recompute which stored credential — if any — applies to that hop's host,
|
||||
and (c) let the connector's resolver screen the IP. This is the single place
|
||||
that attaches credentials, so a token can never ride a redirect to a CDN host.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator, Optional
|
||||
from urllib.parse import unquote, urljoin, urlsplit, urlunsplit
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.model_downloader.credentials.resolver import resolve_auth_for_hop
|
||||
from app.model_downloader.net.session import get_session
|
||||
from app.model_downloader.security.ssrf import (
|
||||
MAX_REDIRECTS,
|
||||
SSRFError,
|
||||
check_redirect_hop,
|
||||
)
|
||||
|
||||
_REDIRECT_CODES = {301, 302, 303, 307, 308}
|
||||
DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120)
|
||||
|
||||
|
||||
def redact_url(url: str) -> str:
|
||||
"""Drop the query string so a query-scheme secret is never logged/stored."""
|
||||
try:
|
||||
parts = urlsplit(url)
|
||||
except ValueError:
|
||||
return "<unparseable-url>"
|
||||
return urlunsplit(parts._replace(query=""))
|
||||
|
||||
|
||||
_CD_FILENAME_STAR = re.compile(
|
||||
r"filename\*\s*=\s*[^']*'[^']*'([^;]+)", re.IGNORECASE
|
||||
)
|
||||
_CD_FILENAME_QUOTED = re.compile(r'filename\s*=\s*"([^"]+)"', re.IGNORECASE)
|
||||
_CD_FILENAME_BARE = re.compile(r"filename\s*=\s*([^;]+)", re.IGNORECASE)
|
||||
|
||||
|
||||
def filename_from_content_disposition(value: Optional[str]) -> Optional[str]:
|
||||
"""Extract the download filename from a ``Content-Disposition`` header.
|
||||
|
||||
Prefers the RFC 5987 ``filename*=`` form (percent-decoded) over the plain
|
||||
``filename=`` form. Any directory components in the value are stripped so a
|
||||
hostile header can only influence the *name*, never the target directory.
|
||||
Returns ``None`` when no filename is present.
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
for pat, decode in (
|
||||
(_CD_FILENAME_STAR, True),
|
||||
(_CD_FILENAME_QUOTED, False),
|
||||
(_CD_FILENAME_BARE, False),
|
||||
):
|
||||
m = pat.search(value)
|
||||
if not m:
|
||||
continue
|
||||
raw = m.group(1).strip().strip('"')
|
||||
if decode:
|
||||
try:
|
||||
raw = unquote(raw)
|
||||
except Exception:
|
||||
pass
|
||||
name = raw.replace("\\", "/").rsplit("/", 1)[-1].strip()
|
||||
if name:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_final_response(
|
||||
method: str,
|
||||
url: str,
|
||||
credential_id: Optional[str],
|
||||
base_headers: dict[str, str],
|
||||
timeout: aiohttp.ClientTimeout,
|
||||
) -> tuple[aiohttp.ClientResponse, str]:
|
||||
"""Follow redirects manually until a non-redirect response.
|
||||
|
||||
Each intermediate redirect response is released before the next hop.
|
||||
Returns the final ``(response, final_url)``; the caller owns releasing it.
|
||||
"""
|
||||
session = await get_session()
|
||||
current = url
|
||||
hops = 0
|
||||
while True:
|
||||
check_redirect_hop(current, is_initial_url=(hops == 0))
|
||||
parts = urlsplit(current)
|
||||
auth = await resolve_auth_for_hop(
|
||||
parts.hostname or "", parts.scheme, explicit_credential_id=credential_id
|
||||
)
|
||||
req_headers = dict(base_headers)
|
||||
req_url = current
|
||||
if auth is not None:
|
||||
req_headers.update(auth.headers)
|
||||
req_url = auth.apply_to_url(current)
|
||||
|
||||
resp = await session.request(
|
||||
method,
|
||||
req_url,
|
||||
allow_redirects=False,
|
||||
headers=req_headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
if resp.status in _REDIRECT_CODES and resp.headers.get("Location"):
|
||||
next_url = urljoin(str(resp.url), resp.headers["Location"])
|
||||
await resp.release()
|
||||
hops += 1
|
||||
if hops > MAX_REDIRECTS:
|
||||
raise SSRFError(
|
||||
f"too many redirects (> {MAX_REDIRECTS}) for {redact_url(url)}"
|
||||
)
|
||||
current = next_url
|
||||
continue
|
||||
return resp, redact_url(str(resp.url))
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def open_validated(
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
credential_id: Optional[str] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
timeout: aiohttp.ClientTimeout = DEFAULT_TIMEOUT,
|
||||
) -> AsyncIterator[tuple[aiohttp.ClientResponse, str]]:
|
||||
"""Open ``method url`` following redirects manually and validated.
|
||||
|
||||
Yields ``(response, final_url)`` where ``final_url`` is redacted of any
|
||||
query string. The response is released automatically on exit.
|
||||
"""
|
||||
resp, final_url = await _resolve_final_response(
|
||||
method, url, credential_id, dict(headers or {}), timeout
|
||||
)
|
||||
try:
|
||||
yield resp, final_url
|
||||
finally:
|
||||
try:
|
||||
await resp.release()
|
||||
except Exception: # pragma: no cover - best-effort cleanup
|
||||
logging.debug("[model_downloader] response release error", exc_info=True)
|
||||
116
app/model_downloader/net/probe.py
Normal file
116
app/model_downloader/net/probe.py
Normal file
@ -0,0 +1,116 @@
|
||||
"""Pre-download probe.
|
||||
|
||||
Issues a tiny ranged GET (``Range: bytes=0-0``) — which doubles as a
|
||||
range-support test — to discover ``Content-Length``, ``Accept-Ranges``,
|
||||
``ETag``/``Last-Modified``, and the final post-redirect URL. For HuggingFace
|
||||
LFS files the true size also appears in the non-standard ``X-Linked-Size``
|
||||
header, which we read as a fallback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse, urlsplit
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.model_downloader.net.http import (
|
||||
filename_from_content_disposition,
|
||||
open_validated,
|
||||
)
|
||||
from app.model_downloader.net.session import parse_int_header
|
||||
|
||||
_PROBE_TIMEOUT = aiohttp.ClientTimeout(total=60, sock_connect=30, sock_read=30)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProbeResult:
|
||||
ok: bool
|
||||
status: int
|
||||
final_url: Optional[str] = None
|
||||
total_bytes: Optional[int] = None
|
||||
accept_ranges: bool = False
|
||||
etag: Optional[str] = None
|
||||
last_modified: Optional[str] = None
|
||||
gated: bool = False # 401/403 — needs (or has wrong) credentials
|
||||
error: Optional[str] = None
|
||||
# Filename the server intends this response to be saved as: the
|
||||
# ``Content-Disposition`` name if present, else the post-redirect URL's
|
||||
# basename. Used to resolve the real extension for URLs (e.g. Civitai's
|
||||
# ``/api/download`` endpoints) that carry no extension in their path.
|
||||
filename: Optional[str] = None
|
||||
|
||||
|
||||
def _total_from_content_range(value: Optional[str]) -> Optional[int]:
|
||||
# "bytes 0-0/12345" -> 12345 ; "bytes 0-0/*" -> None
|
||||
if not value or "/" not in value:
|
||||
return None
|
||||
total = value.rsplit("/", 1)[1].strip()
|
||||
return parse_int_header(total)
|
||||
|
||||
|
||||
def _filename_from_response(
|
||||
content_disposition: Optional[str], final_url: Optional[str]
|
||||
) -> Optional[str]:
|
||||
name = filename_from_content_disposition(content_disposition)
|
||||
if name:
|
||||
return name
|
||||
if final_url:
|
||||
base = urlsplit(final_url).path.rsplit("/", 1)[-1]
|
||||
if base:
|
||||
return base
|
||||
return None
|
||||
|
||||
|
||||
async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult:
|
||||
"""Probe ``url`` and return discovered metadata, failing soft."""
|
||||
try:
|
||||
async with open_validated(
|
||||
"GET",
|
||||
url,
|
||||
credential_id=credential_id,
|
||||
headers={"Range": "bytes=0-0", "Accept-Encoding": "identity"},
|
||||
timeout=_PROBE_TIMEOUT,
|
||||
) as (resp, final_url):
|
||||
if resp.status in (401, 403):
|
||||
return ProbeResult(
|
||||
ok=False, status=resp.status, final_url=final_url, gated=True,
|
||||
error=f"host returned {resp.status} (authentication required)",
|
||||
)
|
||||
if resp.status not in (200, 206):
|
||||
return ProbeResult(
|
||||
ok=False, status=resp.status, final_url=final_url,
|
||||
error=f"probe returned HTTP {resp.status}",
|
||||
)
|
||||
|
||||
headers = resp.headers
|
||||
accept_ranges = False
|
||||
total: Optional[int] = None
|
||||
if resp.status == 206:
|
||||
accept_ranges = True
|
||||
total = _total_from_content_range(headers.get("Content-Range"))
|
||||
else: # 200: server ignored the range
|
||||
accept_ranges = headers.get("Accept-Ranges", "").lower() == "bytes"
|
||||
total = parse_int_header(headers.get("Content-Length"))
|
||||
|
||||
if total is None:
|
||||
total = parse_int_header(headers.get("X-Linked-Size"))
|
||||
|
||||
return ProbeResult(
|
||||
ok=True,
|
||||
status=resp.status,
|
||||
final_url=final_url,
|
||||
total_bytes=total,
|
||||
accept_ranges=accept_ranges,
|
||||
etag=headers.get("ETag"),
|
||||
last_modified=headers.get("Last-Modified"),
|
||||
filename=_filename_from_response(
|
||||
headers.get("Content-Disposition"), final_url
|
||||
),
|
||||
)
|
||||
except Exception as e: # network / SSRF / timeout
|
||||
host = urlparse(url).netloc or "<unknown>"
|
||||
logging.debug("[model_downloader] probe failed for %s: %s", host, type(e).__name__)
|
||||
return ProbeResult(ok=False, status=0, error="probe failed: network error")
|
||||
72
app/model_downloader/net/session.py
Normal file
72
app/model_downloader/net/session.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""Lazily-created shared :class:`aiohttp.ClientSession`.
|
||||
|
||||
A single session reuses TLS handshakes and TCP connections across the probe
|
||||
and the many segment GETs to the same host (HuggingFace is the dominant
|
||||
case), which is a large speedup on cold connections and exactly the
|
||||
connection-reuse strategy that lets us match aria2c.
|
||||
|
||||
The connector uses :class:`ValidatingResolver` so every connection — initial
|
||||
or post-redirect — is screened for private/special-use IPs at connect time.
|
||||
TLS is pinned to certifi's CA bundle because the OS trust store is not wired
|
||||
up on some Python installs (python.org macOS, slim containers).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ssl
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
try:
|
||||
import certifi
|
||||
_CA_FILE = certifi.where()
|
||||
except Exception: # pragma: no cover - certifi is a transitive dep of aiohttp
|
||||
_CA_FILE = None
|
||||
|
||||
from comfy.cli_args import args
|
||||
from app.model_downloader.security.ssrf import ValidatingResolver
|
||||
|
||||
_session: Optional[aiohttp.ClientSession] = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def ssl_context() -> ssl.SSLContext:
|
||||
if _CA_FILE is not None:
|
||||
return ssl.create_default_context(cafile=_CA_FILE)
|
||||
return ssl.create_default_context()
|
||||
|
||||
|
||||
async def get_session() -> aiohttp.ClientSession:
|
||||
"""Return the shared session, creating it on first use."""
|
||||
global _session
|
||||
if _session is not None and not _session.closed:
|
||||
return _session
|
||||
async with _lock:
|
||||
if _session is None or _session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit_per_host=max(1, getattr(args, "download_max_connections_per_host", 16)),
|
||||
ssl=ssl_context(),
|
||||
resolver=ValidatingResolver(),
|
||||
)
|
||||
_session = aiohttp.ClientSession(connector=connector)
|
||||
return _session
|
||||
|
||||
|
||||
async def close_session() -> None:
|
||||
global _session
|
||||
if _session is not None and not _session.closed:
|
||||
await _session.close()
|
||||
_session = None
|
||||
|
||||
|
||||
def parse_int_header(value: Optional[str]) -> Optional[int]:
|
||||
"""Parse a non-negative integer header value, or None if bad/absent."""
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
n = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return n if n >= 0 else None
|
||||
177
app/model_downloader/scheduler.py
Normal file
177
app/model_downloader/scheduler.py
Normal file
@ -0,0 +1,177 @@
|
||||
"""Priority scheduler + lifecycle.
|
||||
|
||||
Owns the set of running jobs and admits queued downloads up to a global
|
||||
concurrency limit (K), highest priority first, FIFO within a priority. Runs
|
||||
entirely on the existing ComfyUI asyncio loop; blocking work (disk, hashing,
|
||||
DB) is offloaded by the job/writer layers.
|
||||
|
||||
On startup it reconciles DB vs. disk: ``active``/``verifying`` rows left by a
|
||||
previous run are reset to ``queued`` and resumed from persisted offsets, and
|
||||
orphaned ``.part`` files with no live download row are swept.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
from comfy.cli_args import args
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.engine.job import DownloadJob, JobSpec
|
||||
from app.model_downloader.security import paths
|
||||
|
||||
# Backoff for retryable failures
|
||||
_BACKOFF_BASE = 2.0
|
||||
_BACKOFF_CAP = 300.0
|
||||
_MAX_ATTEMPTS = 6
|
||||
|
||||
|
||||
class Scheduler:
|
||||
def __init__(self) -> None:
|
||||
self._jobs: dict[str, DownloadJob] = {}
|
||||
self._tasks: dict[str, asyncio.Task] = {}
|
||||
self._backoff_until: dict[str, float] = {}
|
||||
self._pump_lock = asyncio.Lock()
|
||||
self._notify_cb: Optional[Callable[[str], None]] = None
|
||||
self._started = False
|
||||
|
||||
@property
|
||||
def max_active(self) -> int:
|
||||
return max(1, getattr(args, "download_max_active", 3))
|
||||
|
||||
def set_notify(self, cb: Optional[Callable[[str], None]]) -> None:
|
||||
self._notify_cb = cb
|
||||
|
||||
def get_job(self, download_id: str) -> Optional[DownloadJob]:
|
||||
return self._jobs.get(download_id)
|
||||
|
||||
def is_active(self, download_id: str) -> bool:
|
||||
return download_id in self._tasks
|
||||
|
||||
# ----- startup -----
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._started:
|
||||
return
|
||||
self._started = True
|
||||
try:
|
||||
await asyncio.to_thread(queries.reconcile_live_downloads)
|
||||
await asyncio.to_thread(self._sweep_orphan_temp_files)
|
||||
except Exception as e:
|
||||
logging.warning("[model_downloader] startup reconcile failed: %s", e)
|
||||
await self.pump()
|
||||
|
||||
@staticmethod
|
||||
def _sweep_orphan_temp_files() -> None:
|
||||
"""Remove ``.part`` files not referenced by a resumable download row.
|
||||
|
||||
Resumable partials are preserved; only truly orphaned temp files from
|
||||
crashed runs are deleted. ``FAILED`` is included because
|
||||
:meth:`DownloadManager.resume` explicitly permits resuming a
|
||||
retry-exhausted failed row: deleting its partial here while the
|
||||
per-segment offsets survive in the DB would make the next resume
|
||||
preallocate a fresh sparse file, skip every "complete" segment, and
|
||||
leave zero-filled holes that pass the size-only verification gate.
|
||||
"""
|
||||
live = {
|
||||
row.temp_path
|
||||
for row in queries.list_downloads()
|
||||
if row.status
|
||||
in (
|
||||
DownloadStatus.QUEUED,
|
||||
DownloadStatus.PAUSED,
|
||||
DownloadStatus.FAILED,
|
||||
)
|
||||
}
|
||||
for path in paths.iter_all_tmp_paths():
|
||||
if path in live:
|
||||
continue
|
||||
try:
|
||||
os.remove(path)
|
||||
logging.info("[model_downloader] removed orphan temp file: %s", path)
|
||||
except OSError as e:
|
||||
logging.warning("[model_downloader] could not remove %s: %s", path, e)
|
||||
|
||||
# ----- admission -----
|
||||
|
||||
async def pump(self) -> None:
|
||||
async with self._pump_lock:
|
||||
slots = self.max_active - len(self._tasks)
|
||||
if slots <= 0:
|
||||
return
|
||||
now = time.monotonic()
|
||||
candidates = await asyncio.to_thread(queries.list_queued_downloads)
|
||||
for row in candidates:
|
||||
if slots <= 0:
|
||||
break
|
||||
if row.id in self._tasks:
|
||||
continue
|
||||
if self._backoff_until.get(row.id, 0.0) > now:
|
||||
continue
|
||||
self._admit(row)
|
||||
slots -= 1
|
||||
|
||||
def _admit(self, row) -> None:
|
||||
spec = JobSpec(
|
||||
download_id=row.id,
|
||||
url=row.url,
|
||||
model_id=row.model_id,
|
||||
dest_path=row.dest_path,
|
||||
temp_path=row.temp_path,
|
||||
priority=row.priority,
|
||||
credential_id=row.credential_id,
|
||||
expected_sha256=row.expected_sha256,
|
||||
allow_any_extension=row.allow_any_extension,
|
||||
etag=row.etag,
|
||||
attempts=row.attempts,
|
||||
)
|
||||
job = DownloadJob(spec, notify_cb=self._notify_cb)
|
||||
self._jobs[row.id] = job
|
||||
self._tasks[row.id] = asyncio.ensure_future(self._run_job(job))
|
||||
|
||||
async def _run_job(self, job: DownloadJob) -> None:
|
||||
download_id = job.spec.download_id
|
||||
status = DownloadStatus.FAILED
|
||||
try:
|
||||
status = await job.run()
|
||||
except Exception as e: # run() is defensive, but never let a task die silently
|
||||
logging.error("[model_downloader] job %s crashed: %s", download_id, e)
|
||||
queries.update_download(
|
||||
download_id,
|
||||
status=DownloadStatus.FAILED,
|
||||
error=f"internal error: {e}",
|
||||
)
|
||||
if self._notify_cb:
|
||||
self._notify_cb(download_id)
|
||||
finally:
|
||||
self._tasks.pop(download_id, None)
|
||||
self._jobs.pop(download_id, None)
|
||||
|
||||
if status == DownloadStatus.QUEUED:
|
||||
if job.spec.attempts >= _MAX_ATTEMPTS:
|
||||
queries.update_download(
|
||||
download_id,
|
||||
status=DownloadStatus.FAILED,
|
||||
error=f"giving up after {job.spec.attempts} attempts",
|
||||
)
|
||||
if self._notify_cb:
|
||||
self._notify_cb(download_id)
|
||||
else:
|
||||
delay = min(
|
||||
_BACKOFF_CAP, _BACKOFF_BASE ** job.spec.attempts
|
||||
) + random.uniform(0, 1.0)
|
||||
self._backoff_until[download_id] = time.monotonic() + delay
|
||||
asyncio.ensure_future(self._delayed_pump(delay))
|
||||
await self.pump()
|
||||
|
||||
async def _delayed_pump(self, delay: float) -> None:
|
||||
await asyncio.sleep(delay)
|
||||
await self.pump()
|
||||
|
||||
|
||||
SCHEDULER = Scheduler()
|
||||
140
app/model_downloader/security/allowlist.py
Normal file
140
app/model_downloader/security/allowlist.py
Normal file
@ -0,0 +1,140 @@
|
||||
"""URL allowlist for server-side model fetches.
|
||||
|
||||
Default-deny. A URL is downloadable only when its parsed host + scheme are
|
||||
allowlisted AND (unless explicitly relaxed) its final filename ends in a
|
||||
known model extension.
|
||||
|
||||
The built-in host defaults mirror the frontend's ``isModelDownloadable``
|
||||
allowlist so the two flows agree on what is eligible; ``--download-allowed-hosts``
|
||||
extends it for self-hosted mirrors. Matching is done on ``urlparse().hostname``
|
||||
(never a raw string prefix) so userinfo tricks like
|
||||
``http://127.0.0.1@169.254.169.254/x.safetensors`` — whose real host is the
|
||||
metadata IP — cannot slip past.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
# host -> set of allowed schemes. Frontend parity (HuggingFace / Civitai /
|
||||
# localhost). Extra hosts from --download-allowed-hosts are https-only.
|
||||
_DEFAULT_ALLOWED_HOSTS: dict[str, set[str]] = {
|
||||
"huggingface.co": {"https"},
|
||||
"civitai.com": {"https"},
|
||||
"localhost": {"http", "https"},
|
||||
"127.0.0.1": {"http", "https"},
|
||||
}
|
||||
|
||||
# Hosts for which loopback addresses are intentionally permitted (the localhost
|
||||
# "download a local model" feature). Every other host's loopback resolution is
|
||||
# rejected by the SSRF resolver.
|
||||
LOOPBACK_HOSTS = frozenset({"localhost", "127.0.0.1", "::1"})
|
||||
|
||||
# Known model file extensions (frontend parity). Checked on the final filename.
|
||||
ALLOWED_MODEL_EXTENSIONS = (
|
||||
".safetensors",
|
||||
".sft",
|
||||
".ckpt",
|
||||
".pth",
|
||||
".pt",
|
||||
".gguf",
|
||||
".bin",
|
||||
)
|
||||
|
||||
|
||||
def _allowed_hosts() -> dict[str, set[str]]:
|
||||
hosts = {h: set(s) for h, s in _DEFAULT_ALLOWED_HOSTS.items()}
|
||||
for extra in getattr(args, "download_allowed_hosts", []) or []:
|
||||
host = extra.strip().lower()
|
||||
if host:
|
||||
hosts.setdefault(host, set()).add("https")
|
||||
return hosts
|
||||
|
||||
|
||||
def is_host_allowed(host: str | None, scheme: str | None) -> bool:
|
||||
"""True iff ``host`` is allowlisted for ``scheme``.
|
||||
|
||||
Used both for the initial URL and re-checked on every redirect hop,
|
||||
so a whitelisted URL cannot 30x into an off-list host.
|
||||
"""
|
||||
if not host or not scheme:
|
||||
return False
|
||||
allowed = _allowed_hosts().get(host.lower())
|
||||
return allowed is not None and scheme.lower() in allowed
|
||||
|
||||
|
||||
def has_allowed_extension(path: str, allow_any_extension: bool = False) -> bool:
|
||||
if allow_any_extension:
|
||||
return True
|
||||
return path.lower().endswith(ALLOWED_MODEL_EXTENSIONS)
|
||||
|
||||
|
||||
def filename_extension(name: str) -> str:
|
||||
"""Lowercased extension (including the leading dot) of a bare filename.
|
||||
|
||||
Returns ``""`` when there is no extension. A leading-dot name
|
||||
(``.safetensors``) is treated as having no extension (all stem), matching
|
||||
``os.path.splitext`` semantics so dotfiles aren't mistaken for typed files.
|
||||
"""
|
||||
base = name.replace("\\", "/").rsplit("/", 1)[-1]
|
||||
dot = base.rfind(".")
|
||||
if dot <= 0:
|
||||
return ""
|
||||
return base[dot:].lower()
|
||||
|
||||
|
||||
def is_allowed_extension_name(name: str) -> bool:
|
||||
"""True iff ``name`` ends in one of the known model extensions."""
|
||||
return name.lower().endswith(ALLOWED_MODEL_EXTENSIONS)
|
||||
|
||||
|
||||
def is_host_allowed_url(url: str) -> bool:
|
||||
"""True iff ``url`` parses and its host+scheme are allowlisted."""
|
||||
if not isinstance(url, str) or not url:
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError:
|
||||
return False
|
||||
return is_host_allowed(parsed.hostname, parsed.scheme)
|
||||
|
||||
|
||||
def url_path_extension(url: str) -> str:
|
||||
"""Extension of the URL *path* basename (query ignored), or ``""``."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError:
|
||||
return ""
|
||||
return filename_extension(parsed.path)
|
||||
|
||||
|
||||
def is_url_downloadable(url: str) -> bool:
|
||||
"""Coarse enqueue gate: host/scheme allowed and extension not disallowed.
|
||||
|
||||
Unlike :func:`is_url_allowed` (which demands a known extension *in the URL*),
|
||||
this also admits URLs whose path carries no extension at all — e.g. a Civitai
|
||||
``/api/download/models/<id>`` endpoint whose real filename only shows up in
|
||||
the redirect target / ``Content-Disposition``. The true extension is then
|
||||
resolved from the network and re-validated before the download is admitted.
|
||||
A path bearing an explicit *non-model* extension (``.zip``, ``.html``, ...)
|
||||
is still rejected here.
|
||||
"""
|
||||
if not is_host_allowed_url(url):
|
||||
return False
|
||||
ext = url_path_extension(url)
|
||||
return ext == "" or ext in ALLOWED_MODEL_EXTENSIONS
|
||||
|
||||
|
||||
def is_url_allowed(url: str, allow_any_extension: bool = False) -> bool:
|
||||
"""Check whether ``url`` is permitted as a server-side download source."""
|
||||
if not isinstance(url, str) or not url:
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError:
|
||||
return False
|
||||
if not is_host_allowed(parsed.hostname, parsed.scheme):
|
||||
return False
|
||||
return has_allowed_extension(parsed.path, allow_any_extension)
|
||||
132
app/model_downloader/security/paths.py
Normal file
132
app/model_downloader/security/paths.py
Normal file
@ -0,0 +1,132 @@
|
||||
"""Path resolution + traversal safety for downloads.
|
||||
|
||||
A ``model_id`` is a *relative destination path* of the form
|
||||
``<directory>/<filename>`` (e.g. ``loras/my_lora.safetensors``). This module
|
||||
turns one into an absolute on-disk path under one of ComfyUI's registered
|
||||
model folders, rejecting unknown folders, path traversal, and symlink escape.
|
||||
This is the only thing that composes destination paths, so the engine never
|
||||
touches user-supplied path strings directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import folder_paths
|
||||
|
||||
from app.model_downloader.constants import TMP_SUFFIX
|
||||
from app.model_downloader.security.allowlist import ALLOWED_MODEL_EXTENSIONS
|
||||
|
||||
# A model_id component is a single path segment of safe characters — no slashes,
|
||||
# no "..", no leading dots that could escape the target directory.
|
||||
_SEGMENT_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$")
|
||||
|
||||
|
||||
class InvalidModelId(ValueError):
|
||||
"""Raised when a model_id is malformed or names an unknown model folder."""
|
||||
|
||||
|
||||
def parse_model_id(model_id: str, allow_any_extension: bool = False) -> tuple[str, str]:
|
||||
"""Split ``<directory>/<filename>`` and validate both components.
|
||||
|
||||
Returns ``(directory, filename)``. Does not touch the filesystem.
|
||||
"""
|
||||
if not isinstance(model_id, str) or "/" not in model_id:
|
||||
raise InvalidModelId(
|
||||
f"model_id must be '<directory>/<filename>', got {model_id!r}"
|
||||
)
|
||||
directory, _, filename = model_id.partition("/")
|
||||
if "/" in filename or not directory or not filename:
|
||||
raise InvalidModelId(
|
||||
f"model_id must have exactly one '/' separator, got {model_id!r}"
|
||||
)
|
||||
if not _SEGMENT_RE.match(directory):
|
||||
raise InvalidModelId(f"invalid directory segment {directory!r}")
|
||||
if not _SEGMENT_RE.match(filename):
|
||||
raise InvalidModelId(f"invalid filename segment {filename!r}")
|
||||
if not allow_any_extension and not filename.lower().endswith(
|
||||
ALLOWED_MODEL_EXTENSIONS
|
||||
):
|
||||
raise InvalidModelId(
|
||||
f"filename must end with a known model extension "
|
||||
f"{ALLOWED_MODEL_EXTENSIONS}, got {filename!r}"
|
||||
)
|
||||
if directory not in folder_paths.folder_names_and_paths:
|
||||
raise InvalidModelId(f"unknown model folder {directory!r}")
|
||||
return directory, filename
|
||||
|
||||
|
||||
def apply_extension(model_id: str, ext: str) -> str:
|
||||
"""Return ``model_id`` with its filename forced to end in ``ext``.
|
||||
|
||||
``ext`` includes the leading dot (e.g. ``".safetensors"``). If the filename
|
||||
already ends in a *known model extension* it is replaced; otherwise ``ext``
|
||||
is appended (so ``loras/mymodel`` -> ``loras/mymodel.safetensors`` and
|
||||
``loras/mymodel.ckpt`` -> ``loras/mymodel.safetensors``). A filename with a
|
||||
non-model suffix (``my.model.v2``) is treated as an extensionless stem and
|
||||
``ext`` is appended. The directory part is left untouched; validation is
|
||||
still the caller's job via :func:`parse_model_id`.
|
||||
"""
|
||||
directory, sep, filename = model_id.partition("/")
|
||||
if not sep:
|
||||
return model_id # malformed; parse_model_id will reject it
|
||||
low = filename.lower()
|
||||
for known in ALLOWED_MODEL_EXTENSIONS:
|
||||
if low.endswith(known):
|
||||
filename = filename[: -len(known)]
|
||||
break
|
||||
return f"{directory}{sep}{filename}{ext}"
|
||||
|
||||
|
||||
def resolve_existing(model_id: str, allow_any_extension: bool = False) -> Optional[str]:
|
||||
"""Return the absolute path of an installed model, or None if missing.
|
||||
|
||||
Honours ``extra_model_paths.yaml`` transparently via ``get_full_path``.
|
||||
"""
|
||||
directory, filename = parse_model_id(model_id, allow_any_extension)
|
||||
return folder_paths.get_full_path(directory, filename)
|
||||
|
||||
|
||||
def resolve_destination(
|
||||
model_id: str, allow_any_extension: bool = False
|
||||
) -> tuple[str, str]:
|
||||
"""Return ``(final_path, temp_path)`` for a download.
|
||||
|
||||
Downloads land at the first registered path for the model's directory
|
||||
(the "primary" location). ``temp_path`` is a sibling ``.part`` file that
|
||||
is atomically renamed onto ``final_path`` on success. The result is
|
||||
asserted to stay within the registered root (defence in depth on top of
|
||||
the segment regex).
|
||||
"""
|
||||
directory, filename = parse_model_id(model_id, allow_any_extension)
|
||||
roots = folder_paths.get_folder_paths(directory)
|
||||
if not roots:
|
||||
raise InvalidModelId(f"no on-disk path registered for folder {directory!r}")
|
||||
root = os.path.realpath(roots[0])
|
||||
final_path = os.path.realpath(os.path.join(root, filename))
|
||||
if final_path != root and not final_path.startswith(root + os.sep):
|
||||
raise InvalidModelId(f"resolved path escapes model root: {model_id!r}")
|
||||
temp_path = f"{final_path}{TMP_SUFFIX}"
|
||||
return final_path, temp_path
|
||||
|
||||
|
||||
def iter_all_tmp_paths() -> Iterator[str]:
|
||||
"""Yield this subsystem's temp files under every registered model folder.
|
||||
|
||||
Matches only the distinctive ``TMP_SUFFIX`` so the startup orphan sweep
|
||||
can never delete temp files created by other tools.
|
||||
"""
|
||||
seen_roots: set[str] = set()
|
||||
for directory in list(folder_paths.folder_names_and_paths.keys()):
|
||||
for root in folder_paths.get_folder_paths(directory):
|
||||
if root in seen_roots or not os.path.isdir(root):
|
||||
continue
|
||||
seen_roots.add(root)
|
||||
try:
|
||||
for entry in os.scandir(root):
|
||||
if entry.is_file() and entry.name.endswith(TMP_SUFFIX):
|
||||
yield entry.path
|
||||
except OSError:
|
||||
continue
|
||||
163
app/model_downloader/security/ssrf.py
Normal file
163
app/model_downloader/security/ssrf.py
Normal file
@ -0,0 +1,163 @@
|
||||
"""SSRF / exfiltration defenses.
|
||||
|
||||
Two cooperating layers:
|
||||
|
||||
1. :class:`ValidatingResolver` is installed on the shared connector. Every
|
||||
connection — the initial probe and every segment GET, including ones made
|
||||
after a redirect — resolves its host through this resolver, which rejects
|
||||
any address that lands on a private / special-use IP range. Because the
|
||||
resolve and the connect happen together inside the connector, there is no
|
||||
check-then-connect window for DNS rebinding to exploit.
|
||||
|
||||
2. :func:`check_redirect_hop` re-validates every hop. The host allowlist gates
|
||||
only the *initial* user-supplied URL (anti-SSRF for arbitrary input);
|
||||
legitimate downloads from allowlisted origins redirect to presigned CDN
|
||||
hosts that are deliberately NOT on the allowlist (HF ->
|
||||
``cdn-lfs*.huggingface.co``, Civitai -> signed Cloudflare/S3), so hops are
|
||||
instead screened for scheme, embedded credentials, and — via the resolver
|
||||
above — private IPs. Credentials are only ever attached when a hop's host
|
||||
exactly matches a stored credential, so they are dropped on the CDN hop.
|
||||
Loopback (the "download a local model" feature) is exempt from IP filtering
|
||||
only for the initial URL: a *redirect* may never target a loopback host or
|
||||
a blocked IP-literal, which the resolver alone can't enforce (it exempts
|
||||
loopback literals and never sees IP literals through DNS).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from aiohttp.abc import AbstractResolver
|
||||
from aiohttp.resolver import DefaultResolver
|
||||
|
||||
from app.model_downloader.security.allowlist import LOOPBACK_HOSTS
|
||||
|
||||
# Cap the redirect chain length a hop may use.
|
||||
MAX_REDIRECTS = 5
|
||||
|
||||
|
||||
class SSRFError(Exception):
|
||||
"""A hop failed an SSRF / allowlist check."""
|
||||
|
||||
|
||||
def is_scheme_allowed(scheme: str | None, host: str | None) -> bool:
|
||||
"""True iff ``scheme`` is permitted for ``host`` on a download hop.
|
||||
|
||||
https is always allowed; plain http only for loopback/approved dev hosts.
|
||||
"""
|
||||
if not scheme:
|
||||
return False
|
||||
scheme = scheme.lower()
|
||||
if scheme == "https":
|
||||
return True
|
||||
if scheme == "http":
|
||||
return bool(host) and host.lower() in LOOPBACK_HOSTS
|
||||
return False
|
||||
|
||||
|
||||
def is_blocked_ip(ip_str: str) -> bool:
|
||||
"""True for any address we refuse to connect to.
|
||||
|
||||
Covers loopback, link-local (incl. 169.254.169.254 cloud metadata),
|
||||
RFC1918 private ranges, unique-local (ULA), unspecified (0.0.0.0/::),
|
||||
multicast and other reserved ranges.
|
||||
"""
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return True # unparseable -> refuse
|
||||
# On CPython before the gh-113171 fix (backported to 3.12.4/3.11.9/
|
||||
# 3.10.14/3.9.19) the is_* properties don't see through IPv4-mapped IPv6
|
||||
# (e.g. ::ffff:169.254.169.254), so resolve and re-check the embedded IPv4
|
||||
# to keep mapped metadata/private addresses from slipping past the filter.
|
||||
mapped = getattr(ip, "ipv4_mapped", None)
|
||||
if mapped is not None:
|
||||
ip = mapped
|
||||
return (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
or ip.is_reserved
|
||||
or ip.is_unspecified
|
||||
)
|
||||
|
||||
|
||||
class ValidatingResolver(AbstractResolver):
|
||||
"""Delegating resolver that drops blocked IPs from every resolution.
|
||||
|
||||
If a hostname resolves only to blocked addresses, the connection fails
|
||||
closed with an :class:`OSError`, which aiohttp surfaces as a connection
|
||||
error to the caller.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._inner = DefaultResolver()
|
||||
|
||||
async def resolve(self, host, port=0, family=socket.AF_INET):
|
||||
infos = await self._inner.resolve(host, port, family)
|
||||
# localhost/127.0.0.1 are an explicit, opt-in allowlist feature.
|
||||
if isinstance(host, str) and host.lower() in LOOPBACK_HOSTS:
|
||||
return infos
|
||||
safe = [info for info in infos if not is_blocked_ip(info["host"])]
|
||||
if not safe:
|
||||
raise OSError(
|
||||
f"refusing to connect to {host!r}: resolves only to "
|
||||
f"private/special-use addresses"
|
||||
)
|
||||
return safe
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._inner.close()
|
||||
|
||||
|
||||
def check_redirect_hop(url: str, *, is_initial_url: bool = False) -> str:
|
||||
"""Validate one hop's URL.
|
||||
|
||||
Returns the URL unchanged on success; raises :class:`SSRFError` otherwise.
|
||||
Requires https for external hosts (http only for loopback/approved dev
|
||||
hosts) and forbids credentials-in-URL. The host is NOT re-checked against
|
||||
the allowlist (CDN redirect targets are off-list by design); credential
|
||||
leakage is prevented by exact host matching at attach time, and the landing
|
||||
filename's extension is gated separately by the caller.
|
||||
|
||||
Loopback/blocked-IP screening: the connector's resolver filters resolvable
|
||||
hostnames but exempts literal loopback hosts (``localhost``/``127.0.0.1``/
|
||||
``::1``) and never sees IP literals through DNS. That loopback exemption is
|
||||
legitimate only for the *initial* user-supplied URL (``is_initial_url``);
|
||||
on a redirect hop we reject loopback hosts and any blocked IP-literal here,
|
||||
so a 30x can't steer a server-side GET at loopback/internal services.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError as e:
|
||||
raise SSRFError(f"unparseable redirect URL {url!r}: {e}") from e
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise SSRFError(f"redirect URL has no host: {url!r}")
|
||||
if not is_scheme_allowed(parsed.scheme, host):
|
||||
raise SSRFError(
|
||||
f"redirect to disallowed scheme {parsed.scheme!r} for host "
|
||||
f"{host!r} (https required for external hosts)"
|
||||
)
|
||||
if parsed.username or parsed.password:
|
||||
raise SSRFError("credentials-in-URL are not allowed")
|
||||
host_is_loopback = host.lower() in LOOPBACK_HOSTS
|
||||
if not is_initial_url and host_is_loopback:
|
||||
raise SSRFError(f"redirect to loopback host {host!r} is not allowed")
|
||||
# IP-literal targets never go through DNS, so the connector's resolver can't
|
||||
# screen them — check them directly. The only blocked IP allowed through is
|
||||
# a loopback literal on the initial URL (handled by the exemption above).
|
||||
try:
|
||||
ipaddress.ip_address(host)
|
||||
except ValueError:
|
||||
is_ip_literal = False
|
||||
else:
|
||||
is_ip_literal = True
|
||||
if is_ip_literal and is_blocked_ip(host) and not (
|
||||
is_initial_url and host_is_loopback
|
||||
):
|
||||
raise SSRFError(f"redirect to blocked internal address {host!r}")
|
||||
return url
|
||||
49
app/model_downloader/verify/checksum.py
Normal file
49
app/model_downloader/verify/checksum.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""Hub-checksum verification = SHA256.
|
||||
|
||||
Only used to confirm a download matches a *provided* ``expected_sha256``. It
|
||||
is NOT the dedup key (that is blake3, owned by the assets system). The full
|
||||
sequential read happens at most once, here, only when a checksum was supplied.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Callable, Optional
|
||||
|
||||
_CHUNK = 8 * 1024 * 1024
|
||||
|
||||
InterruptCheck = Callable[[], bool]
|
||||
|
||||
|
||||
class ChecksumError(Exception):
|
||||
"""The computed SHA256 did not match the expected value."""
|
||||
|
||||
|
||||
def sha256_file(path: str, interrupt_check: Optional[InterruptCheck] = None) -> Optional[str]:
|
||||
"""Stream the file and return its lowercase hex SHA256.
|
||||
|
||||
Returns ``None`` if interrupted via ``interrupt_check``.
|
||||
"""
|
||||
h = hashlib.sha256()
|
||||
with open(path, "rb") as f:
|
||||
while True:
|
||||
if interrupt_check is not None and interrupt_check():
|
||||
return None
|
||||
chunk = f.read(_CHUNK)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def verify_sha256(
|
||||
path: str, expected: str, interrupt_check: Optional[InterruptCheck] = None
|
||||
) -> None:
|
||||
"""Raise :class:`ChecksumError` unless the file's SHA256 matches ``expected``."""
|
||||
actual = sha256_file(path, interrupt_check)
|
||||
if actual is None:
|
||||
return # interrupted; caller will re-verify on resume
|
||||
if actual.lower() != expected.lower():
|
||||
raise ChecksumError(
|
||||
f"sha256 mismatch: expected {expected.lower()}, got {actual.lower()}"
|
||||
)
|
||||
53
app/model_downloader/verify/dedup.py
Normal file
53
app/model_downloader/verify/dedup.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""Dedup + catalog handoff — reuse the assets system.
|
||||
|
||||
We do NOT build a parallel indexer. "Do I already have it?" is answered by
|
||||
``resolve_existing`` (path) at enqueue time and, where a hash is known, by the
|
||||
assets blake3 catalog. After a completed download we register the file
|
||||
through the assets ingest path so it is cataloged and (eventually) hashed by
|
||||
the existing enrichment worker.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _register_sync(abs_path: str) -> Optional[str]:
|
||||
"""Register a finished file into the assets catalog. Returns asset hash."""
|
||||
try:
|
||||
from app.assets.services.ingest import register_file_in_place
|
||||
except Exception as e: # assets package import failure — non-fatal
|
||||
logging.debug("[model_downloader] assets ingest unavailable: %s", e)
|
||||
return None
|
||||
try:
|
||||
result = register_file_in_place(abs_path, name=os.path.basename(abs_path), tags=[])
|
||||
return result.asset.hash if result and result.asset else None
|
||||
except Exception as e:
|
||||
# The file is already safely on disk; cataloging is best-effort.
|
||||
logging.warning(
|
||||
"[model_downloader] could not register %s into assets catalog: %s",
|
||||
abs_path, e,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def register_completed(abs_path: str) -> Optional[str]:
|
||||
"""Catalog a completed download via the assets system (off the event loop)."""
|
||||
return await asyncio.to_thread(_register_sync, abs_path)
|
||||
|
||||
|
||||
def _find_by_hash_sync(blake3_hex: str) -> Optional[str]:
|
||||
try:
|
||||
from app.assets.services.asset_management import get_asset_by_hash
|
||||
except Exception:
|
||||
return None
|
||||
asset = get_asset_by_hash("blake3:" + blake3_hex)
|
||||
return asset.hash if asset is not None else None
|
||||
|
||||
|
||||
async def find_existing_by_hash(blake3_hex: str) -> Optional[str]:
|
||||
"""Pure DB lookup — never triggers hashing on the hot path."""
|
||||
return await asyncio.to_thread(_find_by_hash_sync, blake3_hex)
|
||||
86
app/model_downloader/verify/structural.py
Normal file
86
app/model_downloader/verify/structural.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""Cheap structural validation, no full read.
|
||||
|
||||
For ``.safetensors``/``.sft`` we parse the header (first few KB): it carries
|
||||
the tensor table and the byte length of the data region. We assert
|
||||
``file_size == 8 + header_len + data_region_len``. This detects truncation
|
||||
and most corruption for free, before any crypto hashing. Other extensions
|
||||
have no cheap structural check and pass through.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
from typing import Optional
|
||||
|
||||
_SAFETENSORS_EXTS = (".safetensors", ".sft")
|
||||
# A sane upper bound so a corrupt header length can't make us read gigabytes.
|
||||
_MAX_HEADER_BYTES = 100 * 1024 * 1024
|
||||
|
||||
|
||||
class StructuralError(Exception):
|
||||
"""The file failed its structural integrity check."""
|
||||
|
||||
|
||||
def validate(path: str, name_hint: Optional[str] = None) -> None:
|
||||
"""Validate the file at ``path``. Raises :class:`StructuralError` on failure.
|
||||
|
||||
The file format is detected from ``name_hint`` when provided, otherwise from
|
||||
``path``. Callers that download into a temp file with an opaque suffix (e.g.
|
||||
``*.comfy-download.part``) must pass the final destination name as
|
||||
``name_hint`` so the format check is not silently skipped.
|
||||
"""
|
||||
lower = (name_hint or path).lower()
|
||||
if lower.endswith(_SAFETENSORS_EXTS):
|
||||
_validate_safetensors(path)
|
||||
# No structural check for other formats; the size + (optional) checksum
|
||||
# gates in the engine cover those.
|
||||
|
||||
|
||||
def _validate_safetensors(path: str) -> None:
|
||||
file_size = os.path.getsize(path)
|
||||
if file_size < 8:
|
||||
raise StructuralError(f"file too small to be safetensors ({file_size} bytes)")
|
||||
with open(path, "rb") as f:
|
||||
header_len = struct.unpack("<Q", f.read(8))[0]
|
||||
if header_len <= 0 or header_len > _MAX_HEADER_BYTES:
|
||||
raise StructuralError(f"implausible safetensors header length {header_len}")
|
||||
if 8 + header_len > file_size:
|
||||
raise StructuralError("safetensors header extends past end of file")
|
||||
try:
|
||||
header = json.loads(f.read(header_len).decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
||||
raise StructuralError(f"safetensors header is not valid JSON: {e}") from e
|
||||
|
||||
if not isinstance(header, dict):
|
||||
raise StructuralError("safetensors header is not a JSON object")
|
||||
|
||||
data_len = 0
|
||||
for name, entry in header.items():
|
||||
if name == "__metadata__":
|
||||
continue
|
||||
if not isinstance(entry, dict) or "data_offsets" not in entry:
|
||||
raise StructuralError(f"tensor {name!r} missing data_offsets")
|
||||
offsets = entry["data_offsets"]
|
||||
if not (isinstance(offsets, list) and len(offsets) == 2):
|
||||
raise StructuralError(f"tensor {name!r} has malformed data_offsets")
|
||||
begin, end = offsets
|
||||
# bool is an int subclass; reject it explicitly to avoid True/False offsets.
|
||||
if (
|
||||
not isinstance(begin, int)
|
||||
or not isinstance(end, int)
|
||||
or isinstance(begin, bool)
|
||||
or isinstance(end, bool)
|
||||
or begin < 0
|
||||
or end < begin
|
||||
):
|
||||
raise StructuralError(f"tensor {name!r} has malformed data_offsets")
|
||||
data_len = max(data_len, end)
|
||||
|
||||
expected = 8 + header_len + data_len
|
||||
if file_size != expected:
|
||||
raise StructuralError(
|
||||
f"size mismatch: file is {file_size} bytes, header implies {expected} "
|
||||
f"(8 + {header_len} header + {data_len} data)"
|
||||
)
|
||||
@ -33,6 +33,28 @@ class EnumAction(argparse.Action):
|
||||
setattr(namespace, self.dest, value)
|
||||
|
||||
|
||||
def _positive_int(value: str) -> int:
|
||||
"""argparse type that rejects zero and negative integers."""
|
||||
try:
|
||||
ivalue = int(value)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(f"{value!r} is not an integer")
|
||||
if ivalue <= 0:
|
||||
raise argparse.ArgumentTypeError(f"{value!r} must be a positive integer (> 0)")
|
||||
return ivalue
|
||||
|
||||
|
||||
def _non_negative_int(value: str) -> int:
|
||||
"""argparse type that rejects negatives but allows zero (a disable sentinel)."""
|
||||
try:
|
||||
ivalue = int(value)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(f"{value!r} is not an integer")
|
||||
if ivalue < 0:
|
||||
raise argparse.ArgumentTypeError(f"{value!r} must be a non-negative integer (>= 0)")
|
||||
return ivalue
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
||||
@ -243,6 +265,15 @@ parser.add_argument("--enable-assets", action="store_true", help="Enable the ass
|
||||
parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
|
||||
parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")
|
||||
|
||||
# ----- Model download manager (PRD: docs/prd-download-manager.md) -----
|
||||
parser.add_argument("--download-segments", type=_positive_int, default=8, metavar="N", help="Number of parallel HTTP range segments per file for the model download manager (default: 8).")
|
||||
parser.add_argument("--download-max-active", type=_positive_int, default=3, metavar="N", help="Maximum number of model downloads running concurrently (default: 3).")
|
||||
parser.add_argument("--download-max-connections-per-host", type=_positive_int, default=16, metavar="N", help="Maximum simultaneous connections to a single host for the download manager (default: 16).")
|
||||
parser.add_argument("--download-chunk-size", type=_positive_int, default=4 * 1024 * 1024, metavar="BYTES", help="Read chunk size in bytes for the download manager (default: 4 MiB).")
|
||||
parser.add_argument("--download-max-bytes", type=_non_negative_int, default=1024 * 1024 * 1024 * 1024, metavar="BYTES", help="Maximum size in bytes of a single download; aborts transfers that exceed it (guards against malicious/non-conforming hosts filling the disk). Set to 0 to disable (default: 1 TiB).")
|
||||
parser.add_argument("--download-allowed-hosts", type=str, nargs="*", default=[], metavar="HOST", help="Additional hostnames to add to the download manager allowlist (https only). The built-in defaults always include huggingface.co and civitai.com.")
|
||||
parser.add_argument("--download-allow-any-extension", action="store_true", help="Allow the download manager to fetch files with any extension (default: only known model extensions like .safetensors).")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
|
||||
@ -1216,7 +1216,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
bias_dtype=input.dtype,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=want_requant,
|
||||
want_requant=True,
|
||||
)
|
||||
weight = weight.to(dtype=input.dtype)
|
||||
else:
|
||||
|
||||
@ -104,6 +104,7 @@ _CORE_FEATURE_FLAGS: dict[str, Any] = {
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
"node_replacements": True,
|
||||
"assets": args.enable_assets,
|
||||
"server_side_model_downloads": True,
|
||||
}
|
||||
|
||||
# CLI-provided flags cannot overwrite core flags
|
||||
|
||||
@ -1261,158 +1261,6 @@ class DynamicSlot(ComfyTypeI):
|
||||
out_dict[input_type][finalized_id] = value
|
||||
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||
|
||||
@comfytype(io_type="COMFY_DYNAMICGROUP_V3")
|
||||
class DynamicGroup(ComfyTypeI):
|
||||
"""A repeatable group of widget inputs (e.g. lora_name + strength stacked into N rows).
|
||||
|
||||
At execution time the node receives a ``list[dict]`` where each element is a row.
|
||||
|
||||
Example::
|
||||
|
||||
io.DynamicGroup.Input(
|
||||
"loras",
|
||||
template=[
|
||||
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras")),
|
||||
io.Float.Input("strength", default=1.0, min=-100, max=100, step=0.01),
|
||||
],
|
||||
min=0,
|
||||
max=50,
|
||||
)
|
||||
# execute receives: loras: list[dict] = [{"lora_name": "x.safetensors", "strength": 1.0}, ...]
|
||||
"""
|
||||
|
||||
Type = list[dict[str, Any]]
|
||||
_MaxRows = 100
|
||||
|
||||
class Input(DynamicInput):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
template: list["Input"],
|
||||
min: int = 0,
|
||||
max: int = 50,
|
||||
display_name: str = None,
|
||||
optional: bool = False,
|
||||
tooltip: str = None,
|
||||
lazy: bool = None,
|
||||
extra_dict=None,
|
||||
group_name: str = "Group",
|
||||
):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
# Validate template entries: only WidgetInput subclasses, no nesting
|
||||
assert len(template) > 0, "DynamicGroup template must have at least one field."
|
||||
for t in template:
|
||||
assert isinstance(t, WidgetInput), (
|
||||
f"DynamicGroup template field '{t.id}' must be a WidgetInput subclass "
|
||||
f"(Combo, Float, Int, String, Boolean, Color). Got {type(t).__name__}."
|
||||
)
|
||||
assert not isinstance(t, DynamicInput), (
|
||||
f"DynamicGroup template field '{t.id}' must not be a DynamicInput. "
|
||||
"Nesting dynamic inputs inside DynamicGroup is not supported."
|
||||
)
|
||||
# Enforce unique field ids within template
|
||||
field_ids = [t.id for t in template]
|
||||
assert len(field_ids) == len(set(field_ids)), (
|
||||
f"DynamicGroup template field ids must be unique within a row. Got: {field_ids}"
|
||||
)
|
||||
# Reject "." in group id and template field ids: slot_id encoding uses "." as a
|
||||
# delimiter (<group_id>.<row>.<field_id>), so any "." in these names would cause
|
||||
# path.split(".") to produce the wrong number of segments during decoding.
|
||||
assert "." not in id, (
|
||||
f"DynamicGroup id must not contain '.'. Got: '{id}'"
|
||||
)
|
||||
for t in template:
|
||||
assert "." not in t.id, (
|
||||
f"DynamicGroup template field id must not contain '.'. Got: '{t.id}'"
|
||||
)
|
||||
assert min >= 0, "DynamicGroup min must be >= 0."
|
||||
assert max >= 1, "DynamicGroup max must be >= 1."
|
||||
assert max <= DynamicGroup._MaxRows, f"DynamicGroup max must be <= {DynamicGroup._MaxRows}."
|
||||
assert min <= max, "DynamicGroup min must be <= max."
|
||||
self.template = template
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.group_name = group_name
|
||||
|
||||
def get_all(self) -> list["Input"]:
|
||||
return [self] + list(self.template)
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": create_input_dict_v1(self.template),
|
||||
"min": self.min,
|
||||
"max": self.max,
|
||||
"group_name": self.group_name,
|
||||
})
|
||||
|
||||
def validate(self):
|
||||
for t in self.template:
|
||||
t.validate()
|
||||
|
||||
@staticmethod
|
||||
def _expand_schema_for_dynamic(
|
||||
out_dict: dict[str, Any],
|
||||
live_inputs: dict[str, Any],
|
||||
value: tuple[str, dict[str, Any]],
|
||||
input_type: str,
|
||||
curr_prefix: list[str] | None,
|
||||
):
|
||||
info = value[1]
|
||||
min_rows: int = info.get("min", 0)
|
||||
max_rows: int = info.get("max", DynamicGroup._MaxRows)
|
||||
template: dict[str, Any] = info.get("template", {})
|
||||
|
||||
# Collect all template field specs across required/optional sections
|
||||
field_specs: list[tuple[str, tuple[str, dict[str, Any]], bool]] = []
|
||||
for field_required_key in ("required", "optional"):
|
||||
section = template.get(field_required_key, {})
|
||||
is_required_field = field_required_key == "required"
|
||||
for field_id, field_value in section.items():
|
||||
field_specs.append((field_id, field_value, is_required_field))
|
||||
|
||||
# Determine how many rows are currently present by scanning live_inputs
|
||||
finalized_prefix = finalize_prefix(curr_prefix)
|
||||
present_rows = 0
|
||||
for live_key in live_inputs:
|
||||
# Keys look like "<prefix>.<row>.<field_id>"
|
||||
if live_key.startswith(finalized_prefix + "."):
|
||||
remainder = live_key[len(finalized_prefix) + 1:]
|
||||
parts = remainder.split(".", 1)
|
||||
if len(parts) >= 1:
|
||||
try:
|
||||
row_idx = int(parts[0])
|
||||
present_rows = max(present_rows, row_idx + 1)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if present_rows > max_rows:
|
||||
raise ValueError(
|
||||
f"DynamicGroup input '{finalized_prefix}' received {present_rows} rows but max is {max_rows}."
|
||||
)
|
||||
row_count = max(min_rows, present_rows)
|
||||
|
||||
for row in range(row_count):
|
||||
for field_id, field_value, is_required_field in field_specs:
|
||||
slot_id = f"{finalized_prefix}.{row}.{field_id}"
|
||||
# The first `min_rows` rows are required if the field itself is required
|
||||
if row < min_rows and is_required_field:
|
||||
out_dict["required"][slot_id] = field_value
|
||||
else:
|
||||
out_dict["optional"][slot_id] = field_value
|
||||
# Register into dynamic_paths so build_nested_inputs places value at the right path
|
||||
out_dict["dynamic_paths"][slot_id] = slot_id
|
||||
|
||||
# Track the list root path so build_nested_inputs can convert the index dict to a list
|
||||
out_dict.setdefault("list_paths", set()).add(finalized_prefix)
|
||||
|
||||
# Handle the empty case (0 rows) – emit an empty-list default for the parent.
|
||||
# This must only fire when there are genuinely no rows; otherwise the parent
|
||||
# path would clobber the per-row dict built from the slot ids above.
|
||||
if row_count == 0:
|
||||
out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix
|
||||
out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_LIST
|
||||
|
||||
|
||||
@comfytype(io_type="IMAGECOMPARE")
|
||||
class ImageCompare(ComfyTypeI):
|
||||
Type = dict
|
||||
@ -1570,8 +1418,6 @@ def setup_dynamic_input_funcs():
|
||||
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
|
||||
# DynamicSlot.Input
|
||||
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
|
||||
# DynamicGroup.Input
|
||||
register_dynamic_input_func(DynamicGroup.io_type, DynamicGroup._expand_schema_for_dynamic)
|
||||
|
||||
if len(DYNAMIC_INPUT_LOOKUP) == 0:
|
||||
setup_dynamic_input_funcs()
|
||||
@ -1583,8 +1429,6 @@ class V3Data(TypedDict):
|
||||
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
||||
dynamic_paths_default_value: dict[str, Any]
|
||||
'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.'
|
||||
list_paths: set[str]
|
||||
'Set of top-level keys whose index-keyed dict values should be converted to a sorted list[dict] after build_nested_inputs runs.'
|
||||
create_dynamic_tuple: bool
|
||||
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
||||
|
||||
@ -1926,7 +1770,6 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
"optional": {},
|
||||
"dynamic_paths": {},
|
||||
"dynamic_paths_default_value": {},
|
||||
"list_paths": set(),
|
||||
}
|
||||
d = d.copy()
|
||||
# ignore hidden for parsing
|
||||
@ -1942,10 +1785,6 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None)
|
||||
if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0:
|
||||
v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value
|
||||
# list_paths: keys whose nested dict should be post-converted to a sorted list[dict]
|
||||
list_paths = out_dict.pop("list_paths", None)
|
||||
if list_paths:
|
||||
v3_data["list_paths"] = list_paths
|
||||
return out_dict, hidden, v3_data
|
||||
|
||||
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
||||
@ -1981,12 +1820,10 @@ def add_to_dict_v1(i: Input, d: dict):
|
||||
|
||||
class DynamicPathsDefaultValue:
|
||||
EMPTY_DICT = "empty_dict"
|
||||
EMPTY_LIST = "empty_list"
|
||||
|
||||
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
paths = v3_data.get("dynamic_paths", None)
|
||||
default_value_dict = v3_data.get("dynamic_paths_default_value", {})
|
||||
list_paths: set[str] = v3_data.get("list_paths", set()) or set()
|
||||
if paths is None:
|
||||
return values
|
||||
values = values.copy()
|
||||
@ -2009,8 +1846,6 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
default_option = default_value_dict.get(key, None)
|
||||
if default_option == DynamicPathsDefaultValue.EMPTY_DICT:
|
||||
value = {}
|
||||
elif default_option == DynamicPathsDefaultValue.EMPTY_LIST:
|
||||
value = []
|
||||
if create_tuple:
|
||||
value = (value, key)
|
||||
current[p] = value
|
||||
@ -2018,34 +1853,6 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
current = current.setdefault(p, {})
|
||||
|
||||
values.update(result)
|
||||
|
||||
# Post-pass: convert index-keyed dicts to sorted lists for io.DynamicGroup fields
|
||||
for list_path in list_paths:
|
||||
parts = list_path.split(".")
|
||||
# Navigate to the parent container, then convert the leaf
|
||||
container = values
|
||||
for part in parts[:-1]:
|
||||
if not isinstance(container, dict) or part not in container:
|
||||
container = None
|
||||
break
|
||||
container = container[part]
|
||||
if container is None:
|
||||
continue
|
||||
leaf_key = parts[-1]
|
||||
leaf = container.get(leaf_key, None)
|
||||
if isinstance(leaf, dict):
|
||||
try:
|
||||
sorted_rows = [leaf[k] for k in sorted(leaf.keys(), key=int)]
|
||||
container[leaf_key] = sorted_rows
|
||||
except (ValueError, TypeError):
|
||||
# Keys are not all integers; leave as-is
|
||||
pass
|
||||
elif isinstance(leaf, list):
|
||||
# Already a list (e.g. the EMPTY_LIST default was applied above)
|
||||
pass
|
||||
elif leaf is None:
|
||||
container[leaf_key] = []
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@ -2610,9 +2417,7 @@ __all__ = [
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
"DynamicCombo",
|
||||
"DynamicSlot",
|
||||
"Autogrow",
|
||||
"DynamicGroup",
|
||||
# Other classes
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
|
||||
@ -249,18 +249,22 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
||||
input_tokens_price = 2
|
||||
output_text_tokens_price = 12.0
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion == "gemini-3.1-flash-lite-preview":
|
||||
elif response.modelVersion in ("gemini-3.1-flash-lite-preview", "gemini-3.1-flash-lite"):
|
||||
input_tokens_price = 0.25
|
||||
output_text_tokens_price = 1.50
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion == "gemini-3-pro-image-preview":
|
||||
elif response.modelVersion in ("gemini-3-pro-image-preview", "gemini-3-pro-image"):
|
||||
input_tokens_price = 2
|
||||
output_text_tokens_price = 12.0
|
||||
output_image_tokens_price = 120.0
|
||||
elif response.modelVersion == "gemini-3.1-flash-image-preview":
|
||||
elif response.modelVersion in ("gemini-3.1-flash-image-preview", "gemini-3.1-flash-image"):
|
||||
input_tokens_price = 0.5
|
||||
output_text_tokens_price = 3.0
|
||||
output_image_tokens_price = 60.0
|
||||
elif response.modelVersion == "gemini-3.1-flash-lite-image":
|
||||
input_tokens_price = 0.25
|
||||
output_text_tokens_price = 1.50
|
||||
output_image_tokens_price = 30.0
|
||||
else:
|
||||
return None
|
||||
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
|
||||
@ -1302,7 +1306,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
def _nano_banana_2_v2_model_inputs():
|
||||
def _nano_banana_2_v2_model_inputs(resolutions: list[str]):
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
@ -1329,8 +1333,8 @@ def _nano_banana_2_v2_model_inputs():
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["1K", "2K", "4K"],
|
||||
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
|
||||
options=resolutions,
|
||||
tooltip="Target output resolution.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"thinking_level",
|
||||
@ -1376,7 +1380,11 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"Nano Banana 2 (Gemini 3.1 Flash Image)",
|
||||
_nano_banana_2_v2_model_inputs(),
|
||||
_nano_banana_2_v2_model_inputs(resolutions=["1K", "2K", "4K"]),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Nano Banana 2 Lite",
|
||||
_nano_banana_2_v2_model_inputs(resolutions=["1K"]),
|
||||
),
|
||||
],
|
||||
),
|
||||
@ -1445,9 +1453,13 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$r := $lookup(widgets, "model.resolution");
|
||||
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||
$contains(widgets.model, "lite")
|
||||
? {"type":"usd","usd": 0.034, "format":{"suffix":"/Image","approximate":true}}
|
||||
: (
|
||||
$r := $lookup(widgets, "model.resolution");
|
||||
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||
)
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -1468,6 +1480,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
model_choice = model["model"]
|
||||
if model_choice == "Nano Banana 2 (Gemini 3.1 Flash Image)":
|
||||
model_id = "gemini-3.1-flash-image-preview"
|
||||
elif model_choice == "Nano Banana 2 Lite":
|
||||
model_id = "gemini-3.1-flash-lite-image"
|
||||
else:
|
||||
model_id = model_choice
|
||||
|
||||
|
||||
25
nodes.py
25
nodes.py
@ -159,6 +159,29 @@ class ConditioningConcat:
|
||||
|
||||
return (out, )
|
||||
|
||||
class ConditioningMultiply:
|
||||
SEARCH_ALIASES = ["scale conditioning", "scale prompt", "multiply conditioning", "multiply prompt"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
"multiplier": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01})
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "multiply"
|
||||
CATEGORY = "model/conditioning/transform"
|
||||
|
||||
def multiply(self, conditioning, multiplier):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
values = {}
|
||||
pooled_output = t[1].get("pooled_output", None)
|
||||
if pooled_output is not None:
|
||||
values["pooled_output"] = pooled_output * multiplier
|
||||
scaled = node_helpers.conditioning_set_values([[t[0] * multiplier, t[1]]], values)[0]
|
||||
c.append(scaled)
|
||||
return (c,)
|
||||
|
||||
class ConditioningSetArea:
|
||||
SEARCH_ALIASES = ["regional prompt", "area prompt", "spatial conditioning", "localized prompt"]
|
||||
|
||||
@ -2050,6 +2073,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ConditioningAverage": ConditioningAverage,
|
||||
"ConditioningCombine": ConditioningCombine,
|
||||
"ConditioningConcat": ConditioningConcat,
|
||||
"ConditioningMultiply": ConditioningMultiply,
|
||||
"ConditioningSetArea": ConditioningSetArea,
|
||||
"ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
|
||||
"ConditioningSetAreaStrength": ConditioningSetAreaStrength,
|
||||
@ -2121,6 +2145,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ConditioningAverage ": "Conditioning (Average)",
|
||||
"ConditioningAverage": "Conditioning (Average)",
|
||||
"ConditioningConcat": "Conditioning (Concat)",
|
||||
"ConditioningMultiply": "Conditioning (Multiply)",
|
||||
"ConditioningSetArea": "Conditioning (Set Area)",
|
||||
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
||||
"ConditioningSetAreaStrength": "Conditioning (Set Area Strength)",
|
||||
|
||||
546
openapi.yaml
546
openapi.yaml
@ -230,6 +230,93 @@ components:
|
||||
- base_version
|
||||
- workflow_json
|
||||
type: object
|
||||
DownloadEnqueueRequest:
|
||||
description: Request body for enqueuing a server-side model download.
|
||||
properties:
|
||||
allow_any_extension:
|
||||
default: false
|
||||
description: Permit a non-model file extension (default only allows known model extensions).
|
||||
type: boolean
|
||||
credential_id:
|
||||
description: Explicit per-host credential to use; otherwise auto-resolved by host. Still subject to the per-hop host match.
|
||||
nullable: true
|
||||
type: string
|
||||
expected_sha256:
|
||||
description: Optional hub-provided SHA256 to verify the completed file against (fail-closed).
|
||||
nullable: true
|
||||
type: string
|
||||
model_id:
|
||||
description: Destination as "<directory>/<filename>", resolving to a registered model folder (e.g. "loras/my_lora.safetensors").
|
||||
type: string
|
||||
priority:
|
||||
default: 0
|
||||
description: Scheduling priority; higher is admitted first.
|
||||
type: integer
|
||||
url:
|
||||
description: Source URL; must be on the allowlist (host + scheme + extension).
|
||||
type: string
|
||||
required:
|
||||
- url
|
||||
- model_id
|
||||
type: object
|
||||
DownloadStatus:
|
||||
description: Current state and live progress of a single download.
|
||||
properties:
|
||||
bytes_done:
|
||||
type: integer
|
||||
created_at:
|
||||
type: integer
|
||||
download_id:
|
||||
format: uuid
|
||||
type: string
|
||||
error:
|
||||
nullable: true
|
||||
type: string
|
||||
eta_seconds:
|
||||
nullable: true
|
||||
type: number
|
||||
model_id:
|
||||
type: string
|
||||
priority:
|
||||
type: integer
|
||||
progress:
|
||||
description: Fraction in [0,1]; null until total size is known.
|
||||
nullable: true
|
||||
type: number
|
||||
segments:
|
||||
description: Per-segment progress (segmented downloads only).
|
||||
items:
|
||||
properties:
|
||||
bytes_done:
|
||||
type: integer
|
||||
idx:
|
||||
type: integer
|
||||
length:
|
||||
type: integer
|
||||
type: object
|
||||
nullable: true
|
||||
type: array
|
||||
speed_bps:
|
||||
nullable: true
|
||||
type: number
|
||||
status:
|
||||
enum:
|
||||
- queued
|
||||
- active
|
||||
- paused
|
||||
- verifying
|
||||
- completed
|
||||
- failed
|
||||
- cancelled
|
||||
type: string
|
||||
total_bytes:
|
||||
nullable: true
|
||||
type: integer
|
||||
updated_at:
|
||||
type: integer
|
||||
url:
|
||||
type: string
|
||||
type: object
|
||||
ErrorResponse:
|
||||
description: Standard error response with a machine-readable code and human-readable message.
|
||||
properties:
|
||||
@ -511,6 +598,78 @@ components:
|
||||
required:
|
||||
- history
|
||||
type: object
|
||||
HostCredentialUpsert:
|
||||
description: Request body for upserting a per-host credential. The secret is write-only.
|
||||
properties:
|
||||
auth_scheme:
|
||||
default: bearer
|
||||
description: How the secret is attached to requests.
|
||||
enum:
|
||||
- bearer
|
||||
- header
|
||||
- query
|
||||
type: string
|
||||
enabled:
|
||||
default: true
|
||||
type: boolean
|
||||
header_name:
|
||||
description: Header name when auth_scheme=header (defaults to Authorization).
|
||||
nullable: true
|
||||
type: string
|
||||
host:
|
||||
description: Normalized hostname the key applies to (e.g. "civitai.com").
|
||||
type: string
|
||||
label:
|
||||
description: User-friendly name for display.
|
||||
nullable: true
|
||||
type: string
|
||||
match_subdomains:
|
||||
default: false
|
||||
description: Also match label-boundary subdomains of host (off by default; unsafe for hub CDNs).
|
||||
type: boolean
|
||||
query_param:
|
||||
description: Query parameter name when auth_scheme=query.
|
||||
nullable: true
|
||||
type: string
|
||||
secret:
|
||||
description: The API key. Write-only — never returned by any endpoint.
|
||||
type: string
|
||||
required:
|
||||
- host
|
||||
- secret
|
||||
type: object
|
||||
HostCredentialView:
|
||||
description: Masked, API-safe view of a stored credential. Never includes the secret.
|
||||
properties:
|
||||
auth_scheme:
|
||||
type: string
|
||||
created_at:
|
||||
type: integer
|
||||
enabled:
|
||||
type: boolean
|
||||
header_name:
|
||||
nullable: true
|
||||
type: string
|
||||
host:
|
||||
type: string
|
||||
id:
|
||||
format: uuid
|
||||
type: string
|
||||
label:
|
||||
nullable: true
|
||||
type: string
|
||||
match_subdomains:
|
||||
type: boolean
|
||||
query_param:
|
||||
nullable: true
|
||||
type: string
|
||||
secret_last4:
|
||||
description: Last 4 characters of the secret, for masked display only.
|
||||
nullable: true
|
||||
type: string
|
||||
updated_at:
|
||||
type: integer
|
||||
type: object
|
||||
JobCancelResponse:
|
||||
description: Response for POST /api/jobs/{job_id}/cancel. Returned on both fresh cancels and idempotent no-ops.
|
||||
properties:
|
||||
@ -2350,6 +2509,391 @@ paths:
|
||||
summary: Get tag histogram for filtered assets
|
||||
tags:
|
||||
- file
|
||||
/api/download:
|
||||
get:
|
||||
description: List all known downloads (queued, active, paused, and terminal) with live progress.
|
||||
operationId: listDownloads
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
downloads:
|
||||
items:
|
||||
$ref: '#/components/schemas/DownloadStatus'
|
||||
type: array
|
||||
type: object
|
||||
description: List of downloads
|
||||
summary: List downloads
|
||||
tags:
|
||||
- download
|
||||
/api/download/availability:
|
||||
post:
|
||||
description: |
|
||||
Bulk per-id availability for a set of model_ids declared in a workflow.
|
||||
Returns whether each model is available on disk, currently downloading
|
||||
(with progress), or missing, plus whether its URL is on the allowlist.
|
||||
operationId: getModelsAvailability
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
models:
|
||||
additionalProperties:
|
||||
type: string
|
||||
description: Map of "<directory>/<filename>" model_id to its declared source URL.
|
||||
type: object
|
||||
type: object
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
models:
|
||||
additionalProperties: true
|
||||
type: object
|
||||
type: object
|
||||
description: Per-id availability map
|
||||
summary: Bulk model availability + status
|
||||
tags:
|
||||
- download
|
||||
/api/download/clear:
|
||||
post:
|
||||
description: |
|
||||
Delete all terminal downloads (completed, failed, cancelled) from history
|
||||
in one transaction, so the cleared history persists across reloads. Live
|
||||
downloads (queued, active, paused, verifying) are skipped. Finished model
|
||||
files on disk are never removed; only leftover .part temp files are cleaned up.
|
||||
operationId: clearDownloads
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
deleted:
|
||||
description: Number of history rows removed.
|
||||
type: integer
|
||||
type: object
|
||||
description: History cleared
|
||||
summary: Clear terminal downloads from history
|
||||
tags:
|
||||
- download
|
||||
/api/download/credentials:
|
||||
get:
|
||||
description: List stored per-host credentials. Secrets are never returned; only masked metadata (last 4 chars, scheme, label).
|
||||
operationId: listDownloadCredentials
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
credentials:
|
||||
items:
|
||||
$ref: '#/components/schemas/HostCredentialView'
|
||||
type: array
|
||||
type: object
|
||||
description: Masked credential list
|
||||
summary: List host credentials (masked)
|
||||
tags:
|
||||
- download
|
||||
post:
|
||||
description: |
|
||||
Upsert (by host) a per-host API key used to authenticate downloads.
|
||||
The secret is write-only: it is stored once here and never returned by any endpoint.
|
||||
operationId: upsertDownloadCredential
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HostCredentialUpsert'
|
||||
responses:
|
||||
"201":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HostCredentialView'
|
||||
description: Credential stored (masked view returned)
|
||||
"400":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Invalid credential
|
||||
summary: Upsert a host credential
|
||||
tags:
|
||||
- download
|
||||
/api/download/credentials/{id}:
|
||||
delete:
|
||||
description: Delete a stored host credential.
|
||||
operationId: deleteDownloadCredential
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
deleted:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Deleted
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: No such credential
|
||||
summary: Delete a host credential
|
||||
tags:
|
||||
- download
|
||||
get:
|
||||
description: Get a single host credential (masked; never includes the secret).
|
||||
operationId: getDownloadCredential
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HostCredentialView'
|
||||
description: Masked credential
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: No such credential
|
||||
summary: Get a host credential (masked)
|
||||
tags:
|
||||
- download
|
||||
/api/download/enqueue:
|
||||
post:
|
||||
description: |
|
||||
Enqueue a server-side model download. The URL must be on the allowlist
|
||||
(host + scheme + extension) and the model_id must be "<directory>/<filename>"
|
||||
resolving to a registered model folder. Returns immediately; track progress
|
||||
via GET /api/download/{id} or the "download_progress" websocket event.
|
||||
operationId: enqueueDownload
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/DownloadEnqueueRequest'
|
||||
responses:
|
||||
"202":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
accepted:
|
||||
type: boolean
|
||||
download_id:
|
||||
format: uuid
|
||||
type: string
|
||||
type: object
|
||||
description: Download accepted and queued
|
||||
"400":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Invalid request (bad URL, model_id, or not allowlisted)
|
||||
"409":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Already on disk or already downloading
|
||||
summary: Enqueue a model download
|
||||
tags:
|
||||
- download
|
||||
/api/download/{id}:
|
||||
delete:
|
||||
description: |
|
||||
Delete a single terminal download from history so it stays gone across
|
||||
reloads. Refuses (409) to delete a live download (queued, active, paused,
|
||||
verifying) — cancel it first. The finished model file on disk is never
|
||||
removed; only a leftover .part temp file is cleaned up.
|
||||
operationId: deleteDownload
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
deleted:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Deleted
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: No such download
|
||||
"409":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Download is still in progress
|
||||
summary: Delete a download from history
|
||||
tags:
|
||||
- download
|
||||
get:
|
||||
description: Get the current status + progress of a single download.
|
||||
operationId: getDownloadStatus
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/DownloadStatus'
|
||||
description: Download status
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: No such download
|
||||
summary: Get download status
|
||||
tags:
|
||||
- download
|
||||
/api/download/{id}/cancel:
|
||||
post:
|
||||
description: Cancel a download. The partial file is removed.
|
||||
operationId: cancelDownload
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
ok:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Cancelled
|
||||
summary: Cancel a download
|
||||
tags:
|
||||
- download
|
||||
/api/download/{id}/pause:
|
||||
post:
|
||||
description: Pause a download. The partial file and per-segment offsets are retained for resume.
|
||||
operationId: pauseDownload
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
ok:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Paused
|
||||
summary: Pause a download
|
||||
tags:
|
||||
- download
|
||||
/api/download/{id}/priority:
|
||||
post:
|
||||
description: Set a download's scheduling priority. Higher priority is admitted first when a slot frees.
|
||||
operationId: setDownloadPriority
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
priority:
|
||||
type: integer
|
||||
required:
|
||||
- priority
|
||||
type: object
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
ok:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Priority updated
|
||||
summary: Set download priority
|
||||
tags:
|
||||
- download
|
||||
/api/download/{id}/resume:
|
||||
post:
|
||||
description: Resume a paused (or failed) download from its persisted offsets.
|
||||
operationId: resumeDownload
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
ok:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Resumed
|
||||
summary: Resume a download
|
||||
tags:
|
||||
- download
|
||||
/api/embeddings:
|
||||
get:
|
||||
description: Returns the list of text-encoder embeddings available on disk.
|
||||
@ -5103,3 +5647,5 @@ tags:
|
||||
name: queue
|
||||
- description: Job lifecycle queries
|
||||
name: job
|
||||
- description: Model download management
|
||||
name: download
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.45.19
|
||||
comfyui-frontend-package==1.45.20
|
||||
comfyui-workflow-templates==0.10.7
|
||||
comfyui-embedded-docs==0.5.5
|
||||
comfyui-embedded-docs==0.5.6
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=16.0.0
|
||||
comfy-kitchen==0.2.13
|
||||
comfy-kitchen==0.2.15
|
||||
comfy-aimdo==0.4.10
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
|
||||
26
server.py
26
server.py
@ -45,6 +45,8 @@ from app.frontend_management import FrontendManager, parse_version
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.api.routes import register_assets_routes
|
||||
from app.model_downloader.api.routes import register_routes as register_model_downloader_routes
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER
|
||||
from app.assets.services.ingest import register_file_in_place
|
||||
from app.assets.services.asset_management import resolve_hash_to_path
|
||||
|
||||
@ -256,6 +258,7 @@ class PromptServer():
|
||||
else:
|
||||
register_assets_routes(self.app)
|
||||
asset_seeder.disable()
|
||||
register_model_downloader_routes(self.app)
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
@ -1182,6 +1185,29 @@ class PromptServer():
|
||||
async def setup(self):
|
||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||
await self._setup_model_downloader()
|
||||
|
||||
async def _setup_model_downloader(self):
|
||||
"""Start the download manager: push progress over the websocket and
|
||||
resume any downloads interrupted by a previous run."""
|
||||
def _notify(download_id: str) -> None:
|
||||
try:
|
||||
view = DOWNLOAD_MANAGER.status_sync(download_id)
|
||||
if view is not None:
|
||||
# Drop the url field before broadcasting: the redacted URL
|
||||
# (scheme + host + path) should not leak to every connected
|
||||
# websocket client. download_id / model_id are sufficient to
|
||||
# correlate progress on the frontend.
|
||||
broadcast = {k: v for k, v in view.items() if k != "url"}
|
||||
self.send_sync("download_progress", broadcast)
|
||||
except Exception:
|
||||
logging.debug("download progress notify failed", exc_info=True)
|
||||
|
||||
DOWNLOAD_MANAGER.set_notify(_notify)
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.start()
|
||||
except Exception as e:
|
||||
logging.warning("Failed to start model download manager: %s", e)
|
||||
|
||||
def add_routes(self):
|
||||
self.user_manager.add_routes(self.routes)
|
||||
|
||||
@ -1,204 +0,0 @@
|
||||
"""Unit tests for io.DynamicGroup: expansion/reconstruction (0-row and N-row cases)."""
|
||||
import sys
|
||||
import types
|
||||
import pytest
|
||||
|
||||
# Stub torch (type-hint only in _io.py; real torch not available in unit-test env)
|
||||
if "torch" not in sys.modules:
|
||||
_torch_stub = types.ModuleType("torch")
|
||||
_torch_stub.Tensor = object # type: ignore[attr-defined]
|
||||
sys.modules["torch"] = _torch_stub
|
||||
|
||||
from comfy_api.latest._io import ( # noqa: E402
|
||||
DynamicGroup,
|
||||
Float,
|
||||
Int,
|
||||
String,
|
||||
Boolean,
|
||||
get_finalized_class_inputs,
|
||||
build_nested_inputs,
|
||||
create_input_dict_v1,
|
||||
setup_dynamic_input_funcs,
|
||||
)
|
||||
|
||||
# Make sure dynamic input funcs are registered (may already be done at import time)
|
||||
setup_dynamic_input_funcs()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_class_inputs(group_input: DynamicGroup.Input) -> dict:
|
||||
"""Wrap a DynamicGroup.Input into the required/optional dict structure."""
|
||||
return create_input_dict_v1([group_input])
|
||||
|
||||
|
||||
def _run(group_input: DynamicGroup.Input, live_values: dict) -> dict:
|
||||
"""End-to-end helper: expand schema + reconstruct values.
|
||||
|
||||
Mirrors the production split in execution.py:
|
||||
1. get_finalized_class_inputs (schema expansion, line 162)
|
||||
2. build_nested_inputs (value reconstruction, line 281)
|
||||
|
||||
The two steps are separate in production because the engine resolves
|
||||
linked node outputs between them, but in tests we supply values directly.
|
||||
"""
|
||||
class_inputs = _make_class_inputs(group_input)
|
||||
_, _, v3_data = get_finalized_class_inputs(class_inputs, live_values)
|
||||
return build_nested_inputs(dict(live_values), v3_data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDynamicGroupInputConstruction:
|
||||
def test_basic_construction(self):
|
||||
inp = DynamicGroup.Input(
|
||||
"loras",
|
||||
template=[
|
||||
Float.Input("strength", default=1.0),
|
||||
String.Input("name"),
|
||||
],
|
||||
min=0,
|
||||
max=10,
|
||||
)
|
||||
assert inp.id == "loras"
|
||||
assert inp.min == 0
|
||||
assert inp.max == 10
|
||||
assert len(inp.template) == 2
|
||||
|
||||
def test_get_all_includes_self_and_template(self):
|
||||
inp = DynamicGroup.Input(
|
||||
"items",
|
||||
template=[Float.Input("value")],
|
||||
)
|
||||
all_inputs = inp.get_all()
|
||||
assert all_inputs[0] is inp
|
||||
assert all_inputs[1].id == "value"
|
||||
|
||||
def test_as_dict_has_template_min_max(self):
|
||||
inp = DynamicGroup.Input(
|
||||
"items",
|
||||
template=[Float.Input("val", default=0.5)],
|
||||
min=1,
|
||||
max=5,
|
||||
)
|
||||
d = inp.as_dict()
|
||||
assert "template" in d
|
||||
assert d["min"] == 1
|
||||
assert d["max"] == 5
|
||||
|
||||
def test_duplicate_field_ids_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input(
|
||||
"bad",
|
||||
template=[Float.Input("x"), Float.Input("x")],
|
||||
)
|
||||
|
||||
def test_empty_template_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input("bad", template=[])
|
||||
|
||||
def test_min_gt_max_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input("bad", template=[Float.Input("x")], min=5, max=3)
|
||||
|
||||
def test_max_exceeds_limit_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input("bad", template=[Float.Input("x")], max=101)
|
||||
|
||||
def test_dynamic_input_in_template_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input(
|
||||
"bad",
|
||||
template=[DynamicGroup.Input("nested", template=[Float.Input("x")])],
|
||||
)
|
||||
|
||||
def test_validate_calls_through(self):
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", min=-1.0, max=1.0)])
|
||||
inp.validate() # should not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 0-row case
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestZeroRows:
|
||||
def test_empty_live_inputs_produces_empty_list(self):
|
||||
"""With min=0 and no live values, the result should be an empty list."""
|
||||
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
|
||||
assert _run(inp, {}).get("loras") == []
|
||||
|
||||
def test_min_zero_with_values(self):
|
||||
"""min=0 but 2 rows of live data."""
|
||||
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
|
||||
result = _run(inp, {"loras.0.strength": 0.8, "loras.1.strength": 0.5})
|
||||
assert result["loras"] == [{"strength": 0.8}, {"strength": 0.5}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# N-row case
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNRows:
|
||||
def test_two_rows_two_fields(self):
|
||||
"""Two rows with two fields each produce a list[dict]."""
|
||||
inp = DynamicGroup.Input(
|
||||
"loras",
|
||||
template=[String.Input("lora_name"), Float.Input("strength", default=1.0)],
|
||||
min=0, max=50,
|
||||
)
|
||||
result = _run(inp, {
|
||||
"loras.0.lora_name": "model_a.safetensors", "loras.0.strength": 0.9,
|
||||
"loras.1.lora_name": "model_b.safetensors", "loras.1.strength": 0.4,
|
||||
})
|
||||
assert result["loras"] == [
|
||||
{"lora_name": "model_a.safetensors", "strength": 0.9},
|
||||
{"lora_name": "model_b.safetensors", "strength": 0.4},
|
||||
]
|
||||
|
||||
def test_rows_are_sorted_by_index(self):
|
||||
"""Rows must be in ascending index order even if dict iteration is unordered."""
|
||||
inp = DynamicGroup.Input("items", template=[Int.Input("v", default=0)], min=0, max=10)
|
||||
result = _run(inp, {"items.0.v": 10, "items.2.v": 30, "items.1.v": 20})
|
||||
assert [row["v"] for row in result["items"]] == [10, 20, 30]
|
||||
|
||||
def test_min_rows_schema_slots(self):
|
||||
"""With min=2 and no live data, 2 slots must appear in the expanded schema."""
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
out, _, _ = get_finalized_class_inputs(_make_class_inputs(inp), {})
|
||||
all_slots = {**out.get("required", {}), **out.get("optional", {})}
|
||||
assert "items.0.val" in all_slots
|
||||
assert "items.1.val" in all_slots
|
||||
|
||||
def test_min_rows_reconstructs_when_no_values(self):
|
||||
"""min=2 with NO live values must still yield a 2-element list,
|
||||
not collapse to [] (regression: parent-path clobber)."""
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
result = _run(inp, {})
|
||||
assert len(result["items"]) == 2
|
||||
assert all("val" in row for row in result["items"])
|
||||
|
||||
def test_min_rows_reconstructs_with_partial_values(self):
|
||||
"""min=2 with only the first row's value present still yields 2 rows."""
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
result = _run(inp, {"items.0.val": 0.7})
|
||||
assert len(result["items"]) == 2
|
||||
assert result["items"][0]["val"] == 0.7
|
||||
assert result["items"][1]["val"] is None
|
||||
|
||||
def test_list_paths_in_v3_data(self):
|
||||
"""list_paths must contain the group id so build_nested_inputs knows to convert."""
|
||||
inp = DynamicGroup.Input("things", template=[Boolean.Input("flag")], min=0, max=5)
|
||||
_, _, v3_data = get_finalized_class_inputs(_make_class_inputs(inp), {})
|
||||
assert "things" in v3_data.get("list_paths", set())
|
||||
|
||||
def test_no_leftover_flat_keys(self):
|
||||
"""Flat keys must be consumed; only the reconstructed list remains."""
|
||||
inp = DynamicGroup.Input("rows", template=[Float.Input("x", default=0.0)], min=0, max=5)
|
||||
result = _run(inp, {"rows.0.x": 1.0, "rows.1.x": 2.0})
|
||||
assert "rows.0.x" not in result
|
||||
assert "rows.1.x" not in result
|
||||
assert isinstance(result["rows"], list)
|
||||
90
tests-unit/model_downloader_test/conftest.py
Normal file
90
tests-unit/model_downloader_test/conftest.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""Shared fixtures for the model download manager tests.
|
||||
|
||||
These run in-process (no ComfyUI subprocess): a file-backed SQLite DB is
|
||||
initialized once, a temp model folder is registered with ``folder_paths``, and
|
||||
the shared aiohttp session is reset between tests so each async test gets a
|
||||
session bound to its own event loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _drain_scheduler_tasks(scheduler) -> None:
|
||||
"""Cancel and await live scheduler tasks so none outlive the test.
|
||||
|
||||
Uses the actual task handles rather than only clearing ``_tasks``: each
|
||||
per-test event loop is created by ``asyncio.run``, so a task left behind by
|
||||
a crashed/aborted test would otherwise keep its coroutine alive. We cancel
|
||||
every live task and, when its loop is still usable, run it to completion to
|
||||
let the cancellation propagate before dropping the reference.
|
||||
"""
|
||||
for task in list(scheduler._tasks.values()):
|
||||
if task is None:
|
||||
continue
|
||||
loop = task.get_loop()
|
||||
if task.done() or loop.is_closed():
|
||||
continue
|
||||
task.cancel()
|
||||
if not loop.is_running():
|
||||
try:
|
||||
loop.run_until_complete(asyncio.gather(task, return_exceptions=True))
|
||||
except Exception:
|
||||
pass
|
||||
scheduler._tasks.clear()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _init_db():
|
||||
import app.database.db as db
|
||||
from comfy.cli_args import args
|
||||
|
||||
fd, db_path = tempfile.mkstemp(suffix="-dlmgr-test.sqlite3")
|
||||
os.close(fd)
|
||||
args.database_url = f"sqlite:///{db_path}"
|
||||
db.init_db()
|
||||
yield
|
||||
try:
|
||||
os.remove(db_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_runtime():
|
||||
"""Reset module singletons that hold event-loop-bound or cross-test state."""
|
||||
import app.model_downloader.net.session as ns
|
||||
from app.model_downloader.scheduler import SCHEDULER
|
||||
|
||||
ns._session = None
|
||||
_drain_scheduler_tasks(SCHEDULER)
|
||||
SCHEDULER._jobs.clear()
|
||||
SCHEDULER._backoff_until.clear()
|
||||
SCHEDULER._started = False
|
||||
yield
|
||||
_drain_scheduler_tasks(SCHEDULER)
|
||||
ns._session = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_root(tmp_path):
|
||||
"""Register a temp 'loras' model folder and return its absolute path."""
|
||||
import folder_paths
|
||||
|
||||
root = tmp_path / "loras"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
saved = folder_paths.folder_names_and_paths.get("loras")
|
||||
folder_paths.folder_names_and_paths["loras"] = (
|
||||
[str(root)],
|
||||
{".safetensors", ".sft", ".ckpt", ".pt", ".pth"},
|
||||
)
|
||||
yield str(root)
|
||||
if saved is not None:
|
||||
folder_paths.folder_names_and_paths["loras"] = saved
|
||||
else:
|
||||
folder_paths.folder_names_and_paths.pop("loras", None)
|
||||
166
tests-unit/model_downloader_test/test_credentials.py
Normal file
166
tests-unit/model_downloader_test/test_credentials.py
Normal file
@ -0,0 +1,166 @@
|
||||
"""Unit tests for the credential store and the per-hop credential resolver.
|
||||
|
||||
Covers the critical rule: a secret is only ever attached when the current
|
||||
hop's host matches a stored credential, and never over a non-https hop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from app.model_downloader.credentials import resolver
|
||||
from app.model_downloader.credentials.store import (
|
||||
CREDENTIAL_STORE,
|
||||
CredentialValidationError,
|
||||
normalize_host,
|
||||
)
|
||||
from app.model_downloader.database.models import HostCredential
|
||||
|
||||
|
||||
# ----- pure host normalization + matching -----
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw,expected",
|
||||
[
|
||||
("Civitai.com", "civitai.com"),
|
||||
("HuggingFace.co:443", "huggingface.co"),
|
||||
(" Example.COM ", "example.com"),
|
||||
],
|
||||
)
|
||||
def test_normalize_host(raw, expected):
|
||||
assert normalize_host(raw) == expected
|
||||
|
||||
|
||||
def _cred(**kw) -> HostCredential:
|
||||
base = dict(
|
||||
id="x", host="civitai.com", match_subdomains=False, auth_scheme="bearer",
|
||||
secret="SECRET", enabled=True,
|
||||
)
|
||||
base.update(kw)
|
||||
return HostCredential(**base)
|
||||
|
||||
|
||||
def test_matches_exact_only_by_default():
|
||||
c = _cred(host="civitai.com")
|
||||
assert resolver._matches(c, "civitai.com") is True
|
||||
assert resolver._matches(c, "api.civitai.com") is False
|
||||
assert resolver._matches(c, "evil-civitai.com") is False
|
||||
|
||||
|
||||
def test_matches_subdomain_label_boundary():
|
||||
c = _cred(host="example.com", match_subdomains=True)
|
||||
assert resolver._matches(c, "api.example.com") is True
|
||||
assert resolver._matches(c, "example.com") is True
|
||||
# not a label boundary -> no match
|
||||
assert resolver._matches(c, "evil-example.com") is False
|
||||
|
||||
|
||||
def test_build_auth_shapes():
|
||||
assert resolver._build_auth(_cred(auth_scheme="bearer")).headers == {
|
||||
"Authorization": "Bearer SECRET"
|
||||
}
|
||||
assert resolver._build_auth(
|
||||
_cred(auth_scheme="header", header_name="X-Api-Key")
|
||||
).headers == {"X-Api-Key": "SECRET"}
|
||||
q = resolver._build_auth(_cred(auth_scheme="query", query_param="token"))
|
||||
assert q.query == {"token": "SECRET"}
|
||||
assert q.apply_to_url("https://civitai.com/x") == "https://civitai.com/x?token=SECRET"
|
||||
|
||||
|
||||
# ----- DB-backed store + resolver -----
|
||||
|
||||
|
||||
def test_store_upsert_is_write_only_and_masked():
|
||||
async def _run():
|
||||
view = await CREDENTIAL_STORE.upsert("civitai.com", "abcd1234", label="my key")
|
||||
# The view never carries the secret, only the last 4.
|
||||
assert not hasattr(view, "secret")
|
||||
assert view.secret_last4 == "1234"
|
||||
assert view.host == "civitai.com"
|
||||
listed = await CREDENTIAL_STORE.list()
|
||||
assert any(v.host == "civitai.com" for v in listed)
|
||||
await CREDENTIAL_STORE.delete(view.id)
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_query_scheme_requires_param():
|
||||
async def _run():
|
||||
with pytest.raises(CredentialValidationError):
|
||||
await CREDENTIAL_STORE.upsert("civitai.com", "k", auth_scheme="query")
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_resolver_never_crosses_host_boundary():
|
||||
async def _run():
|
||||
view = await CREDENTIAL_STORE.upsert("huggingface.co", "hf_secret_key")
|
||||
try:
|
||||
# matching host over https -> attached
|
||||
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
|
||||
assert auth is not None
|
||||
assert auth.headers["Authorization"] == "Bearer hf_secret_key"
|
||||
# CDN redirect host -> dropped
|
||||
assert await resolver.resolve_auth_for_hop("cdn-lfs.huggingface.co", "https") is None
|
||||
# non-https hop -> never attached
|
||||
assert await resolver.resolve_auth_for_hop("huggingface.co", "http") is None
|
||||
finally:
|
||||
await CREDENTIAL_STORE.delete(view.id)
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- env-based HF token fallback -----
|
||||
|
||||
|
||||
def test_env_token_fallback_attaches_when_no_db_credential(monkeypatch):
|
||||
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
|
||||
|
||||
async def _run():
|
||||
# exact host over https -> env token attached
|
||||
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
|
||||
assert auth is not None
|
||||
assert auth.headers["Authorization"] == "Bearer env_hf_token"
|
||||
# non-https hop -> never attached
|
||||
assert await resolver.resolve_auth_for_hop("huggingface.co", "http") is None
|
||||
# CDN redirect host -> dropped (exact-host only)
|
||||
assert await resolver.resolve_auth_for_hop("cdn-lfs.huggingface.co", "https") is None
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_env_token_secondary_var_is_honored(monkeypatch):
|
||||
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||
monkeypatch.setenv("HUGGING_FACE_HUB_TOKEN", "env_hub_token")
|
||||
|
||||
async def _run():
|
||||
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
|
||||
assert auth is not None
|
||||
assert auth.headers["Authorization"] == "Bearer env_hub_token"
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_db_credential_takes_precedence_over_env(monkeypatch):
|
||||
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
|
||||
|
||||
async def _run():
|
||||
view = await CREDENTIAL_STORE.upsert("huggingface.co", "db_secret_key")
|
||||
try:
|
||||
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
|
||||
assert auth is not None
|
||||
assert auth.headers["Authorization"] == "Bearer db_secret_key"
|
||||
finally:
|
||||
await CREDENTIAL_STORE.delete(view.id)
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_env_token_does_not_leak_into_explicit_path(monkeypatch):
|
||||
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
|
||||
|
||||
async def _run():
|
||||
# An explicit credential id that doesn't resolve must stay None; the env
|
||||
# fallback only applies to the auto-resolve branch.
|
||||
auth = await resolver.resolve_auth_for_hop(
|
||||
"huggingface.co", "https", explicit_credential_id="does-not-exist"
|
||||
)
|
||||
assert auth is None
|
||||
asyncio.run(_run())
|
||||
136
tests-unit/model_downloader_test/test_delete.py
Normal file
136
tests-unit/model_downloader_test/test_delete.py
Normal file
@ -0,0 +1,136 @@
|
||||
"""Unit tests for ``DownloadManager.delete`` and ``DownloadManager.clear``.
|
||||
|
||||
Deleting a terminal row must remove it from history for good (so it does not
|
||||
reappear on the next ``list``), leave live rows untouched, and clean up any
|
||||
leftover ``.part`` temp file without touching the finished model file.
|
||||
|
||||
``clear()`` is the bulk variant: it removes all terminal rows atomically, skips
|
||||
live ones, and returns the count of rows deleted.
|
||||
|
||||
Async methods are driven via ``asyncio.run`` so no pytest-asyncio plugin is
|
||||
required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
|
||||
|
||||
|
||||
def _insert(download_id: str, status: str, *, temp_path: str = "/tmp/none.part") -> None:
|
||||
queries.insert_download(
|
||||
{
|
||||
"id": download_id,
|
||||
"url": "https://huggingface.co/org/model.safetensors",
|
||||
"model_id": "loras/model.safetensors",
|
||||
"dest_path": "/tmp/model.safetensors",
|
||||
"temp_path": temp_path,
|
||||
"status": status,
|
||||
"priority": 0,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_delete_removes_terminal_row_from_history():
|
||||
_insert("done", DownloadStatus.COMPLETED)
|
||||
|
||||
asyncio.run(DOWNLOAD_MANAGER.delete("done"))
|
||||
|
||||
assert queries.get_download("done") is None
|
||||
|
||||
|
||||
def test_delete_refuses_live_row():
|
||||
_insert("live", DownloadStatus.QUEUED)
|
||||
|
||||
with pytest.raises(DownloadError) as excinfo:
|
||||
asyncio.run(DOWNLOAD_MANAGER.delete("live"))
|
||||
|
||||
assert excinfo.value.code == "DOWNLOAD_ACTIVE"
|
||||
assert queries.get_download("live") is not None
|
||||
|
||||
|
||||
def test_delete_missing_row_raises_not_found():
|
||||
with pytest.raises(DownloadError) as excinfo:
|
||||
asyncio.run(DOWNLOAD_MANAGER.delete("nope"))
|
||||
|
||||
assert excinfo.value.code == "NOT_FOUND"
|
||||
|
||||
|
||||
def test_delete_removes_leftover_temp_file(tmp_path):
|
||||
partial = tmp_path / "model.safetensors.part"
|
||||
partial.write_bytes(b"partial")
|
||||
_insert("failed", DownloadStatus.FAILED, temp_path=str(partial))
|
||||
|
||||
asyncio.run(DOWNLOAD_MANAGER.delete("failed"))
|
||||
|
||||
assert not os.path.exists(partial)
|
||||
assert queries.get_download("failed") is None
|
||||
|
||||
|
||||
# ----- clear -----
|
||||
|
||||
|
||||
def test_clear_removes_all_terminal_rows():
|
||||
_insert("c-done", DownloadStatus.COMPLETED)
|
||||
_insert("c-fail", DownloadStatus.FAILED)
|
||||
_insert("c-canc", DownloadStatus.CANCELLED)
|
||||
|
||||
deleted = asyncio.run(DOWNLOAD_MANAGER.clear())
|
||||
|
||||
assert deleted == 3
|
||||
assert queries.get_download("c-done") is None
|
||||
assert queries.get_download("c-fail") is None
|
||||
assert queries.get_download("c-canc") is None
|
||||
|
||||
|
||||
def test_clear_skips_live_rows():
|
||||
_insert("cl-queued", DownloadStatus.QUEUED)
|
||||
_insert("cl-paused", DownloadStatus.PAUSED)
|
||||
_insert("cl-done", DownloadStatus.COMPLETED)
|
||||
|
||||
deleted = asyncio.run(DOWNLOAD_MANAGER.clear())
|
||||
|
||||
assert deleted == 1
|
||||
assert queries.get_download("cl-queued") is not None
|
||||
assert queries.get_download("cl-paused") is not None
|
||||
assert queries.get_download("cl-done") is None
|
||||
|
||||
|
||||
def test_clear_returns_zero_when_nothing_to_delete():
|
||||
_insert("cl-only-live", DownloadStatus.QUEUED)
|
||||
|
||||
deleted = asyncio.run(DOWNLOAD_MANAGER.clear())
|
||||
|
||||
assert deleted == 0
|
||||
assert queries.get_download("cl-only-live") is not None
|
||||
|
||||
|
||||
def test_clear_removes_leftover_temp_files(tmp_path):
|
||||
partial = tmp_path / "clear_partial.part"
|
||||
partial.write_bytes(b"partial data")
|
||||
finished = tmp_path / "finished.safetensors"
|
||||
finished.write_bytes(b"real model weights")
|
||||
|
||||
_insert("cl-part", DownloadStatus.FAILED, temp_path=str(partial))
|
||||
# The finished file is not the temp_path; temp_path for a completed download
|
||||
# no longer exists (already renamed), so use a non-existent path here to
|
||||
# verify clear() tolerates a missing temp file without raising.
|
||||
_insert("cl-comp", DownloadStatus.COMPLETED, temp_path=str(tmp_path / "gone.part"))
|
||||
|
||||
asyncio.run(DOWNLOAD_MANAGER.clear())
|
||||
|
||||
# Leftover .part from the failed download is cleaned up.
|
||||
assert not partial.exists()
|
||||
# Finished model file is never touched.
|
||||
assert finished.exists()
|
||||
|
||||
|
||||
def test_clear_empty_db_returns_zero():
|
||||
deleted = asyncio.run(DOWNLOAD_MANAGER.clear())
|
||||
assert deleted == 0
|
||||
637
tests-unit/model_downloader_test/test_engine_integration.py
Normal file
637
tests-unit/model_downloader_test/test_engine_integration.py
Normal file
@ -0,0 +1,637 @@
|
||||
"""Integration tests for the download engine against a local aiohttp server.
|
||||
|
||||
Covers single-stream and segmented transfers, deterministic resume from a
|
||||
partial file, and cancel rollback. Async tests are driven via ``asyncio.run``
|
||||
so no pytest-asyncio plugin is required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from comfy.cli_args import args
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.engine.job import DownloadJob, JobSpec
|
||||
from app.model_downloader.net.session import close_session
|
||||
from app.model_downloader.security import paths
|
||||
|
||||
PAYLOAD_ETAG = '"v1"'
|
||||
|
||||
|
||||
def _payload(n: int) -> bytes:
|
||||
return bytes((i * 37 + 11) % 256 for i in range(n))
|
||||
|
||||
|
||||
def _safetensors_payload(total: int) -> bytes:
|
||||
"""A structurally valid ``.safetensors`` blob of exactly ``total`` bytes.
|
||||
|
||||
Success-path tests download to ``.safetensors`` destinations, which the
|
||||
engine now structurally validates before the atomic rename, so their
|
||||
payloads must parse as real safetensors (header length + JSON header +
|
||||
data region whose size matches the declared ``data_offsets``).
|
||||
"""
|
||||
def _header(data_len: int) -> bytes:
|
||||
return json.dumps(
|
||||
{"w": {"dtype": "U8", "shape": [data_len], "data_offsets": [0, data_len]}}
|
||||
).encode("utf-8")
|
||||
|
||||
# The header's byte length depends on the digit count of ``data_len``, so
|
||||
# iterate until ``total == 8 + len(header) + data_len`` is self-consistent.
|
||||
data_len = total - 8 - len(_header(total))
|
||||
for _ in range(8):
|
||||
header = _header(data_len)
|
||||
new_data_len = total - 8 - len(header)
|
||||
if new_data_len == data_len:
|
||||
break
|
||||
data_len = new_data_len
|
||||
assert data_len >= 0, "total too small for a safetensors payload"
|
||||
header = _header(data_len)
|
||||
body = bytes((i * 37 + 11) % 256 for i in range(data_len))
|
||||
return struct.pack("<Q", len(header)) + header + body
|
||||
|
||||
|
||||
def _range_handler(payload: bytes):
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
rng = request.headers.get("Range")
|
||||
if rng:
|
||||
spec = rng.split("=", 1)[1]
|
||||
s, _, e = spec.partition("-")
|
||||
start = int(s)
|
||||
end = int(e) if e else len(payload) - 1
|
||||
chunk = payload[start : end + 1]
|
||||
return web.Response(
|
||||
status=206,
|
||||
body=chunk,
|
||||
headers={
|
||||
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"ETag": PAYLOAD_ETAG,
|
||||
},
|
||||
)
|
||||
return web.Response(
|
||||
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
|
||||
)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _content_disposition_handler(payload: bytes, filename: str):
|
||||
"""A range-capable server that only reveals its filename via a header.
|
||||
|
||||
Models a Civitai-style ``/api/download/...`` endpoint: the URL path has no
|
||||
extension, and the real filename (hence extension) lives in the response
|
||||
``Content-Disposition`` header.
|
||||
"""
|
||||
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
headers = {
|
||||
"Accept-Ranges": "bytes",
|
||||
"ETag": PAYLOAD_ETAG,
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
}
|
||||
rng = request.headers.get("Range")
|
||||
if rng:
|
||||
spec = rng.split("=", 1)[1]
|
||||
s, _, e = spec.partition("-")
|
||||
start = int(s)
|
||||
end = int(e) if e else len(payload) - 1
|
||||
chunk = payload[start : end + 1]
|
||||
return web.Response(
|
||||
status=206,
|
||||
body=chunk,
|
||||
headers={**headers, "Content-Range": f"bytes {start}-{end}/{len(payload)}"},
|
||||
)
|
||||
return web.Response(status=200, body=payload, headers=headers)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _noranges_handler(payload: bytes):
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
# Always full body, never advertises Accept-Ranges -> single-stream.
|
||||
return web.Response(status=200, body=payload)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _slow_handler(payload: bytes, chunk: int = 16384, delay: float = 0.01):
|
||||
async def handler(request: web.Request) -> web.StreamResponse:
|
||||
resp = web.StreamResponse(
|
||||
status=200, headers={"Content-Length": str(len(payload))}
|
||||
)
|
||||
await resp.prepare(request)
|
||||
for i in range(0, len(payload), chunk):
|
||||
await resp.write(payload[i : i + chunk])
|
||||
await asyncio.sleep(delay)
|
||||
await resp.write_eof()
|
||||
return resp
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _overflow_range_handler(payload: bytes, extra: int = 256 * 1024):
|
||||
"""A non-conforming 206 server that returns MORE than the requested range."""
|
||||
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
rng = request.headers.get("Range")
|
||||
if rng:
|
||||
spec = rng.split("=", 1)[1]
|
||||
s, _, e = spec.partition("-")
|
||||
start = int(s)
|
||||
end = int(e) if e else len(payload) - 1
|
||||
# Maliciously overrun: append extra bytes past the requested end.
|
||||
body = payload[start : end + 1] + bytes(extra)
|
||||
return web.Response(
|
||||
status=206,
|
||||
body=body,
|
||||
headers={
|
||||
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"ETag": PAYLOAD_ETAG,
|
||||
},
|
||||
)
|
||||
return web.Response(
|
||||
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
|
||||
)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _short_range_handler(payload: bytes, drop: int = 64 * 1024):
|
||||
"""A 206 server that returns fewer bytes than requested for later segments.
|
||||
|
||||
Simulates a server cleanly closing a range connection early. The response
|
||||
is internally consistent (Content-Length matches the short body), so the
|
||||
client sees no error and the segment just ends short, leaving a zero-filled
|
||||
hole in the preallocated file.
|
||||
"""
|
||||
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
rng = request.headers.get("Range")
|
||||
if rng:
|
||||
spec = rng.split("=", 1)[1]
|
||||
s, _, e = spec.partition("-")
|
||||
start = int(s)
|
||||
end = int(e) if e else len(payload) - 1
|
||||
chunk = payload[start : end + 1]
|
||||
if start > 0 and len(chunk) > drop:
|
||||
chunk = chunk[:-drop] # truncate a non-first segment
|
||||
return web.Response(
|
||||
status=206,
|
||||
body=chunk,
|
||||
headers={
|
||||
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"ETag": PAYLOAD_ETAG,
|
||||
},
|
||||
)
|
||||
return web.Response(
|
||||
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
|
||||
)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _unbounded_handler(total: int, chunk: int = 16384):
|
||||
"""A 200 stream with no Content-Length / Accept-Ranges (unknown length)."""
|
||||
|
||||
async def handler(request: web.Request) -> web.StreamResponse:
|
||||
resp = web.StreamResponse(status=200)
|
||||
await resp.prepare(request)
|
||||
sent = 0
|
||||
while sent < total:
|
||||
await resp.write(bytes(min(chunk, total - sent)))
|
||||
sent += chunk
|
||||
await resp.write_eof()
|
||||
return resp
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
async def _serve(handler):
|
||||
app = web.Application()
|
||||
app.router.add_route("*", "/{name:.*}", handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", 0)
|
||||
await site.start()
|
||||
port = site._server.sockets[0].getsockname()[1]
|
||||
return runner, port
|
||||
|
||||
|
||||
def _insert(model_id: str, url: str, status: str = DownloadStatus.QUEUED) -> tuple[str, str, str]:
|
||||
final_path, temp_path = paths.resolve_destination(model_id)
|
||||
download_id = str(uuid.uuid4())
|
||||
queries.insert_download(
|
||||
{
|
||||
"id": download_id,
|
||||
"url": url,
|
||||
"model_id": model_id,
|
||||
"dest_path": final_path,
|
||||
"temp_path": temp_path,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
return download_id, final_path, temp_path
|
||||
|
||||
|
||||
# ----- single-stream -----
|
||||
|
||||
|
||||
def test_single_stream_download(model_root):
|
||||
payload = _safetensors_payload(300_000)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_noranges_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, _temp = _insert("loras/single.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/single.safetensors",
|
||||
dest_path=final_path, temp_path=_temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
|
||||
assert os.path.exists(final_path)
|
||||
assert open(final_path, "rb").read() == payload
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- segmented -----
|
||||
|
||||
|
||||
def test_segmented_download(model_root):
|
||||
payload = _safetensors_payload(4 * 1024 * 1024) # 4 MiB -> multiple segments
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_range_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/seg.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/seg.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
|
||||
assert open(final_path, "rb").read() == payload
|
||||
# More than one segment row was planned.
|
||||
assert len(queries.list_segments(did)) > 1
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- deterministic resume from a partial file -----
|
||||
|
||||
|
||||
def test_resume_from_partial(model_root):
|
||||
payload = _safetensors_payload(512 * 1024) # < 1 MiB -> single segment
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_range_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/resume.safetensors", url)
|
||||
# Simulate a prior partial: first 200 KiB already written, offset persisted.
|
||||
prefix = 200 * 1024
|
||||
os.makedirs(os.path.dirname(temp), exist_ok=True)
|
||||
with open(temp, "wb") as f:
|
||||
f.write(payload[:prefix])
|
||||
queries.update_download(did, bytes_done=prefix, etag=PAYLOAD_ETAG)
|
||||
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/resume.safetensors",
|
||||
dest_path=final_path, temp_path=temp, etag=PAYLOAD_ETAG,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
|
||||
assert open(final_path, "rb").read() == payload
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- cancel rollback -----
|
||||
|
||||
|
||||
def test_cancel_rollback(model_root, monkeypatch):
|
||||
monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False)
|
||||
payload = _payload(1024 * 1024)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_slow_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/cancel.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/cancel.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
task = asyncio.ensure_future(job.run())
|
||||
# Wait until some bytes have been written, then cancel.
|
||||
for _ in range(200):
|
||||
await asyncio.sleep(0.01)
|
||||
if job.state.bytes_done > 0:
|
||||
break
|
||||
job.request_cancel()
|
||||
status = await task
|
||||
assert status == DownloadStatus.CANCELLED
|
||||
assert not os.path.exists(temp)
|
||||
assert not os.path.exists(final_path)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- size-bound enforcement (malicious / non-conforming hosts) -----
|
||||
|
||||
|
||||
def test_segment_overflow_aborts(model_root):
|
||||
"""A 206 returning more than the requested range must not overrun."""
|
||||
payload = _payload(4 * 1024 * 1024) # large enough to segment
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_overflow_range_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/overflow.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/overflow.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.FAILED
|
||||
assert not os.path.exists(final_path)
|
||||
assert not os.path.exists(temp)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_short_segment_fails_closed(model_root):
|
||||
"""A segment that ends short must fail, not be accepted as complete.
|
||||
|
||||
The file is preallocated to total_bytes, so the on-disk size still equals
|
||||
total even with a zero-filled hole; completeness must be judged per-segment.
|
||||
"""
|
||||
payload = _safetensors_payload(4 * 1024 * 1024) # large enough to segment
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_short_range_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/short.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/short.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.FAILED, queries.get_download(did).error
|
||||
assert "incomplete" in (queries.get_download(did).error or "")
|
||||
assert not os.path.exists(final_path)
|
||||
assert not os.path.exists(temp)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_structural_validation_rejects_corrupt(model_root):
|
||||
"""A correctly sized but structurally invalid file fails closed (not retried).
|
||||
|
||||
Regression for the dead structural gate: validation must key off the
|
||||
destination extension, not the ``.part`` temp suffix.
|
||||
"""
|
||||
payload = _payload(300_000) # right size, but not a valid safetensors blob
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_noranges_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/corrupt.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/corrupt.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.FAILED, queries.get_download(did).error
|
||||
assert not os.path.exists(final_path)
|
||||
assert not os.path.exists(temp)
|
||||
# Failed closed at first attempt, not re-queued as retryable.
|
||||
assert queries.get_download(did).attempts == 0
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_rejects_oversized_known_download(model_root, monkeypatch):
|
||||
"""A file whose advertised size exceeds the cap is rejected at probe."""
|
||||
monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False)
|
||||
payload = _payload(300_000)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_noranges_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/toobig.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/toobig.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.FAILED
|
||||
assert not os.path.exists(final_path)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_unknown_length_capped_by_max_bytes(model_root, monkeypatch):
|
||||
"""An unbounded unknown-length stream is capped by --download-max-bytes."""
|
||||
monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False)
|
||||
monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_unbounded_handler(2 * 1024 * 1024))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/unbounded.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/unbounded.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.FAILED
|
||||
assert not os.path.exists(final_path)
|
||||
assert not os.path.exists(temp)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- manager + scheduler end-to-end -----
|
||||
|
||||
|
||||
def test_manager_enqueue_to_completion(model_root):
|
||||
payload = _safetensors_payload(2 * 1024 * 1024)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER
|
||||
|
||||
runner, port = await _serve(_range_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/e2e.safetensors")
|
||||
# Wait for completion.
|
||||
final_path, _ = paths.resolve_destination("loras/e2e.safetensors")
|
||||
for _ in range(500):
|
||||
await asyncio.sleep(0.02)
|
||||
row = queries.get_download(did)
|
||||
if row.status in DownloadStatus.TERMINAL:
|
||||
break
|
||||
row = queries.get_download(did)
|
||||
assert row.status == DownloadStatus.COMPLETED, row.error
|
||||
assert open(final_path, "rb").read() == payload
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_manager_rejects_disallowed_url(model_root):
|
||||
async def _run():
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
|
||||
|
||||
with pytest.raises(DownloadError) as ei:
|
||||
await DOWNLOAD_MANAGER.enqueue(
|
||||
"https://evil.example.com/x.safetensors", "loras/bad.safetensors"
|
||||
)
|
||||
assert ei.value.code == "URL_NOT_ALLOWED"
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_manager_resolves_extensionless_url(model_root):
|
||||
"""An allowlisted URL with no extension in its path is resolved from the
|
||||
response, and the stored file adopts the resolved extension."""
|
||||
payload = _safetensors_payload(1 * 1024 * 1024)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER
|
||||
|
||||
runner, port = await _serve(
|
||||
_content_disposition_handler(payload, "RealModel.safetensors")
|
||||
)
|
||||
try:
|
||||
# No extension in the path (Civitai-style) and none in the model_id.
|
||||
url = f"http://127.0.0.1:{port}/api/download/models/12345"
|
||||
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/my_civitai_model")
|
||||
|
||||
row = queries.get_download(did)
|
||||
# The resolved extension was appended to the model_id + destination.
|
||||
assert row.model_id == "loras/my_civitai_model.safetensors"
|
||||
assert row.dest_path.endswith("my_civitai_model.safetensors")
|
||||
|
||||
final_path, _ = paths.resolve_destination(
|
||||
"loras/my_civitai_model.safetensors"
|
||||
)
|
||||
for _ in range(500):
|
||||
await asyncio.sleep(0.02)
|
||||
row = queries.get_download(did)
|
||||
if row.status in DownloadStatus.TERMINAL:
|
||||
break
|
||||
row = queries.get_download(did)
|
||||
assert row.status == DownloadStatus.COMPLETED, row.error
|
||||
assert open(final_path, "rb").read() == payload
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_manager_overrides_extension_from_resolution(model_root):
|
||||
"""A model_id carrying a different known extension is corrected to match
|
||||
the resolved URL's extension."""
|
||||
payload = _safetensors_payload(256 * 1024)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER
|
||||
|
||||
runner, port = await _serve(
|
||||
_content_disposition_handler(payload, "weights.safetensors")
|
||||
)
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/api/download/models/777"
|
||||
# Caller guessed .ckpt; resolution says .safetensors -> corrected.
|
||||
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/guessed.ckpt")
|
||||
row = queries.get_download(did)
|
||||
assert row.model_id == "loras/guessed.safetensors"
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_manager_rejects_non_model_resolution(model_root):
|
||||
"""A URL that resolves to a non-model file is rejected, not downloaded."""
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
|
||||
|
||||
runner, port = await _serve(
|
||||
_content_disposition_handler(b"not a model", "installer.zip")
|
||||
)
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/api/download/models/999"
|
||||
with pytest.raises(DownloadError) as ei:
|
||||
await DOWNLOAD_MANAGER.enqueue(url, "loras/whatever")
|
||||
assert ei.value.code == "URL_NOT_ALLOWED"
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
81
tests-unit/model_downloader_test/test_planner_structural.py
Normal file
81
tests-unit/model_downloader_test/test_planner_structural.py
Normal file
@ -0,0 +1,81 @@
|
||||
"""Unit tests for the segment planner and structural safetensors validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import struct
|
||||
|
||||
import pytest
|
||||
|
||||
from app.model_downloader.engine.planner import (
|
||||
effective_segment_count,
|
||||
plan_segments,
|
||||
)
|
||||
from app.model_downloader.verify import structural
|
||||
|
||||
|
||||
# ----- planner -----
|
||||
|
||||
|
||||
def test_plan_segments_covers_full_range_contiguously():
|
||||
total = 1000
|
||||
plans = plan_segments(total, 4)
|
||||
assert len(plans) == 4
|
||||
assert plans[0].start == 0
|
||||
assert plans[-1].end == total - 1
|
||||
# contiguous, no gaps/overlaps
|
||||
for a, b in zip(plans, plans[1:]):
|
||||
assert b.start == a.end + 1
|
||||
assert sum(p.length for p in plans) == total
|
||||
|
||||
|
||||
def test_effective_segment_count_falls_back_to_single():
|
||||
# No range support -> single
|
||||
assert effective_segment_count(10_000_000, False, 8) == 1
|
||||
# Unknown size -> single
|
||||
assert effective_segment_count(None, True, 8) == 1
|
||||
# Tiny file -> fewer segments than configured
|
||||
assert effective_segment_count(1024, True, 8) == 1
|
||||
# Large file with range support -> configured count
|
||||
assert effective_segment_count(1_000_000_000, True, 8) == 8
|
||||
|
||||
|
||||
# ----- structural -----
|
||||
|
||||
|
||||
def _make_safetensors(tensor_data_len: int, *, corrupt_size: bool = False) -> bytes:
|
||||
header = {"t": {"dtype": "F32", "shape": [tensor_data_len], "data_offsets": [0, tensor_data_len]}}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
body = b"\x00" * tensor_data_len
|
||||
if corrupt_size:
|
||||
body = body[:-1] # truncate one byte
|
||||
return struct.pack("<Q", len(header_bytes)) + header_bytes + body
|
||||
|
||||
|
||||
def test_structural_valid_safetensors(tmp_path):
|
||||
p = tmp_path / "ok.safetensors"
|
||||
p.write_bytes(_make_safetensors(256))
|
||||
structural.validate(str(p)) # no raise
|
||||
|
||||
|
||||
def test_structural_detects_truncation(tmp_path):
|
||||
p = tmp_path / "bad.safetensors"
|
||||
p.write_bytes(_make_safetensors(256, corrupt_size=True))
|
||||
with pytest.raises(structural.StructuralError):
|
||||
structural.validate(str(p))
|
||||
|
||||
|
||||
def test_structural_skips_unknown_extension(tmp_path):
|
||||
p = tmp_path / "weights.bin"
|
||||
p.write_bytes(b"anything")
|
||||
structural.validate(str(p)) # no structural check, no raise
|
||||
|
||||
|
||||
def test_structural_detects_truncation_via_name_hint(tmp_path):
|
||||
# The downloader validates the opaque temp file (a ``.part`` path) but keys
|
||||
# the format check off the final destination name via ``name_hint``, so
|
||||
# truncation must still be detected instead of silently skipped.
|
||||
p = tmp_path / "bad.comfy-download.part"
|
||||
p.write_bytes(_make_safetensors(256, corrupt_size=True))
|
||||
with pytest.raises(structural.StructuralError):
|
||||
structural.validate(str(p), name_hint="model.safetensors")
|
||||
231
tests-unit/model_downloader_test/test_security.py
Normal file
231
tests-unit/model_downloader_test/test_security.py
Normal file
@ -0,0 +1,231 @@
|
||||
"""Unit tests for the security layer: allowlist, SSRF checks, path safety."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.model_downloader.security import allowlist, paths
|
||||
from app.model_downloader.security.ssrf import (
|
||||
SSRFError,
|
||||
check_redirect_hop,
|
||||
is_blocked_ip,
|
||||
)
|
||||
|
||||
|
||||
# ----- allowlist -----
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url,allowed",
|
||||
[
|
||||
("https://huggingface.co/org/repo/resolve/main/model.safetensors", True),
|
||||
("https://civitai.com/api/download/x/model.safetensors", True),
|
||||
("http://localhost/model.safetensors", True),
|
||||
# off-list host
|
||||
("https://evil.example.com/model.safetensors", False),
|
||||
# http to a non-loopback allowlisted host is not permitted (https only)
|
||||
("http://huggingface.co/org/repo/resolve/main/model.safetensors", False),
|
||||
# bad extension on an allowed host
|
||||
("https://huggingface.co/org/repo/resolve/main/config.json", False),
|
||||
# userinfo trick: real host is the metadata IP, not 127.0.0.1
|
||||
("http://127.0.0.1@169.254.169.254/x.safetensors", False),
|
||||
],
|
||||
)
|
||||
def test_is_url_allowed(url, allowed):
|
||||
assert allowlist.is_url_allowed(url) is allowed
|
||||
|
||||
|
||||
def test_allow_any_extension_relaxes_extension_only():
|
||||
url = "https://huggingface.co/org/repo/resolve/main/weights.bin"
|
||||
assert allowlist.is_url_allowed(url) is True # .bin is in the known set
|
||||
odd = "https://huggingface.co/org/repo/resolve/main/weights.zip"
|
||||
assert allowlist.is_url_allowed(odd) is False
|
||||
assert allowlist.is_url_allowed(odd, allow_any_extension=True) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url,downloadable",
|
||||
[
|
||||
# known model extension in the path -> allowed
|
||||
("https://civitai.com/x/model.safetensors", True),
|
||||
# no extension in the path (Civitai download API) -> allowed, resolved later
|
||||
("https://civitai.com/api/download/models/3031464?fileId=2910346", True),
|
||||
("https://civitai.com/api/download/models/3031464", True),
|
||||
# explicit non-model extension -> rejected even on an allowed host
|
||||
("https://civitai.com/api/download/models/thing.zip", False),
|
||||
("https://huggingface.co/org/repo/resolve/main/config.json", False),
|
||||
# off-list host is never downloadable
|
||||
("https://evil.example.com/api/download/models/1", False),
|
||||
# http to a non-loopback allowlisted host is not permitted
|
||||
("http://civitai.com/api/download/models/1", False),
|
||||
],
|
||||
)
|
||||
def test_is_url_downloadable(url, downloadable):
|
||||
assert allowlist.is_url_downloadable(url) is downloadable
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name,ext",
|
||||
[
|
||||
("model.safetensors", ".safetensors"),
|
||||
("model.SAFETENSORS", ".safetensors"),
|
||||
("archive.tar.gz", ".gz"),
|
||||
("noext", ""),
|
||||
(".safetensors", ""), # leading-dot dotfile -> no extension
|
||||
("a/b/c/model.ckpt", ".ckpt"),
|
||||
],
|
||||
)
|
||||
def test_filename_extension(name, ext):
|
||||
assert allowlist.filename_extension(name) == ext
|
||||
|
||||
|
||||
# ----- SSRF: blocked IPs -----
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ip,blocked",
|
||||
[
|
||||
("169.254.169.254", True), # cloud metadata / link-local
|
||||
("127.0.0.1", True),
|
||||
("10.0.0.5", True),
|
||||
("192.168.1.1", True),
|
||||
("172.16.0.1", True),
|
||||
("::1", True),
|
||||
("0.0.0.0", True),
|
||||
# IPv4-mapped IPv6: must see through the mapping even on CPython
|
||||
# versions predating the gh-113171 is_* property fix.
|
||||
("::ffff:169.254.169.254", True), # mapped cloud metadata
|
||||
("::ffff:127.0.0.1", True), # mapped loopback
|
||||
("::ffff:10.0.0.1", True), # mapped RFC1918
|
||||
("::ffff:8.8.8.8", False), # mapped public address stays allowed
|
||||
("8.8.8.8", False),
|
||||
("1.1.1.1", False),
|
||||
("not-an-ip", True), # unparseable -> refuse
|
||||
],
|
||||
)
|
||||
def test_is_blocked_ip(ip, blocked):
|
||||
assert is_blocked_ip(ip) is blocked
|
||||
|
||||
|
||||
# ----- SSRF: redirect hop validation -----
|
||||
|
||||
|
||||
def test_check_redirect_hop_rejects_bad_scheme_and_userinfo():
|
||||
with pytest.raises(SSRFError):
|
||||
check_redirect_hop("ftp://huggingface.co/x.safetensors")
|
||||
with pytest.raises(SSRFError):
|
||||
check_redirect_hop("https://user:pass@cdn.example.com/x")
|
||||
# A CDN host that is NOT on the allowlist is allowed as a redirect target
|
||||
# (private-IP protection is the resolver's job; credential leak is prevented
|
||||
# by exact host matching).
|
||||
assert check_redirect_hop("https://cdn-lfs.huggingface.co/abc") is not None
|
||||
|
||||
|
||||
def test_check_redirect_hop_http_only_for_loopback():
|
||||
# Plain http to an external host is rejected (no plaintext downgrade).
|
||||
with pytest.raises(SSRFError):
|
||||
check_redirect_hop("http://cdn-lfs.huggingface.co/abc")
|
||||
# http is honored for loopback only on the initial user-supplied URL (the
|
||||
# "download a local model" feature).
|
||||
assert (
|
||||
check_redirect_hop("http://localhost/x.safetensors", is_initial_url=True)
|
||||
is not None
|
||||
)
|
||||
assert (
|
||||
check_redirect_hop("http://127.0.0.1/x.safetensors", is_initial_url=True)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def test_check_redirect_hop_blocks_loopback_and_ip_literals_on_redirect():
|
||||
# A redirect (is_initial_url=False, the default) must never reach loopback,
|
||||
# whether by hostname or by IP literal, nor any other internal IP literal.
|
||||
for target in (
|
||||
"http://localhost/x.safetensors",
|
||||
"http://127.0.0.1/x.safetensors",
|
||||
"https://[::1]/x.safetensors",
|
||||
"https://169.254.169.254/x.safetensors", # cloud metadata
|
||||
"https://10.0.0.5/x.safetensors", # RFC1918
|
||||
):
|
||||
with pytest.raises(SSRFError):
|
||||
check_redirect_hop(target)
|
||||
# Off-allowlist public CDN hosts (hostnames) remain valid redirect targets;
|
||||
# their resolved IPs are screened by the connector's resolver.
|
||||
assert check_redirect_hop("https://cdn-lfs.huggingface.co/abc") is not None
|
||||
|
||||
|
||||
# ----- path safety -----
|
||||
|
||||
|
||||
def test_parse_model_id_valid(model_root):
|
||||
directory, filename = paths.parse_model_id("loras/my_lora.safetensors")
|
||||
assert directory == "loras"
|
||||
assert filename == "my_lora.safetensors"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_id",
|
||||
[
|
||||
"loras/../etc/passwd.safetensors", # traversal
|
||||
"loras/sub/dir.safetensors", # nested
|
||||
"unknownfolder/x.safetensors", # unknown folder
|
||||
"loras/model.txt", # bad extension
|
||||
"noslash.safetensors", # missing directory
|
||||
"loras/", # empty filename
|
||||
],
|
||||
)
|
||||
def test_parse_model_id_rejects(model_root, model_id):
|
||||
with pytest.raises(paths.InvalidModelId):
|
||||
paths.parse_model_id(model_id)
|
||||
|
||||
|
||||
def test_resolve_destination_stays_in_root(model_root):
|
||||
final_path, temp_path = paths.resolve_destination("loras/x.safetensors")
|
||||
assert final_path.startswith(model_root)
|
||||
assert temp_path.startswith(model_root)
|
||||
assert temp_path != final_path
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_id,ext,expected",
|
||||
[
|
||||
# no extension -> append the resolved one
|
||||
("loras/my_civitai_model", ".safetensors", "loras/my_civitai_model.safetensors"),
|
||||
# different known extension -> replace it
|
||||
("loras/mymodel.ckpt", ".safetensors", "loras/mymodel.safetensors"),
|
||||
# same extension -> unchanged
|
||||
("loras/mymodel.safetensors", ".safetensors", "loras/mymodel.safetensors"),
|
||||
# non-model suffix is treated as a stem, extension appended
|
||||
("loras/my.model.v2", ".safetensors", "loras/my.model.v2.safetensors"),
|
||||
# malformed (no slash) is returned untouched for parse_model_id to reject
|
||||
("noslash", ".safetensors", "noslash"),
|
||||
],
|
||||
)
|
||||
def test_apply_extension(model_id, ext, expected):
|
||||
assert paths.apply_extension(model_id, ext) == expected
|
||||
|
||||
|
||||
# ----- Content-Disposition filename parsing -----
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"header,expected",
|
||||
[
|
||||
('attachment; filename="model.safetensors"', "model.safetensors"),
|
||||
("attachment; filename=model.ckpt", "model.ckpt"),
|
||||
# RFC 5987 form is preferred and percent-decoded
|
||||
(
|
||||
"attachment; filename=\"fallback.bin\"; filename*=UTF-8''my%20model.safetensors",
|
||||
"my model.safetensors",
|
||||
),
|
||||
# directory components in a hostile header are stripped to the basename
|
||||
('attachment; filename="../../etc/passwd"', "passwd"),
|
||||
('attachment; filename="a\\\\b\\\\model.pt"', "model.pt"),
|
||||
("inline", None),
|
||||
(None, None),
|
||||
],
|
||||
)
|
||||
def test_filename_from_content_disposition(header, expected):
|
||||
from app.model_downloader.net.http import filename_from_content_disposition
|
||||
|
||||
assert filename_from_content_disposition(header) == expected
|
||||
Reference in New Issue
Block a user