mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-28 00:46:52 +08:00
Compare commits
3 Commits
fix/valida
...
model_down
| Author | SHA1 | Date | |
|---|---|---|---|
| 8d18675e75 | |||
| 603d891eaf | |||
| 470ac36a0a |
118
alembic_db/versions/0005_download_manager.py
Normal file
118
alembic_db/versions/0005_download_manager.py
Normal file
@ -0,0 +1,118 @@
|
||||
"""
|
||||
Download manager schema.
|
||||
|
||||
Adds the three tables that back the server-side model download manager
|
||||
(PRD section 7): transient job/queue state (``downloads`` + per-segment
|
||||
``download_segments``) and one-API-key-per-host auth (``host_credentials``).
|
||||
|
||||
The local file catalog / dedup index is intentionally NOT added here — it
|
||||
is owned by the assets system (``assets`` / ``asset_references``).
|
||||
|
||||
Revision ID: 0005_download_manager
|
||||
Revises: 0004_drop_tag_type
|
||||
Create Date: 2026-06-27
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0005_download_manager"
|
||||
down_revision = "0004_drop_tag_type"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"downloads",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("url", sa.Text(), nullable=False),
|
||||
sa.Column("final_url", sa.Text(), nullable=True),
|
||||
sa.Column("model_id", sa.String(length=1024), nullable=False),
|
||||
sa.Column("dest_path", sa.Text(), nullable=False),
|
||||
sa.Column("temp_path", sa.Text(), nullable=False),
|
||||
sa.Column("status", sa.String(length=16), nullable=False),
|
||||
sa.Column("priority", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("total_bytes", sa.BigInteger(), nullable=True),
|
||||
sa.Column("bytes_done", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("etag", sa.String(length=512), nullable=True),
|
||||
sa.Column("last_modified", sa.String(length=128), nullable=True),
|
||||
sa.Column(
|
||||
"accept_ranges", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column("expected_sha256", sa.String(length=64), nullable=True),
|
||||
sa.Column("credential_id", sa.String(length=36), nullable=True),
|
||||
sa.Column(
|
||||
"allow_any_extension",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
sa.Column("attempts", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("error", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
sa.CheckConstraint("bytes_done >= 0", name="ck_downloads_bytes_done_nonneg"),
|
||||
sa.CheckConstraint(
|
||||
"total_bytes IS NULL OR total_bytes >= 0",
|
||||
name="ck_downloads_total_bytes_nonneg",
|
||||
),
|
||||
)
|
||||
op.create_index("ix_downloads_status", "downloads", ["status"])
|
||||
op.create_index("ix_downloads_priority", "downloads", ["priority"])
|
||||
op.create_index("ix_downloads_model_id", "downloads", ["model_id"])
|
||||
|
||||
op.create_table(
|
||||
"download_segments",
|
||||
sa.Column(
|
||||
"download_id",
|
||||
sa.String(length=36),
|
||||
sa.ForeignKey("downloads.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("idx", sa.Integer(), nullable=False),
|
||||
sa.Column("start_offset", sa.BigInteger(), nullable=False),
|
||||
sa.Column("end_offset", sa.BigInteger(), nullable=False),
|
||||
sa.Column("bytes_done", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.PrimaryKeyConstraint("download_id", "idx", name="pk_download_segments"),
|
||||
sa.CheckConstraint("bytes_done >= 0", name="ck_segments_bytes_done_nonneg"),
|
||||
sa.CheckConstraint("end_offset >= start_offset", name="ck_segments_range"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"host_credentials",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("host", sa.String(length=255), nullable=False),
|
||||
sa.Column(
|
||||
"match_subdomains",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
sa.Column("label", sa.String(length=255), nullable=True),
|
||||
sa.Column(
|
||||
"auth_scheme", sa.String(length=16), nullable=False, server_default="bearer"
|
||||
),
|
||||
sa.Column("header_name", sa.String(length=255), nullable=True),
|
||||
sa.Column("query_param", sa.String(length=255), nullable=True),
|
||||
sa.Column("secret", sa.Text(), nullable=False),
|
||||
sa.Column("secret_last4", sa.String(length=4), nullable=True),
|
||||
sa.Column("enabled", sa.Boolean(), nullable=False, server_default=sa.text("true")),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
)
|
||||
op.create_index(
|
||||
"uq_host_credentials_host", "host_credentials", ["host"], unique=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("uq_host_credentials_host", table_name="host_credentials")
|
||||
op.drop_table("host_credentials")
|
||||
|
||||
op.drop_table("download_segments")
|
||||
|
||||
op.drop_index("ix_downloads_model_id", table_name="downloads")
|
||||
op.drop_index("ix_downloads_priority", table_name="downloads")
|
||||
op.drop_index("ix_downloads_status", table_name="downloads")
|
||||
op.drop_table("downloads")
|
||||
@ -21,6 +21,7 @@ try:
|
||||
|
||||
from app.database.models import Base
|
||||
import app.assets.database.models # noqa: F401 — register models with Base.metadata
|
||||
import app.model_downloader.database.models # noqa: F401 — register models with Base.metadata
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
|
||||
203
app/model_downloader/api/routes.py
Normal file
203
app/model_downloader/api/routes.py
Normal file
@ -0,0 +1,203 @@
|
||||
"""aiohttp routes for the download manager.
|
||||
|
||||
Endpoint surface (all under ``/api/download``), mirroring the response
|
||||
envelope used by ``app/assets/api/routes.py``:
|
||||
|
||||
POST /api/download/enqueue
|
||||
GET /api/download
|
||||
POST /api/download/availability
|
||||
POST /api/download/credentials
|
||||
GET /api/download/credentials
|
||||
GET /api/download/credentials/{id}
|
||||
DELETE /api/download/credentials/{id}
|
||||
GET /api/download/{id}
|
||||
POST /api/download/{id}/pause
|
||||
POST /api/download/{id}/resume
|
||||
POST /api/download/{id}/cancel
|
||||
POST /api/download/{id}/priority
|
||||
|
||||
Note on ordering: the static ``credentials`` routes are registered before the
|
||||
dynamic ``/api/download/{id}`` route so a request to ``.../credentials`` is not
|
||||
captured as ``id == "credentials"``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from aiohttp import web
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from app.model_downloader.api import schemas_in, schemas_out
|
||||
from app.model_downloader.credentials.store import (
|
||||
CREDENTIAL_STORE,
|
||||
CredentialValidationError,
|
||||
)
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
|
||||
|
||||
def register_routes(app: web.Application) -> None:
|
||||
"""Wire the download-manager routes into the running aiohttp app."""
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
|
||||
# ----- envelope helpers (same shape as app/assets/api/routes.py) -----
|
||||
|
||||
|
||||
def _error(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
|
||||
return web.json_response(
|
||||
{"error": {"code": code, "message": message, "details": details or {}}},
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def _ok(payload, status: int = 200) -> web.Response:
|
||||
return web.json_response(payload, status=status)
|
||||
|
||||
|
||||
async def _parse(request: web.Request, model: type[BaseModel]):
|
||||
try:
|
||||
raw = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return _error(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
try:
|
||||
return model.model_validate(raw)
|
||||
except ValidationError as ve:
|
||||
return _error(400, "INVALID_BODY", "Validation failed.", {"errors": json.loads(ve.json())})
|
||||
|
||||
|
||||
def _from_download_error(e: DownloadError) -> web.Response:
|
||||
return _error(e.http_status, e.code, e.message)
|
||||
|
||||
|
||||
# ----- downloads: collection + enqueue + availability -----
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/enqueue")
|
||||
async def enqueue(request: web.Request) -> web.Response:
|
||||
parsed = await _parse(request, schemas_in.EnqueueRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
try:
|
||||
download_id = await DOWNLOAD_MANAGER.enqueue(
|
||||
parsed.url,
|
||||
parsed.model_id,
|
||||
priority=parsed.priority,
|
||||
expected_sha256=parsed.expected_sha256,
|
||||
allow_any_extension=parsed.allow_any_extension,
|
||||
credential_id=parsed.credential_id,
|
||||
)
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"download_id": download_id, "accepted": True}, status=202)
|
||||
|
||||
|
||||
@ROUTES.get("/api/download")
|
||||
async def list_downloads(request: web.Request) -> web.Response:
|
||||
return _ok({"downloads": await DOWNLOAD_MANAGER.list()})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/availability")
|
||||
async def availability(request: web.Request) -> web.Response:
|
||||
parsed = await _parse(request, schemas_in.AvailabilityRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
return _ok({"models": await DOWNLOAD_MANAGER.availability(parsed.models)})
|
||||
|
||||
|
||||
# ----- credentials (secrets are write-only) — must precede /{id} -----
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/credentials")
|
||||
async def upsert_credential(request: web.Request) -> web.Response:
|
||||
parsed = await _parse(request, schemas_in.CredentialUpsertRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
try:
|
||||
view = await CREDENTIAL_STORE.upsert(
|
||||
parsed.host,
|
||||
parsed.secret,
|
||||
auth_scheme=parsed.auth_scheme,
|
||||
header_name=parsed.header_name,
|
||||
query_param=parsed.query_param,
|
||||
label=parsed.label,
|
||||
match_subdomains=parsed.match_subdomains,
|
||||
enabled=parsed.enabled,
|
||||
)
|
||||
except CredentialValidationError as e:
|
||||
return _error(400, "INVALID_CREDENTIAL", str(e))
|
||||
return _ok(schemas_out.credential_to_dict(view), status=201)
|
||||
|
||||
|
||||
@ROUTES.get("/api/download/credentials")
|
||||
async def list_credentials(request: web.Request) -> web.Response:
|
||||
views = await CREDENTIAL_STORE.list()
|
||||
return _ok({"credentials": [schemas_out.credential_to_dict(v) for v in views]})
|
||||
|
||||
|
||||
@ROUTES.get("/api/download/credentials/{id}")
|
||||
async def get_credential(request: web.Request) -> web.Response:
|
||||
view = await CREDENTIAL_STORE.get(request.match_info["id"])
|
||||
if view is None:
|
||||
return _error(404, "NOT_FOUND", "No such credential.")
|
||||
return _ok(schemas_out.credential_to_dict(view))
|
||||
|
||||
|
||||
@ROUTES.delete("/api/download/credentials/{id}")
|
||||
async def delete_credential(request: web.Request) -> web.Response:
|
||||
deleted = await CREDENTIAL_STORE.delete(request.match_info["id"])
|
||||
if not deleted:
|
||||
return _error(404, "NOT_FOUND", "No such credential.")
|
||||
return _ok({"deleted": True})
|
||||
|
||||
|
||||
# ----- single download by id (dynamic; registered last) -----
|
||||
|
||||
|
||||
@ROUTES.get("/api/download/{id}")
|
||||
async def get_download(request: web.Request) -> web.Response:
|
||||
view = await DOWNLOAD_MANAGER.status(request.match_info["id"])
|
||||
if view is None:
|
||||
return _error(404, "NOT_FOUND", "No such download.")
|
||||
return _ok(view)
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/{id}/pause")
|
||||
async def pause(request: web.Request) -> web.Response:
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.pause(request.match_info["id"])
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"ok": True})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/{id}/resume")
|
||||
async def resume(request: web.Request) -> web.Response:
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.resume(request.match_info["id"])
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"ok": True})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/{id}/cancel")
|
||||
async def cancel(request: web.Request) -> web.Response:
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.cancel(request.match_info["id"])
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"ok": True})
|
||||
|
||||
|
||||
@ROUTES.post("/api/download/{id}/priority")
|
||||
async def set_priority(request: web.Request) -> web.Response:
|
||||
parsed = await _parse(request, schemas_in.PriorityRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.set_priority(request.match_info["id"], parsed.priority)
|
||||
except DownloadError as e:
|
||||
return _from_download_error(e)
|
||||
return _ok({"ok": True})
|
||||
51
app/model_downloader/api/schemas_in.py
Normal file
51
app/model_downloader/api/schemas_in.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""Request schemas for the download manager API.
|
||||
|
||||
Pydantic enforces shape at the boundary; handlers operate only on validated
|
||||
values past that point.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.model_downloader.constants import AUTH_SCHEME_BEARER
|
||||
|
||||
|
||||
class EnqueueRequest(BaseModel):
|
||||
url: str
|
||||
model_id: str
|
||||
priority: int = 0
|
||||
expected_sha256: Optional[str] = None
|
||||
allow_any_extension: bool = False
|
||||
credential_id: Optional[str] = None
|
||||
|
||||
|
||||
class PriorityRequest(BaseModel):
|
||||
priority: int
|
||||
|
||||
|
||||
class AvailabilityRequest(BaseModel):
|
||||
"""``{model_id: url}`` — the URLs declared in the workflow JSON."""
|
||||
|
||||
models: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class CredentialUpsertRequest(BaseModel):
|
||||
host: str
|
||||
secret: str
|
||||
auth_scheme: str = AUTH_SCHEME_BEARER
|
||||
header_name: Optional[str] = None
|
||||
query_param: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
match_subdomains: bool = False
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EnqueueRequest",
|
||||
"PriorityRequest",
|
||||
"AvailabilityRequest",
|
||||
"CredentialUpsertRequest",
|
||||
]
|
||||
26
app/model_downloader/api/schemas_out.py
Normal file
26
app/model_downloader/api/schemas_out.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""Response helpers for the download manager API.
|
||||
|
||||
The download/status read models are plain dicts produced by the manager. This
|
||||
module only needs to mask credentials for output (the secret is never returned).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.model_downloader.credentials.store import CredentialView
|
||||
|
||||
|
||||
def credential_to_dict(view: CredentialView) -> dict:
|
||||
"""API-safe credential representation — never includes the secret."""
|
||||
return {
|
||||
"id": view.id,
|
||||
"host": view.host,
|
||||
"auth_scheme": view.auth_scheme,
|
||||
"header_name": view.header_name,
|
||||
"query_param": view.query_param,
|
||||
"label": view.label,
|
||||
"match_subdomains": view.match_subdomains,
|
||||
"enabled": view.enabled,
|
||||
"secret_last4": view.secret_last4,
|
||||
"created_at": view.created_at,
|
||||
"updated_at": view.updated_at,
|
||||
}
|
||||
38
app/model_downloader/constants.py
Normal file
38
app/model_downloader/constants.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""Shared constants for the download manager.
|
||||
|
||||
Status values are persisted as TEXT in the ``downloads`` table; keep them
|
||||
stable. The lifecycle is (PRD section 6):
|
||||
|
||||
queued -> active -> verifying -> completed
|
||||
| |-> paused -> (resume) -> active
|
||||
| |-> failed (network, retryable) -> queued (backoff)
|
||||
|-> cancelled
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Auth schemes for HostCredential (PRD section 9.4.1).
|
||||
AUTH_SCHEME_BEARER = "bearer"
|
||||
AUTH_SCHEME_HEADER = "header"
|
||||
AUTH_SCHEME_QUERY = "query"
|
||||
AUTH_SCHEMES = (AUTH_SCHEME_BEARER, AUTH_SCHEME_HEADER, AUTH_SCHEME_QUERY)
|
||||
|
||||
|
||||
class DownloadStatus:
|
||||
QUEUED = "queued"
|
||||
ACTIVE = "active"
|
||||
PAUSED = "paused"
|
||||
VERIFYING = "verifying"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
#: States from which a worker is doing (or about to do) network I/O.
|
||||
LIVE = (QUEUED, ACTIVE, VERIFYING)
|
||||
#: Terminal states — the job will not transition again on its own.
|
||||
TERMINAL = (COMPLETED, FAILED, CANCELLED)
|
||||
|
||||
|
||||
# Default temp-file suffix. Distinctive so the startup orphan sweep only
|
||||
# removes files THIS subsystem created, never unrelated *.tmp files.
|
||||
TMP_SUFFIX = ".comfy-download.part"
|
||||
99
app/model_downloader/credentials/resolver.py
Normal file
99
app/model_downloader/credentials/resolver.py
Normal file
@ -0,0 +1,99 @@
|
||||
"""Turn a stored credential into a per-hop request modifier (PRD section 9.4.2).
|
||||
|
||||
The critical rule: a credential is only ever attached when *the current hop's
|
||||
host* matches a stored credential, and only over https. This is recomputed
|
||||
from scratch on every redirect hop, so a token bound to ``huggingface.co`` is
|
||||
silently dropped when the request is redirected to a presigned CDN host —
|
||||
which is exactly what these hubs expect.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
|
||||
|
||||
from app.model_downloader.constants import (
|
||||
AUTH_SCHEME_BEARER,
|
||||
AUTH_SCHEME_HEADER,
|
||||
AUTH_SCHEME_QUERY,
|
||||
)
|
||||
from app.model_downloader.credentials.store import normalize_host
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.database.models import HostCredential
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestAuth:
|
||||
"""How to modify a single request to carry a credential."""
|
||||
|
||||
headers: dict[str, str] = field(default_factory=dict)
|
||||
query: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def apply_to_url(self, url: str) -> str:
|
||||
if not self.query:
|
||||
return url
|
||||
parts = urlsplit(url)
|
||||
params = dict(parse_qsl(parts.query, keep_blank_values=True))
|
||||
params.update(self.query)
|
||||
return urlunsplit(parts._replace(query=urlencode(params)))
|
||||
|
||||
|
||||
def _matches(cred: HostCredential, hop_host: str) -> bool:
|
||||
cred_host = cred.host
|
||||
if hop_host == cred_host:
|
||||
return True
|
||||
if cred.match_subdomains:
|
||||
# Label-boundary suffix: api.example.com matches example.com, but
|
||||
# evil-example.com does NOT.
|
||||
return hop_host.endswith("." + cred_host)
|
||||
return False
|
||||
|
||||
|
||||
def _build_auth(cred: HostCredential) -> RequestAuth:
|
||||
if cred.auth_scheme == AUTH_SCHEME_BEARER:
|
||||
return RequestAuth(headers={"Authorization": f"Bearer {cred.secret}"})
|
||||
if cred.auth_scheme == AUTH_SCHEME_HEADER:
|
||||
name = cred.header_name or "Authorization"
|
||||
return RequestAuth(headers={name: cred.secret})
|
||||
if cred.auth_scheme == AUTH_SCHEME_QUERY and cred.query_param:
|
||||
return RequestAuth(query={cred.query_param: cred.secret})
|
||||
return RequestAuth()
|
||||
|
||||
|
||||
def _resolve_sync(
|
||||
host: str, scheme: str, explicit_credential_id: Optional[str]
|
||||
) -> Optional[RequestAuth]:
|
||||
# Never attach a secret over a non-https hop (PRD section 9.4.2).
|
||||
if scheme.lower() != "https":
|
||||
return None
|
||||
hop_host = normalize_host(host)
|
||||
if not hop_host:
|
||||
return None
|
||||
|
||||
if explicit_credential_id is not None:
|
||||
cred = queries.get_credential(explicit_credential_id)
|
||||
# An explicit credential is still subject to the per-hop host check —
|
||||
# it is not forced onto a non-matching host.
|
||||
if cred is None or not cred.enabled or not _matches(cred, hop_host):
|
||||
return None
|
||||
return _build_auth(cred)
|
||||
|
||||
# Auto-resolve: exact host first, then any subdomain-matching credential.
|
||||
cred = queries.get_credential_by_host(hop_host)
|
||||
if cred is not None and cred.enabled:
|
||||
return _build_auth(cred)
|
||||
for sub in queries.list_subdomain_credentials():
|
||||
if sub.enabled and _matches(sub, hop_host):
|
||||
return _build_auth(sub)
|
||||
return None
|
||||
|
||||
|
||||
async def resolve_auth_for_hop(
|
||||
host: str, scheme: str, *, explicit_credential_id: Optional[str] = None
|
||||
) -> Optional[RequestAuth]:
|
||||
"""Resolve the credential (if any) to attach for one request hop."""
|
||||
return await asyncio.to_thread(
|
||||
_resolve_sync, host, scheme, explicit_credential_id
|
||||
)
|
||||
137
app/model_downloader/credentials/store.py
Normal file
137
app/model_downloader/credentials/store.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""The credential store: one API key per host (PRD section 9.4).
|
||||
|
||||
Secrets are write-only over the API — :class:`CredentialView` carries only
|
||||
masked metadata (``secret_last4`` + scheme + label), never the secret itself.
|
||||
At-rest protection for v1 is filesystem permissions on the shared DB (the DB
|
||||
is the trust boundary); encryption-at-rest is a noted future seam.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from app.model_downloader.constants import (
|
||||
AUTH_SCHEME_BEARER,
|
||||
AUTH_SCHEME_HEADER,
|
||||
AUTH_SCHEME_QUERY,
|
||||
AUTH_SCHEMES,
|
||||
)
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.database.models import HostCredential
|
||||
|
||||
|
||||
def normalize_host(host: str) -> str:
|
||||
"""Lowercase, strip port, IDNA-encode (PRD section 9.4.3)."""
|
||||
if not host:
|
||||
return ""
|
||||
host = host.strip().lower()
|
||||
if host.startswith("[") and "]" in host: # bracketed IPv6 literal
|
||||
host = host[1 : host.index("]")]
|
||||
elif host.count(":") == 1: # host:port (not IPv6)
|
||||
host = host.split(":", 1)[0]
|
||||
try:
|
||||
host = host.encode("idna").decode("ascii")
|
||||
except (UnicodeError, ValueError):
|
||||
pass
|
||||
return host
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CredentialView:
|
||||
"""Masked, API-safe view of a credential — never includes the secret."""
|
||||
|
||||
id: str
|
||||
host: str
|
||||
auth_scheme: str
|
||||
header_name: Optional[str]
|
||||
query_param: Optional[str]
|
||||
label: Optional[str]
|
||||
match_subdomains: bool
|
||||
enabled: bool
|
||||
secret_last4: Optional[str]
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
def _to_view(row: HostCredential) -> CredentialView:
|
||||
return CredentialView(
|
||||
id=row.id,
|
||||
host=row.host,
|
||||
auth_scheme=row.auth_scheme,
|
||||
header_name=row.header_name,
|
||||
query_param=row.query_param,
|
||||
label=row.label,
|
||||
match_subdomains=row.match_subdomains,
|
||||
enabled=row.enabled,
|
||||
secret_last4=row.secret_last4,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class CredentialValidationError(ValueError):
|
||||
"""A credential upsert had inconsistent fields."""
|
||||
|
||||
|
||||
class CredentialStore:
|
||||
"""Async facade over the ``host_credentials`` table.
|
||||
|
||||
DB access is synchronous (SQLite) and offloaded via ``asyncio.to_thread``.
|
||||
"""
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
host: str,
|
||||
secret: str,
|
||||
*,
|
||||
auth_scheme: str = AUTH_SCHEME_BEARER,
|
||||
header_name: Optional[str] = None,
|
||||
query_param: Optional[str] = None,
|
||||
label: Optional[str] = None,
|
||||
match_subdomains: bool = False,
|
||||
enabled: bool = True,
|
||||
) -> CredentialView:
|
||||
host = normalize_host(host)
|
||||
if not host:
|
||||
raise CredentialValidationError("host is required")
|
||||
if not secret:
|
||||
raise CredentialValidationError("secret is required")
|
||||
if auth_scheme not in AUTH_SCHEMES:
|
||||
raise CredentialValidationError(
|
||||
f"auth_scheme must be one of {AUTH_SCHEMES}, got {auth_scheme!r}"
|
||||
)
|
||||
if auth_scheme == AUTH_SCHEME_HEADER and not header_name:
|
||||
header_name = "Authorization"
|
||||
if auth_scheme == AUTH_SCHEME_QUERY and not query_param:
|
||||
raise CredentialValidationError(
|
||||
"query_param is required when auth_scheme='query'"
|
||||
)
|
||||
values = {
|
||||
"host": host,
|
||||
"secret": secret,
|
||||
"secret_last4": secret[-4:] if len(secret) >= 4 else secret,
|
||||
"auth_scheme": auth_scheme,
|
||||
"header_name": header_name,
|
||||
"query_param": query_param,
|
||||
"label": label,
|
||||
"match_subdomains": match_subdomains,
|
||||
"enabled": enabled,
|
||||
}
|
||||
row = await asyncio.to_thread(queries.upsert_credential, values)
|
||||
return _to_view(row)
|
||||
|
||||
async def list(self) -> list[CredentialView]:
|
||||
rows = await asyncio.to_thread(queries.list_credentials)
|
||||
return [_to_view(r) for r in rows]
|
||||
|
||||
async def get(self, credential_id: str) -> Optional[CredentialView]:
|
||||
row = await asyncio.to_thread(queries.get_credential, credential_id)
|
||||
return _to_view(row) if row is not None else None
|
||||
|
||||
async def delete(self, credential_id: str) -> bool:
|
||||
return await asyncio.to_thread(queries.delete_credential, credential_id)
|
||||
|
||||
|
||||
CREDENTIAL_STORE = CredentialStore()
|
||||
162
app/model_downloader/database/models.py
Normal file
162
app/model_downloader/database/models.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""SQLAlchemy models for the download manager.
|
||||
|
||||
Three tables (PRD section 7):
|
||||
|
||||
- ``downloads`` one row per requested file (job + queue state).
|
||||
- ``download_segments`` per-segment byte progress, for segmented resume.
|
||||
- ``host_credentials`` one API key per host, reused across downloads.
|
||||
|
||||
The local file catalog / dedup index is NOT here — that is owned by the
|
||||
assets system (``assets`` / ``asset_references``). On completion a finished
|
||||
file is registered into the assets catalog; ``downloads`` is kept only as
|
||||
job history.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database.models import Base
|
||||
|
||||
|
||||
def _uuid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _now() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
class Download(Base):
|
||||
__tablename__ = "downloads"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||
# Original requested URL and the final URL after validated redirects.
|
||||
url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
final_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# Canonical "<directory>/<filename>" identifier (resolved via folder_paths).
|
||||
model_id: Mapped[str] = mapped_column(String(1024), nullable=False)
|
||||
# Final on-disk location and the .part write target.
|
||||
dest_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
temp_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
|
||||
status: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
priority: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
total_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
bytes_done: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
|
||||
etag: Mapped[str | None] = mapped_column(String(512), nullable=True)
|
||||
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
accept_ranges: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Optional hub-provided checksum to verify against (NOT the dedup key).
|
||||
expected_sha256: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
|
||||
# Explicit credential override; otherwise auto-resolved by host.
|
||||
credential_id: Mapped[str | None] = mapped_column(String(36), nullable=True)
|
||||
allow_any_extension: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
# How many retryable failures we have seen (for backoff capping).
|
||||
attempts: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
error: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[int] = mapped_column(BigInteger, nullable=False, default=_now)
|
||||
updated_at: Mapped[int] = mapped_column(
|
||||
BigInteger, nullable=False, default=_now, onupdate=_now
|
||||
)
|
||||
|
||||
segments: Mapped[list[DownloadSegment]] = relationship(
|
||||
"DownloadSegment",
|
||||
back_populates="download",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
order_by="DownloadSegment.idx",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_downloads_status", "status"),
|
||||
Index("ix_downloads_priority", "priority"),
|
||||
Index("ix_downloads_model_id", "model_id"),
|
||||
CheckConstraint("bytes_done >= 0", name="ck_downloads_bytes_done_nonneg"),
|
||||
CheckConstraint(
|
||||
"total_bytes IS NULL OR total_bytes >= 0",
|
||||
name="ck_downloads_total_bytes_nonneg",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Download id={self.id} model_id={self.model_id!r} status={self.status}>"
|
||||
|
||||
|
||||
class DownloadSegment(Base):
|
||||
__tablename__ = "download_segments"
|
||||
|
||||
download_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("downloads.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
idx: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
start_offset: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
end_offset: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
bytes_done: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
|
||||
download: Mapped[Download] = relationship("Download", back_populates="segments")
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("bytes_done >= 0", name="ck_segments_bytes_done_nonneg"),
|
||||
CheckConstraint("end_offset >= start_offset", name="ck_segments_range"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<DownloadSegment {self.download_id}#{self.idx} "
|
||||
f"{self.start_offset}-{self.end_offset} done={self.bytes_done}>"
|
||||
)
|
||||
|
||||
|
||||
class HostCredential(Base):
|
||||
__tablename__ = "host_credentials"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
|
||||
# Normalized lowercase hostname, e.g. "civitai.com".
|
||||
host: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
match_subdomains: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
label: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
auth_scheme: Mapped[str] = mapped_column(
|
||||
String(16), nullable=False, default="bearer"
|
||||
)
|
||||
header_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
query_param: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
# The API key itself. Write-only over the API; never returned. See PRD 9.4.4.
|
||||
secret: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
secret_last4: Mapped[str | None] = mapped_column(String(4), nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
created_at: Mapped[int] = mapped_column(BigInteger, nullable=False, default=_now)
|
||||
updated_at: Mapped[int] = mapped_column(
|
||||
BigInteger, nullable=False, default=_now, onupdate=_now
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_host_credentials_host", "host", unique=True),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<HostCredential id={self.id} host={self.host!r} scheme={self.auth_scheme}>"
|
||||
235
app/model_downloader/database/queries.py
Normal file
235
app/model_downloader/database/queries.py
Normal file
@ -0,0 +1,235 @@
|
||||
"""Synchronous DB access for the download manager.
|
||||
|
||||
All functions open their own short-lived session via ``create_session`` and
|
||||
commit before returning, mirroring ``app/assets`` usage. They are blocking
|
||||
(SQLite) and should be called from async code through ``asyncio.to_thread``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database.models import (
|
||||
Download,
|
||||
DownloadSegment,
|
||||
HostCredential,
|
||||
)
|
||||
|
||||
|
||||
# ----- downloads -----
|
||||
|
||||
|
||||
def insert_download(values: dict) -> None:
|
||||
with create_session() as session:
|
||||
session.add(Download(**values))
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_download(download_id: str) -> Optional[Download]:
|
||||
with create_session() as session:
|
||||
row = session.get(Download, download_id)
|
||||
if row is not None:
|
||||
session.expunge_all()
|
||||
return row
|
||||
|
||||
|
||||
def list_downloads() -> list[Download]:
|
||||
with create_session() as session:
|
||||
rows = list(
|
||||
session.execute(
|
||||
select(Download).order_by(Download.created_at.desc())
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return rows
|
||||
|
||||
|
||||
def list_segments(download_id: str) -> list[DownloadSegment]:
|
||||
with create_session() as session:
|
||||
rows = list(
|
||||
session.execute(
|
||||
select(DownloadSegment)
|
||||
.where(DownloadSegment.download_id == download_id)
|
||||
.order_by(DownloadSegment.idx)
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return rows
|
||||
|
||||
|
||||
def update_download(download_id: str, **fields) -> None:
|
||||
if not fields:
|
||||
return
|
||||
fields.setdefault("updated_at", int(time.time()))
|
||||
with create_session() as session:
|
||||
row = session.get(Download, download_id)
|
||||
if row is None:
|
||||
return
|
||||
for key, value in fields.items():
|
||||
setattr(row, key, value)
|
||||
session.commit()
|
||||
|
||||
|
||||
def delete_download(download_id: str) -> None:
|
||||
with create_session() as session:
|
||||
row = session.get(Download, download_id)
|
||||
if row is not None:
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
|
||||
|
||||
def replace_segments(download_id: str, segments: list[dict]) -> None:
|
||||
"""Atomically replace the segment plan for a download."""
|
||||
with create_session() as session:
|
||||
session.query(DownloadSegment).filter(
|
||||
DownloadSegment.download_id == download_id
|
||||
).delete()
|
||||
for seg in segments:
|
||||
session.add(DownloadSegment(download_id=download_id, **seg))
|
||||
session.commit()
|
||||
|
||||
|
||||
def update_segment_progress(download_id: str, idx: int, bytes_done: int) -> None:
|
||||
with create_session() as session:
|
||||
row = session.get(DownloadSegment, {"download_id": download_id, "idx": idx})
|
||||
if row is None:
|
||||
return
|
||||
row.bytes_done = bytes_done
|
||||
session.commit()
|
||||
|
||||
|
||||
def list_queued_downloads() -> list[Download]:
|
||||
"""Queued rows ordered for admission (priority desc, then FIFO)."""
|
||||
with create_session() as session:
|
||||
rows = list(
|
||||
session.execute(
|
||||
select(Download)
|
||||
.where(Download.status == DownloadStatus.QUEUED)
|
||||
.order_by(Download.priority.desc(), Download.created_at.asc())
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return rows
|
||||
|
||||
|
||||
def reconcile_live_downloads() -> list[Download]:
|
||||
"""Reset any ``active``/``verifying`` rows left by a previous run.
|
||||
|
||||
On a clean restart there can be no live worker, so anything still marked
|
||||
live is stale. Move it back to ``queued`` (offsets are preserved on the
|
||||
segment rows) so the scheduler re-admits it. Returns the rows that should
|
||||
be re-queued by the scheduler (queued + paused).
|
||||
"""
|
||||
with create_session() as session:
|
||||
stale = list(
|
||||
session.execute(
|
||||
select(Download).where(
|
||||
Download.status.in_([DownloadStatus.ACTIVE, DownloadStatus.VERIFYING])
|
||||
)
|
||||
).scalars()
|
||||
)
|
||||
now = int(time.time())
|
||||
for row in stale:
|
||||
row.status = DownloadStatus.QUEUED
|
||||
row.updated_at = now
|
||||
session.commit()
|
||||
|
||||
resumable = list(
|
||||
session.execute(
|
||||
select(Download)
|
||||
.where(Download.status == DownloadStatus.QUEUED)
|
||||
.order_by(Download.priority.desc(), Download.created_at.asc())
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return resumable
|
||||
|
||||
|
||||
# ----- host credentials -----
|
||||
|
||||
|
||||
def get_credential(credential_id: str) -> Optional[HostCredential]:
|
||||
with create_session() as session:
|
||||
row = session.get(HostCredential, credential_id)
|
||||
if row is not None:
|
||||
session.expunge_all()
|
||||
return row
|
||||
|
||||
|
||||
def get_credential_by_host(host: str) -> Optional[HostCredential]:
|
||||
with create_session() as session:
|
||||
row = (
|
||||
session.execute(
|
||||
select(HostCredential).where(HostCredential.host == host).limit(1)
|
||||
)
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
if row is not None:
|
||||
session.expunge_all()
|
||||
return row
|
||||
|
||||
|
||||
def list_credentials() -> list[HostCredential]:
|
||||
with create_session() as session:
|
||||
rows = list(
|
||||
session.execute(
|
||||
select(HostCredential).order_by(HostCredential.host)
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return rows
|
||||
|
||||
|
||||
def list_subdomain_credentials() -> list[HostCredential]:
|
||||
"""Credentials that opted into subdomain matching, for suffix checks."""
|
||||
with create_session() as session:
|
||||
rows = list(
|
||||
session.execute(
|
||||
select(HostCredential).where(HostCredential.match_subdomains.is_(True))
|
||||
).scalars()
|
||||
)
|
||||
session.expunge_all()
|
||||
return rows
|
||||
|
||||
|
||||
def upsert_credential(values: dict) -> HostCredential:
|
||||
"""Insert or update a credential keyed by ``host``."""
|
||||
host = values["host"]
|
||||
now = int(time.time())
|
||||
with create_session() as session:
|
||||
row = (
|
||||
session.execute(
|
||||
select(HostCredential).where(HostCredential.host == host).limit(1)
|
||||
)
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
if row is None:
|
||||
row = HostCredential(**values)
|
||||
row.created_at = now
|
||||
row.updated_at = now
|
||||
session.add(row)
|
||||
else:
|
||||
for key, value in values.items():
|
||||
setattr(row, key, value)
|
||||
row.updated_at = now
|
||||
session.commit()
|
||||
session.refresh(row)
|
||||
session.expunge(row)
|
||||
return row
|
||||
|
||||
|
||||
def delete_credential(credential_id: str) -> bool:
|
||||
with create_session() as session:
|
||||
row = session.get(HostCredential, credential_id)
|
||||
if row is None:
|
||||
return False
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
return True
|
||||
443
app/model_downloader/engine/job.py
Normal file
443
app/model_downloader/engine/job.py
Normal file
@ -0,0 +1,443 @@
|
||||
"""The per-download worker (PRD sections 5, 6, 8, 12).
|
||||
|
||||
One :class:`DownloadJob` drives a single file from probe to verified, cataloged
|
||||
completion. It supports cooperative pause / resume / cancel, segmented
|
||||
multi-connection transfer with positioned writes, and a verification gate
|
||||
(size + structural + optional sha256) before the atomic rename into place.
|
||||
|
||||
Control is cooperative: external callers flip ``_control`` via
|
||||
:meth:`request_pause` / :meth:`request_cancel`; segment loops observe it between
|
||||
chunks and raise, which unwinds cleanly and persists resume offsets.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Optional
|
||||
|
||||
from comfy.cli_args import args
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.engine.planner import (
|
||||
SegmentPlan,
|
||||
effective_segment_count,
|
||||
plan_segments,
|
||||
)
|
||||
from app.model_downloader.engine.writer import FileWriter
|
||||
from app.model_downloader.net.http import open_validated
|
||||
from app.model_downloader.net.probe import probe
|
||||
from app.model_downloader.verify import checksum, dedup, structural
|
||||
|
||||
_RETRYABLE_STATUSES = {408, 429, 500, 502, 503, 504}
|
||||
_PERSIST_INTERVAL = 2.0 # seconds between throttled progress persists
|
||||
|
||||
|
||||
class Paused(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Cancelled(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RemoteChanged(Exception):
|
||||
"""The remote file changed under a resume (got 200 where 206 expected)."""
|
||||
|
||||
|
||||
class RetryableError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FatalError(Exception):
|
||||
"""Non-retryable: 4xx, checksum mismatch, structural failure, gated, etc."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SegmentRuntime:
|
||||
idx: int
|
||||
start: int
|
||||
end: int # inclusive; may be -1 for unknown-size single stream
|
||||
bytes_done: int = 0
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return self.end - self.start + 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeState:
|
||||
download_id: str
|
||||
model_id: str
|
||||
url: str
|
||||
priority: int
|
||||
status: str
|
||||
total_bytes: Optional[int] = None
|
||||
bytes_done: int = 0
|
||||
error: Optional[str] = None
|
||||
segments: list[SegmentRuntime] = field(default_factory=list)
|
||||
started_at: float = field(default_factory=time.monotonic)
|
||||
_last_bytes: int = 0
|
||||
_last_time: float = field(default_factory=time.monotonic)
|
||||
speed_bps: float = 0.0
|
||||
|
||||
@property
|
||||
def progress(self) -> Optional[float]:
|
||||
if not self.total_bytes:
|
||||
return None
|
||||
return min(1.0, self.bytes_done / self.total_bytes)
|
||||
|
||||
@property
|
||||
def eta_seconds(self) -> Optional[float]:
|
||||
if not self.total_bytes or self.speed_bps <= 0:
|
||||
return None
|
||||
remaining = max(0, self.total_bytes - self.bytes_done)
|
||||
return remaining / self.speed_bps
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobSpec:
|
||||
download_id: str
|
||||
url: str
|
||||
model_id: str
|
||||
dest_path: str
|
||||
temp_path: str
|
||||
priority: int = 0
|
||||
credential_id: Optional[str] = None
|
||||
expected_sha256: Optional[str] = None
|
||||
allow_any_extension: bool = False
|
||||
etag: Optional[str] = None
|
||||
attempts: int = 0
|
||||
|
||||
|
||||
class DownloadJob:
|
||||
def __init__(
|
||||
self, spec: JobSpec, notify_cb: Optional[Callable[[str], None]] = None
|
||||
) -> None:
|
||||
self.spec = spec
|
||||
self._notify = notify_cb
|
||||
self._control = "run" # run | pause | cancel
|
||||
self.state = RuntimeState(
|
||||
download_id=spec.download_id,
|
||||
model_id=spec.model_id,
|
||||
url=spec.url,
|
||||
priority=spec.priority,
|
||||
status=DownloadStatus.QUEUED,
|
||||
)
|
||||
self._writer: Optional[FileWriter] = None
|
||||
self._etag: Optional[str] = spec.etag
|
||||
self._last_persist = 0.0
|
||||
|
||||
# ----- external control -----
|
||||
|
||||
def request_pause(self) -> None:
|
||||
if self._control == "run":
|
||||
self._control = "pause"
|
||||
|
||||
def request_cancel(self) -> None:
|
||||
self._control = "cancel"
|
||||
|
||||
def _check_control(self) -> None:
|
||||
if self._control == "cancel":
|
||||
raise Cancelled()
|
||||
if self._control == "pause":
|
||||
raise Paused()
|
||||
|
||||
# ----- lifecycle -----
|
||||
|
||||
async def run(self) -> str:
|
||||
"""Run to a terminal/paused state; returns the final status string."""
|
||||
self._set_status(DownloadStatus.ACTIVE, error=None)
|
||||
try:
|
||||
pr = await self._probe_and_plan()
|
||||
await self._transfer(pr)
|
||||
await self._finalize()
|
||||
self._set_status(DownloadStatus.COMPLETED)
|
||||
except Paused:
|
||||
await self._persist_progress(force=True)
|
||||
self._set_status(DownloadStatus.PAUSED)
|
||||
except Cancelled:
|
||||
await self._close_writer()
|
||||
self._remove_temp()
|
||||
self._set_status(DownloadStatus.CANCELLED)
|
||||
except RemoteChanged:
|
||||
await self._reset_for_restart()
|
||||
self._set_status(
|
||||
DownloadStatus.QUEUED, error="remote file changed; restarting"
|
||||
)
|
||||
except RetryableError as e:
|
||||
await self._persist_progress(force=True)
|
||||
self._set_status(DownloadStatus.QUEUED, error=str(e))
|
||||
except FatalError as e:
|
||||
await self._close_writer()
|
||||
self._remove_temp()
|
||||
self._set_status(DownloadStatus.FAILED, error=str(e))
|
||||
except Exception as e: # unexpected -> treat as retryable
|
||||
logging.warning(
|
||||
"[model_downloader] %s unexpected error: %s",
|
||||
self.spec.model_id, e, exc_info=True,
|
||||
)
|
||||
await self._persist_progress(force=True)
|
||||
self._set_status(DownloadStatus.QUEUED, error=f"{type(e).__name__}: {e}")
|
||||
finally:
|
||||
await self._close_writer()
|
||||
return self.state.status
|
||||
|
||||
# ----- probe + plan -----
|
||||
|
||||
async def _probe_and_plan(self):
|
||||
pr = await probe(self.spec.url, credential_id=self.spec.credential_id)
|
||||
if not pr.ok:
|
||||
if pr.gated:
|
||||
raise FatalError(
|
||||
f"{self.spec.url} requires authentication. Add an API key for "
|
||||
f"this host at /api/download/credentials and retry."
|
||||
)
|
||||
if pr.status == 0 or pr.status in _RETRYABLE_STATUSES:
|
||||
raise RetryableError(pr.error or "probe failed")
|
||||
raise FatalError(pr.error or f"probe returned HTTP {pr.status}")
|
||||
|
||||
self._etag = pr.etag or self._etag
|
||||
self.state.total_bytes = pr.total_bytes
|
||||
queries.update_download(
|
||||
self.spec.download_id,
|
||||
final_url=pr.final_url,
|
||||
total_bytes=pr.total_bytes,
|
||||
accept_ranges=pr.accept_ranges,
|
||||
etag=pr.etag,
|
||||
last_modified=pr.last_modified,
|
||||
)
|
||||
|
||||
seg_count = effective_segment_count(
|
||||
pr.total_bytes, pr.accept_ranges, max(1, args.download_segments)
|
||||
)
|
||||
existing = queries.list_segments(self.spec.download_id)
|
||||
if (
|
||||
seg_count > 1
|
||||
and existing
|
||||
and pr.total_bytes is not None
|
||||
and existing[-1].end_offset == pr.total_bytes - 1
|
||||
):
|
||||
# Resume an existing segmented plan.
|
||||
self.state.segments = [
|
||||
SegmentRuntime(s.idx, s.start_offset, s.end_offset, s.bytes_done)
|
||||
for s in existing
|
||||
]
|
||||
elif seg_count > 1 and pr.total_bytes is not None:
|
||||
plans = plan_segments(pr.total_bytes, seg_count)
|
||||
queries.replace_segments(
|
||||
self.spec.download_id,
|
||||
[
|
||||
{"idx": p.idx, "start_offset": p.start, "end_offset": p.end, "bytes_done": 0}
|
||||
for p in plans
|
||||
],
|
||||
)
|
||||
self.state.segments = [SegmentRuntime(p.idx, p.start, p.end, 0) for p in plans]
|
||||
else:
|
||||
# Single-stream: one logical segment; bytes_done tracked on the row.
|
||||
row = queries.get_download(self.spec.download_id)
|
||||
resume_from = row.bytes_done if row else 0
|
||||
end = (pr.total_bytes - 1) if pr.total_bytes else -1
|
||||
self.state.segments = [SegmentRuntime(0, 0, end, resume_from)]
|
||||
self._recompute_bytes_done()
|
||||
return pr
|
||||
|
||||
# ----- transfer -----
|
||||
|
||||
async def _transfer(self, pr) -> None:
|
||||
self._writer = FileWriter(self.spec.temp_path)
|
||||
await self._writer.open()
|
||||
|
||||
segmented = len(self.state.segments) > 1
|
||||
if segmented and self.state.total_bytes:
|
||||
await self._writer.preallocate(self.state.total_bytes)
|
||||
await self._run_segmented()
|
||||
else:
|
||||
await self._run_single()
|
||||
|
||||
await self._writer.flush()
|
||||
|
||||
async def _run_segmented(self) -> None:
|
||||
pending = [
|
||||
asyncio.ensure_future(self._run_segment(seg))
|
||||
for seg in self.state.segments
|
||||
if seg.bytes_done < seg.length
|
||||
]
|
||||
if not pending:
|
||||
return
|
||||
done, not_done = await asyncio.wait(
|
||||
pending, return_when=asyncio.FIRST_EXCEPTION
|
||||
)
|
||||
first_exc: Optional[BaseException] = None
|
||||
for task in done:
|
||||
exc = task.exception()
|
||||
if exc is not None and first_exc is None:
|
||||
first_exc = exc
|
||||
if first_exc is not None:
|
||||
for task in not_done:
|
||||
task.cancel()
|
||||
await asyncio.gather(*not_done, return_exceptions=True)
|
||||
raise first_exc
|
||||
|
||||
async def _run_segment(self, seg: SegmentRuntime) -> None:
|
||||
offset = seg.start + seg.bytes_done
|
||||
headers = {
|
||||
"Range": f"bytes={offset}-{seg.end}",
|
||||
"Accept-Encoding": "identity",
|
||||
}
|
||||
if self._etag:
|
||||
headers["If-Range"] = self._etag
|
||||
async with open_validated(
|
||||
"GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers
|
||||
) as (resp, _final):
|
||||
if resp.status == 200:
|
||||
# Server ignored the range -> remote changed / no resume support.
|
||||
raise RemoteChanged()
|
||||
if resp.status not in (206,):
|
||||
self._raise_for_status(resp.status)
|
||||
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
|
||||
self._check_control()
|
||||
await self._writer.write_at(offset, chunk)
|
||||
offset += len(chunk)
|
||||
seg.bytes_done += len(chunk)
|
||||
self._recompute_bytes_done()
|
||||
await self._persist_progress()
|
||||
|
||||
async def _run_single(self) -> None:
|
||||
seg = self.state.segments[0]
|
||||
offset = seg.bytes_done # resume from here for single-stream
|
||||
headers = {"Accept-Encoding": "identity"}
|
||||
if offset > 0:
|
||||
headers["Range"] = f"bytes={offset}-"
|
||||
if self._etag:
|
||||
headers["If-Range"] = self._etag
|
||||
async with open_validated(
|
||||
"GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers
|
||||
) as (resp, _final):
|
||||
if offset > 0 and resp.status == 200:
|
||||
# Resume not honoured -> start over from the beginning.
|
||||
offset = 0
|
||||
seg.bytes_done = 0
|
||||
elif offset > 0 and resp.status != 206:
|
||||
self._raise_for_status(resp.status)
|
||||
elif offset == 0 and resp.status != 200:
|
||||
self._raise_for_status(resp.status)
|
||||
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
|
||||
self._check_control()
|
||||
await self._writer.write_at(offset, chunk)
|
||||
offset += len(chunk)
|
||||
seg.bytes_done = offset
|
||||
self.state.bytes_done = offset
|
||||
await self._persist_progress()
|
||||
|
||||
def _raise_for_status(self, status: int) -> None:
|
||||
if status in (401, 403):
|
||||
raise FatalError(
|
||||
f"{self.spec.url} returned {status}; add/update an API key for "
|
||||
f"this host at /api/download/credentials."
|
||||
)
|
||||
if status in _RETRYABLE_STATUSES:
|
||||
raise RetryableError(f"HTTP {status}")
|
||||
raise FatalError(f"unexpected HTTP {status}")
|
||||
|
||||
# ----- finalize / verify (PRD section 8.4) -----
|
||||
|
||||
async def _finalize(self) -> None:
|
||||
self._check_control()
|
||||
await self._close_writer()
|
||||
self._set_status(DownloadStatus.VERIFYING)
|
||||
|
||||
total = self.state.total_bytes
|
||||
actual_size = os.path.getsize(self.spec.temp_path)
|
||||
if total is not None and actual_size != total:
|
||||
raise FatalError(
|
||||
f"size mismatch: wrote {actual_size} of {total} bytes"
|
||||
)
|
||||
|
||||
# Structural gate (cheap, no full read) then optional sha256 (full read).
|
||||
await asyncio.to_thread(structural.validate, self.spec.temp_path)
|
||||
if self.spec.expected_sha256:
|
||||
await asyncio.to_thread(
|
||||
checksum.verify_sha256, self.spec.temp_path, self.spec.expected_sha256
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(self.spec.dest_path), exist_ok=True)
|
||||
os.replace(self.spec.temp_path, self.spec.dest_path)
|
||||
logging.info(
|
||||
"[model_downloader] completed %s (%d bytes)",
|
||||
self.spec.model_id, actual_size,
|
||||
)
|
||||
# Catalog into the assets system (blake3 dedup identity). Best-effort.
|
||||
await dedup.register_completed(self.spec.dest_path)
|
||||
|
||||
# ----- helpers -----
|
||||
|
||||
def _recompute_bytes_done(self) -> None:
|
||||
self.state.bytes_done = sum(s.bytes_done for s in self.state.segments)
|
||||
now = time.monotonic()
|
||||
dt = now - self.state._last_time
|
||||
if dt >= 0.5:
|
||||
self.state.speed_bps = (self.state.bytes_done - self.state._last_bytes) / dt
|
||||
self.state._last_bytes = self.state.bytes_done
|
||||
self.state._last_time = now
|
||||
|
||||
async def _persist_progress(self, force: bool = False) -> None:
|
||||
now = time.monotonic()
|
||||
if not force and now - self._last_persist < _PERSIST_INTERVAL:
|
||||
if self._notify:
|
||||
self._notify(self.spec.download_id)
|
||||
return
|
||||
self._last_persist = now
|
||||
queries.update_download(self.spec.download_id, bytes_done=self.state.bytes_done)
|
||||
for seg in self.state.segments:
|
||||
if seg.end >= seg.start: # skip unknown-size sentinel
|
||||
queries.update_segment_progress(
|
||||
self.spec.download_id, seg.idx, seg.bytes_done
|
||||
)
|
||||
if self._notify:
|
||||
self._notify(self.spec.download_id)
|
||||
|
||||
async def _reset_for_restart(self) -> None:
|
||||
await self._close_writer()
|
||||
self._remove_temp()
|
||||
for seg in self.state.segments:
|
||||
seg.bytes_done = 0
|
||||
self.state.bytes_done = 0
|
||||
queries.update_download(self.spec.download_id, bytes_done=0)
|
||||
if queries.list_segments(self.spec.download_id):
|
||||
queries.replace_segments(self.spec.download_id, [])
|
||||
|
||||
async def _close_writer(self) -> None:
|
||||
if self._writer is not None:
|
||||
try:
|
||||
await self._writer.close()
|
||||
except Exception:
|
||||
logging.debug("[model_downloader] writer close error", exc_info=True)
|
||||
self._writer = None
|
||||
|
||||
def _remove_temp(self) -> None:
|
||||
try:
|
||||
os.remove(self.spec.temp_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError as e:
|
||||
logging.warning(
|
||||
"[model_downloader] could not remove %s: %s", self.spec.temp_path, e
|
||||
)
|
||||
|
||||
def _set_status(self, status: str, error: Optional[str] = None) -> None:
|
||||
self.state.status = status
|
||||
if error is not None:
|
||||
self.state.error = error
|
||||
fields = {"status": status, "bytes_done": self.state.bytes_done}
|
||||
if error is not None:
|
||||
fields["error"] = error
|
||||
if status == DownloadStatus.QUEUED:
|
||||
fields["attempts"] = self.spec.attempts + 1
|
||||
self.spec.attempts += 1
|
||||
queries.update_download(self.spec.download_id, **fields)
|
||||
if self._notify:
|
||||
self._notify(self.spec.download_id)
|
||||
51
app/model_downloader/engine/planner.py
Normal file
51
app/model_downloader/engine/planner.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""Segment planning (PRD section 5.2).
|
||||
|
||||
Split a known byte range into S roughly-equal segments, each fetched by its
|
||||
own coroutine with ``Range: bytes=start-end``. Falls back to a single segment
|
||||
when the server doesn't support ranges or the size is unknown/too small for
|
||||
segmentation to be worthwhile.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Below this size, the per-connection setup cost outweighs any parallelism.
|
||||
_MIN_SEGMENT_BYTES = 1 * 1024 * 1024
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SegmentPlan:
|
||||
idx: int
|
||||
start: int
|
||||
end: int # inclusive
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return self.end - self.start + 1
|
||||
|
||||
|
||||
def effective_segment_count(
|
||||
total_bytes: int | None, accept_ranges: bool, configured: int
|
||||
) -> int:
|
||||
"""How many segments to actually use for this file."""
|
||||
if not accept_ranges or total_bytes is None or total_bytes <= 0:
|
||||
return 1
|
||||
by_size = max(1, total_bytes // _MIN_SEGMENT_BYTES)
|
||||
return max(1, min(configured, by_size))
|
||||
|
||||
|
||||
def plan_segments(total_bytes: int, num_segments: int) -> list[SegmentPlan]:
|
||||
"""Return ``num_segments`` contiguous, inclusive byte ranges covering [0, total)."""
|
||||
if total_bytes <= 0 or num_segments <= 1:
|
||||
return [SegmentPlan(idx=0, start=0, end=max(0, total_bytes - 1))]
|
||||
base = total_bytes // num_segments
|
||||
plans: list[SegmentPlan] = []
|
||||
start = 0
|
||||
for i in range(num_segments):
|
||||
# Last segment soaks up the remainder.
|
||||
length = base if i < num_segments - 1 else total_bytes - start
|
||||
end = start + length - 1
|
||||
plans.append(SegmentPlan(idx=i, start=start, end=end))
|
||||
start = end + 1
|
||||
return plans
|
||||
61
app/model_downloader/engine/writer.py
Normal file
61
app/model_downloader/engine/writer.py
Normal file
@ -0,0 +1,61 @@
|
||||
"""Positioned, off-loop file writes (PRD section 4 + 5.2).
|
||||
|
||||
Network I/O stays on the event loop; every blocking disk op (preallocate,
|
||||
positioned write, fsync) is run in a bounded thread pool via
|
||||
``run_in_executor`` so downloads never stall inference or the web server.
|
||||
|
||||
A single file descriptor is opened for the whole download. Segments write to
|
||||
their own offsets with ``os.pwrite`` — which is offset-addressed and atomic
|
||||
per call, so concurrent segment writers need no extra locking. Per-chunk
|
||||
fsync is avoided; we fsync once at completion.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
# One shared, bounded pool for all download disk I/O.
|
||||
_EXECUTOR = ThreadPoolExecutor(max_workers=8, thread_name_prefix="dl-writer")
|
||||
|
||||
|
||||
class FileWriter:
|
||||
"""Owns the ``.part`` file descriptor for one download."""
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self._fd: Optional[int] = None
|
||||
|
||||
def _open(self) -> None:
|
||||
os.makedirs(os.path.dirname(self.path), exist_ok=True)
|
||||
self._fd = os.open(self.path, os.O_RDWR | os.O_CREAT, 0o644)
|
||||
|
||||
async def open(self) -> None:
|
||||
await asyncio.get_running_loop().run_in_executor(_EXECUTOR, self._open)
|
||||
|
||||
async def preallocate(self, size: int) -> None:
|
||||
"""Grow the file to ``size`` so segments write to their offsets."""
|
||||
if self._fd is None or size <= 0:
|
||||
return
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
_EXECUTOR, os.ftruncate, self._fd, size
|
||||
)
|
||||
|
||||
async def write_at(self, offset: int, data: bytes) -> None:
|
||||
assert self._fd is not None, "writer not opened"
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
_EXECUTOR, os.pwrite, self._fd, data, offset
|
||||
)
|
||||
|
||||
async def flush(self) -> None:
|
||||
if self._fd is None:
|
||||
return
|
||||
await asyncio.get_running_loop().run_in_executor(_EXECUTOR, os.fsync, self._fd)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._fd is None:
|
||||
return
|
||||
fd, self._fd = self._fd, None
|
||||
await asyncio.get_running_loop().run_in_executor(_EXECUTOR, os.close, fd)
|
||||
294
app/model_downloader/manager.py
Normal file
294
app/model_downloader/manager.py
Normal file
@ -0,0 +1,294 @@
|
||||
"""Public facade for the download manager (PRD section 10).
|
||||
|
||||
This is the only object the server imports. It validates requests, owns the
|
||||
:class:`Scheduler`, and exposes a small async API plus read models for status.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Callable, Optional
|
||||
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.scheduler import SCHEDULER
|
||||
from app.model_downloader.security import paths
|
||||
from app.model_downloader.security.allowlist import is_url_allowed
|
||||
from app.model_downloader.security.paths import InvalidModelId
|
||||
|
||||
# Non-terminal statuses: an existing row in one of these blocks a re-enqueue.
|
||||
_LIVE_STATUSES = (
|
||||
DownloadStatus.QUEUED,
|
||||
DownloadStatus.ACTIVE,
|
||||
DownloadStatus.PAUSED,
|
||||
DownloadStatus.VERIFYING,
|
||||
)
|
||||
|
||||
|
||||
class DownloadError(Exception):
|
||||
"""A user-facing error with a stable machine-readable code."""
|
||||
|
||||
def __init__(self, code: str, message: str, status: int = 400) -> None:
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.http_status = status
|
||||
|
||||
|
||||
class DownloadManager:
|
||||
def __init__(self) -> None:
|
||||
self._scheduler = SCHEDULER
|
||||
self._notify_cb: Optional[Callable[[str], None]] = None
|
||||
|
||||
def set_notify(self, cb: Optional[Callable[[str], None]]) -> None:
|
||||
self._notify_cb = cb
|
||||
self._scheduler.set_notify(cb)
|
||||
|
||||
async def start(self) -> None:
|
||||
await self._scheduler.start()
|
||||
|
||||
# ----- enqueue -----
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
url: str,
|
||||
model_id: str,
|
||||
*,
|
||||
priority: int = 0,
|
||||
expected_sha256: Optional[str] = None,
|
||||
allow_any_extension: bool = False,
|
||||
credential_id: Optional[str] = None,
|
||||
) -> str:
|
||||
if not is_url_allowed(url, allow_any_extension):
|
||||
raise DownloadError(
|
||||
"URL_NOT_ALLOWED",
|
||||
"URL is not on the download allowlist (host/scheme/extension).",
|
||||
)
|
||||
try:
|
||||
paths.parse_model_id(model_id, allow_any_extension)
|
||||
dest_path, temp_path = paths.resolve_destination(model_id, allow_any_extension)
|
||||
except InvalidModelId as e:
|
||||
raise DownloadError("INVALID_MODEL_ID", str(e))
|
||||
|
||||
if await asyncio.to_thread(
|
||||
paths.resolve_existing, model_id, allow_any_extension
|
||||
):
|
||||
raise DownloadError(
|
||||
"ALREADY_AVAILABLE",
|
||||
f"Model already exists on disk: {model_id}",
|
||||
status=409,
|
||||
)
|
||||
if await self._has_live_download(model_id):
|
||||
raise DownloadError(
|
||||
"ALREADY_DOWNLOADING",
|
||||
f"A download for {model_id} is already in progress.",
|
||||
status=409,
|
||||
)
|
||||
|
||||
download_id = str(uuid.uuid4())
|
||||
await asyncio.to_thread(
|
||||
queries.insert_download,
|
||||
{
|
||||
"id": download_id,
|
||||
"url": url,
|
||||
"model_id": model_id,
|
||||
"dest_path": dest_path,
|
||||
"temp_path": temp_path,
|
||||
"status": DownloadStatus.QUEUED,
|
||||
"priority": priority,
|
||||
"expected_sha256": expected_sha256,
|
||||
"credential_id": credential_id,
|
||||
"allow_any_extension": allow_any_extension,
|
||||
},
|
||||
)
|
||||
logging.info("[model_downloader] enqueued %s -> %s", url, model_id)
|
||||
await self._scheduler.pump()
|
||||
return download_id
|
||||
|
||||
async def _has_live_download(self, model_id: str) -> bool:
|
||||
rows = await asyncio.to_thread(queries.list_downloads)
|
||||
return any(
|
||||
r.model_id == model_id and r.status in _LIVE_STATUSES for r in rows
|
||||
)
|
||||
|
||||
# ----- control -----
|
||||
|
||||
async def pause(self, download_id: str) -> None:
|
||||
job = self._scheduler.get_job(download_id)
|
||||
if job is not None:
|
||||
job.request_pause()
|
||||
return
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
if row is None:
|
||||
raise DownloadError("NOT_FOUND", "No such download.", status=404)
|
||||
if row.status == DownloadStatus.QUEUED:
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, download_id, status=DownloadStatus.PAUSED
|
||||
)
|
||||
|
||||
async def resume(self, download_id: str) -> None:
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
if row is None:
|
||||
raise DownloadError("NOT_FOUND", "No such download.", status=404)
|
||||
if row.status in (DownloadStatus.PAUSED, DownloadStatus.FAILED):
|
||||
await asyncio.to_thread(
|
||||
queries.update_download,
|
||||
download_id,
|
||||
status=DownloadStatus.QUEUED,
|
||||
error=None,
|
||||
)
|
||||
await self._scheduler.pump()
|
||||
|
||||
async def cancel(self, download_id: str) -> None:
|
||||
job = self._scheduler.get_job(download_id)
|
||||
if job is not None:
|
||||
job.request_cancel()
|
||||
return
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
if row is None:
|
||||
raise DownloadError("NOT_FOUND", "No such download.", status=404)
|
||||
if row.status in _LIVE_STATUSES:
|
||||
import os
|
||||
|
||||
try:
|
||||
os.remove(row.temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, download_id, status=DownloadStatus.CANCELLED
|
||||
)
|
||||
|
||||
async def set_priority(self, download_id: str, priority: int) -> None:
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
if row is None:
|
||||
raise DownloadError("NOT_FOUND", "No such download.", status=404)
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, download_id, priority=priority
|
||||
)
|
||||
# Admission-order only (PRD section 13 default); a higher priority is
|
||||
# picked up the next time a slot frees. Pump in case a slot is free now.
|
||||
await self._scheduler.pump()
|
||||
|
||||
# ----- read models -----
|
||||
|
||||
def _view(self, row) -> dict:
|
||||
"""Combine the persisted row with live in-memory progress, if running."""
|
||||
job = self._scheduler.get_job(row.id)
|
||||
bytes_done = row.bytes_done
|
||||
total = row.total_bytes
|
||||
speed = None
|
||||
eta = None
|
||||
segments = None
|
||||
if job is not None:
|
||||
st = job.state
|
||||
bytes_done = st.bytes_done
|
||||
total = st.total_bytes if st.total_bytes is not None else total
|
||||
speed = st.speed_bps
|
||||
eta = st.eta_seconds
|
||||
segments = [
|
||||
{"idx": s.idx, "bytes_done": s.bytes_done, "length": s.length}
|
||||
for s in st.segments
|
||||
if s.end >= s.start
|
||||
]
|
||||
progress = (bytes_done / total) if total else None
|
||||
return {
|
||||
"download_id": row.id,
|
||||
"model_id": row.model_id,
|
||||
"url": row.url,
|
||||
"status": row.status,
|
||||
"priority": row.priority,
|
||||
"total_bytes": total,
|
||||
"bytes_done": bytes_done,
|
||||
"progress": progress,
|
||||
"speed_bps": speed,
|
||||
"eta_seconds": eta,
|
||||
"segments": segments,
|
||||
"error": row.error,
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
}
|
||||
|
||||
def _view_from_state(self, job) -> dict:
|
||||
"""Build a view purely from the live in-memory job state (no DB)."""
|
||||
st = job.state
|
||||
return {
|
||||
"download_id": st.download_id,
|
||||
"model_id": st.model_id,
|
||||
"url": st.url,
|
||||
"status": st.status,
|
||||
"priority": st.priority,
|
||||
"total_bytes": st.total_bytes,
|
||||
"bytes_done": st.bytes_done,
|
||||
"progress": st.progress,
|
||||
"speed_bps": st.speed_bps,
|
||||
"eta_seconds": st.eta_seconds,
|
||||
"segments": [
|
||||
{"idx": s.idx, "bytes_done": s.bytes_done, "length": s.length}
|
||||
for s in st.segments
|
||||
if s.end >= s.start
|
||||
],
|
||||
"error": st.error,
|
||||
}
|
||||
|
||||
def status_sync(self, download_id: str) -> Optional[dict]:
|
||||
"""Synchronous status read for the websocket notify path.
|
||||
|
||||
Uses live in-memory state when the job is running (no DB round-trip on
|
||||
the hot path); falls back to a quick DB read otherwise.
|
||||
"""
|
||||
job = self._scheduler.get_job(download_id)
|
||||
if job is not None:
|
||||
return self._view_from_state(job)
|
||||
row = queries.get_download(download_id)
|
||||
return self._view(row) if row is not None else None
|
||||
|
||||
async def status(self, download_id: str) -> Optional[dict]:
|
||||
row = await asyncio.to_thread(queries.get_download, download_id)
|
||||
return self._view(row) if row is not None else None
|
||||
|
||||
async def list(self) -> list[dict]:
|
||||
rows = await asyncio.to_thread(queries.list_downloads)
|
||||
return [self._view(r) for r in rows]
|
||||
|
||||
async def availability(self, models: dict[str, str]) -> dict[str, dict]:
|
||||
"""Bulk per-id ``{state, progress, ...}`` for the frontend poll.
|
||||
|
||||
``state`` is ``available`` (on disk), ``downloading`` (live row), or
|
||||
``missing``. Cheap: a path lookup plus an in-memory/DB status check.
|
||||
"""
|
||||
rows = await asyncio.to_thread(queries.list_downloads)
|
||||
by_model: dict[str, object] = {}
|
||||
for r in rows:
|
||||
if r.status in _LIVE_STATUSES or r.model_id not in by_model:
|
||||
by_model[r.model_id] = r
|
||||
|
||||
out: dict[str, dict] = {}
|
||||
for model_id, url in models.items():
|
||||
try:
|
||||
exists = await asyncio.to_thread(paths.resolve_existing, model_id)
|
||||
except InvalidModelId:
|
||||
out[model_id] = {"state": "missing", "url_allowed": is_url_allowed(url)}
|
||||
continue
|
||||
if exists:
|
||||
out[model_id] = {"state": "available", "url_allowed": is_url_allowed(url)}
|
||||
continue
|
||||
row = by_model.get(model_id)
|
||||
if row is not None and row.status in _LIVE_STATUSES:
|
||||
view = self._view(row)
|
||||
out[model_id] = {
|
||||
"state": "downloading",
|
||||
"url_allowed": is_url_allowed(url),
|
||||
"download_id": view["download_id"],
|
||||
"progress": view["progress"],
|
||||
"bytes_done": view["bytes_done"],
|
||||
"total_bytes": view["total_bytes"],
|
||||
"speed_bps": view["speed_bps"],
|
||||
}
|
||||
else:
|
||||
out[model_id] = {"state": "missing", "url_allowed": is_url_allowed(url)}
|
||||
return out
|
||||
|
||||
|
||||
DOWNLOAD_MANAGER = DownloadManager()
|
||||
110
app/model_downloader/net/http.py
Normal file
110
app/model_downloader/net/http.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""Manual, validated redirect-following request opener.
|
||||
|
||||
Automatic redirects are disabled (PRD section 9.2): we follow hops ourselves
|
||||
so that on *every* hop we (a) re-validate scheme + reject credentials-in-URL,
|
||||
(b) recompute which stored credential — if any — applies to that hop's host,
|
||||
and (c) let the connector's resolver screen the IP. This is the single place
|
||||
that attaches credentials, so a token can never ride a redirect to a CDN host.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator, Optional
|
||||
from urllib.parse import urljoin, urlsplit, urlunsplit
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.model_downloader.credentials.resolver import resolve_auth_for_hop
|
||||
from app.model_downloader.net.session import get_session
|
||||
from app.model_downloader.security.ssrf import (
|
||||
MAX_REDIRECTS,
|
||||
SSRFError,
|
||||
check_redirect_hop,
|
||||
)
|
||||
|
||||
_REDIRECT_CODES = {301, 302, 303, 307, 308}
|
||||
DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120)
|
||||
|
||||
|
||||
def redact_url(url: str) -> str:
|
||||
"""Drop the query string so a query-scheme secret is never logged/stored."""
|
||||
try:
|
||||
parts = urlsplit(url)
|
||||
except ValueError:
|
||||
return "<unparseable-url>"
|
||||
return urlunsplit(parts._replace(query=""))
|
||||
|
||||
|
||||
async def _resolve_final_response(
|
||||
method: str,
|
||||
url: str,
|
||||
credential_id: Optional[str],
|
||||
base_headers: dict[str, str],
|
||||
timeout: aiohttp.ClientTimeout,
|
||||
) -> tuple[aiohttp.ClientResponse, str]:
|
||||
"""Follow redirects manually until a non-redirect response.
|
||||
|
||||
Each intermediate redirect response is released before the next hop.
|
||||
Returns the final ``(response, final_url)``; the caller owns releasing it.
|
||||
"""
|
||||
session = await get_session()
|
||||
current = url
|
||||
hops = 0
|
||||
while True:
|
||||
check_redirect_hop(current)
|
||||
parts = urlsplit(current)
|
||||
auth = await resolve_auth_for_hop(
|
||||
parts.hostname or "", parts.scheme, explicit_credential_id=credential_id
|
||||
)
|
||||
req_headers = dict(base_headers)
|
||||
req_url = current
|
||||
if auth is not None:
|
||||
req_headers.update(auth.headers)
|
||||
req_url = auth.apply_to_url(current)
|
||||
|
||||
resp = await session.request(
|
||||
method,
|
||||
req_url,
|
||||
allow_redirects=False,
|
||||
headers=req_headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
if resp.status in _REDIRECT_CODES and resp.headers.get("Location"):
|
||||
next_url = urljoin(str(resp.url), resp.headers["Location"])
|
||||
await resp.release()
|
||||
hops += 1
|
||||
if hops > MAX_REDIRECTS:
|
||||
raise SSRFError(
|
||||
f"too many redirects (> {MAX_REDIRECTS}) for {redact_url(url)}"
|
||||
)
|
||||
current = next_url
|
||||
continue
|
||||
return resp, redact_url(str(resp.url))
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def open_validated(
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
credential_id: Optional[str] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
timeout: aiohttp.ClientTimeout = DEFAULT_TIMEOUT,
|
||||
) -> AsyncIterator[tuple[aiohttp.ClientResponse, str]]:
|
||||
"""Open ``method url`` following redirects manually and validated.
|
||||
|
||||
Yields ``(response, final_url)`` where ``final_url`` is redacted of any
|
||||
query string. The response is released automatically on exit.
|
||||
"""
|
||||
resp, final_url = await _resolve_final_response(
|
||||
method, url, credential_id, dict(headers or {}), timeout
|
||||
)
|
||||
try:
|
||||
yield resp, final_url
|
||||
finally:
|
||||
try:
|
||||
await resp.release()
|
||||
except Exception: # pragma: no cover - best-effort cleanup
|
||||
logging.debug("[model_downloader] response release error", exc_info=True)
|
||||
90
app/model_downloader/net/probe.py
Normal file
90
app/model_downloader/net/probe.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""Pre-download probe (PRD section 5.1).
|
||||
|
||||
Issues a tiny ranged GET (``Range: bytes=0-0``) — which doubles as a
|
||||
range-support test — to discover ``Content-Length``, ``Accept-Ranges``,
|
||||
``ETag``/``Last-Modified``, and the final post-redirect URL. For HuggingFace
|
||||
LFS files the true size also appears in the non-standard ``X-Linked-Size``
|
||||
header, which we read as a fallback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.model_downloader.net.http import open_validated
|
||||
from app.model_downloader.net.session import parse_int_header
|
||||
|
||||
_PROBE_TIMEOUT = aiohttp.ClientTimeout(total=60, sock_connect=30, sock_read=30)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProbeResult:
|
||||
ok: bool
|
||||
status: int
|
||||
final_url: Optional[str] = None
|
||||
total_bytes: Optional[int] = None
|
||||
accept_ranges: bool = False
|
||||
etag: Optional[str] = None
|
||||
last_modified: Optional[str] = None
|
||||
gated: bool = False # 401/403 — needs (or has wrong) credentials
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def _total_from_content_range(value: Optional[str]) -> Optional[int]:
|
||||
# "bytes 0-0/12345" -> 12345 ; "bytes 0-0/*" -> None
|
||||
if not value or "/" not in value:
|
||||
return None
|
||||
total = value.rsplit("/", 1)[1].strip()
|
||||
return parse_int_header(total)
|
||||
|
||||
|
||||
async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult:
|
||||
"""Probe ``url`` and return discovered metadata, failing soft."""
|
||||
try:
|
||||
async with open_validated(
|
||||
"GET",
|
||||
url,
|
||||
credential_id=credential_id,
|
||||
headers={"Range": "bytes=0-0", "Accept-Encoding": "identity"},
|
||||
timeout=_PROBE_TIMEOUT,
|
||||
) as (resp, final_url):
|
||||
if resp.status in (401, 403):
|
||||
return ProbeResult(
|
||||
ok=False, status=resp.status, final_url=final_url, gated=True,
|
||||
error=f"host returned {resp.status} (authentication required)",
|
||||
)
|
||||
if resp.status not in (200, 206):
|
||||
return ProbeResult(
|
||||
ok=False, status=resp.status, final_url=final_url,
|
||||
error=f"probe returned HTTP {resp.status}",
|
||||
)
|
||||
|
||||
headers = resp.headers
|
||||
accept_ranges = False
|
||||
total: Optional[int] = None
|
||||
if resp.status == 206:
|
||||
accept_ranges = True
|
||||
total = _total_from_content_range(headers.get("Content-Range"))
|
||||
else: # 200: server ignored the range
|
||||
accept_ranges = headers.get("Accept-Ranges", "").lower() == "bytes"
|
||||
total = parse_int_header(headers.get("Content-Length"))
|
||||
|
||||
if total is None:
|
||||
total = parse_int_header(headers.get("X-Linked-Size"))
|
||||
|
||||
return ProbeResult(
|
||||
ok=True,
|
||||
status=resp.status,
|
||||
final_url=final_url,
|
||||
total_bytes=total,
|
||||
accept_ranges=accept_ranges,
|
||||
etag=headers.get("ETag"),
|
||||
last_modified=headers.get("Last-Modified"),
|
||||
)
|
||||
except Exception as e: # network / SSRF / timeout
|
||||
logging.debug("[model_downloader] probe failed for %s: %s", url, e)
|
||||
return ProbeResult(ok=False, status=0, error=f"{type(e).__name__}: {e}")
|
||||
72
app/model_downloader/net/session.py
Normal file
72
app/model_downloader/net/session.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""Lazily-created shared :class:`aiohttp.ClientSession`.
|
||||
|
||||
A single session reuses TLS handshakes and TCP connections across the probe
|
||||
and the many segment GETs to the same host (HuggingFace is the dominant
|
||||
case), which is a large speedup on cold connections and exactly the
|
||||
connection-reuse strategy that lets us match aria2c (PRD section 5.2).
|
||||
|
||||
The connector uses :class:`ValidatingResolver` so every connection — initial
|
||||
or post-redirect — is screened for private/special-use IPs at connect time.
|
||||
TLS is pinned to certifi's CA bundle because the OS trust store is not wired
|
||||
up on some Python installs (python.org macOS, slim containers).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ssl
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
try:
|
||||
import certifi
|
||||
_CA_FILE = certifi.where()
|
||||
except Exception: # pragma: no cover - certifi is a transitive dep of aiohttp
|
||||
_CA_FILE = None
|
||||
|
||||
from comfy.cli_args import args
|
||||
from app.model_downloader.security.ssrf import ValidatingResolver
|
||||
|
||||
_session: Optional[aiohttp.ClientSession] = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def ssl_context() -> ssl.SSLContext:
|
||||
if _CA_FILE is not None:
|
||||
return ssl.create_default_context(cafile=_CA_FILE)
|
||||
return ssl.create_default_context()
|
||||
|
||||
|
||||
async def get_session() -> aiohttp.ClientSession:
|
||||
"""Return the shared session, creating it on first use."""
|
||||
global _session
|
||||
if _session is not None and not _session.closed:
|
||||
return _session
|
||||
async with _lock:
|
||||
if _session is None or _session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit_per_host=max(1, getattr(args, "download_max_connections_per_host", 16)),
|
||||
ssl=ssl_context(),
|
||||
resolver=ValidatingResolver(),
|
||||
)
|
||||
_session = aiohttp.ClientSession(connector=connector)
|
||||
return _session
|
||||
|
||||
|
||||
async def close_session() -> None:
|
||||
global _session
|
||||
if _session is not None and not _session.closed:
|
||||
await _session.close()
|
||||
_session = None
|
||||
|
||||
|
||||
def parse_int_header(value: Optional[str]) -> Optional[int]:
|
||||
"""Parse a non-negative integer header value, or None if bad/absent."""
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
n = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return n if n >= 0 else None
|
||||
160
app/model_downloader/scheduler.py
Normal file
160
app/model_downloader/scheduler.py
Normal file
@ -0,0 +1,160 @@
|
||||
"""Priority scheduler + lifecycle (PRD sections 4, 6, 12).
|
||||
|
||||
Owns the set of running jobs and admits queued downloads up to a global
|
||||
concurrency limit (K), highest priority first, FIFO within a priority. Runs
|
||||
entirely on the existing ComfyUI asyncio loop; blocking work (disk, hashing,
|
||||
DB) is offloaded by the job/writer layers.
|
||||
|
||||
On startup it reconciles DB vs. disk: ``active``/``verifying`` rows left by a
|
||||
previous run are reset to ``queued`` and resumed from persisted offsets, and
|
||||
orphaned ``.part`` files with no live download row are swept.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
from comfy.cli_args import args
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.engine.job import DownloadJob, JobSpec
|
||||
from app.model_downloader.security import paths
|
||||
|
||||
# Backoff for retryable failures (PRD section 12).
|
||||
_BACKOFF_BASE = 2.0
|
||||
_BACKOFF_CAP = 300.0
|
||||
_MAX_ATTEMPTS = 6
|
||||
|
||||
|
||||
class Scheduler:
|
||||
def __init__(self) -> None:
|
||||
self._jobs: dict[str, DownloadJob] = {}
|
||||
self._tasks: dict[str, asyncio.Task] = {}
|
||||
self._backoff_until: dict[str, float] = {}
|
||||
self._pump_lock = asyncio.Lock()
|
||||
self._notify_cb: Optional[Callable[[str], None]] = None
|
||||
self._started = False
|
||||
|
||||
@property
|
||||
def max_active(self) -> int:
|
||||
return max(1, getattr(args, "download_max_active", 3))
|
||||
|
||||
def set_notify(self, cb: Optional[Callable[[str], None]]) -> None:
|
||||
self._notify_cb = cb
|
||||
|
||||
def get_job(self, download_id: str) -> Optional[DownloadJob]:
|
||||
return self._jobs.get(download_id)
|
||||
|
||||
def is_active(self, download_id: str) -> bool:
|
||||
return download_id in self._tasks
|
||||
|
||||
# ----- startup -----
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._started:
|
||||
return
|
||||
self._started = True
|
||||
try:
|
||||
await asyncio.to_thread(queries.reconcile_live_downloads)
|
||||
await asyncio.to_thread(self._sweep_orphan_temp_files)
|
||||
except Exception as e:
|
||||
logging.warning("[model_downloader] startup reconcile failed: %s", e)
|
||||
await self.pump()
|
||||
|
||||
@staticmethod
|
||||
def _sweep_orphan_temp_files() -> None:
|
||||
"""Remove ``.part`` files not referenced by a resumable download row.
|
||||
|
||||
Resumable partials (queued/paused rows) are preserved; only truly
|
||||
orphaned temp files from crashed runs are deleted.
|
||||
"""
|
||||
live = {
|
||||
row.temp_path
|
||||
for row in queries.list_downloads()
|
||||
if row.status in (DownloadStatus.QUEUED, DownloadStatus.PAUSED)
|
||||
}
|
||||
for path in paths.iter_all_tmp_paths():
|
||||
if path in live:
|
||||
continue
|
||||
try:
|
||||
os.remove(path)
|
||||
logging.info("[model_downloader] removed orphan temp file: %s", path)
|
||||
except OSError as e:
|
||||
logging.warning("[model_downloader] could not remove %s: %s", path, e)
|
||||
|
||||
# ----- admission -----
|
||||
|
||||
async def pump(self) -> None:
|
||||
async with self._pump_lock:
|
||||
slots = self.max_active - len(self._tasks)
|
||||
if slots <= 0:
|
||||
return
|
||||
now = time.monotonic()
|
||||
candidates = await asyncio.to_thread(queries.list_queued_downloads)
|
||||
for row in candidates:
|
||||
if slots <= 0:
|
||||
break
|
||||
if row.id in self._tasks:
|
||||
continue
|
||||
if self._backoff_until.get(row.id, 0.0) > now:
|
||||
continue
|
||||
self._admit(row)
|
||||
slots -= 1
|
||||
|
||||
def _admit(self, row) -> None:
|
||||
spec = JobSpec(
|
||||
download_id=row.id,
|
||||
url=row.url,
|
||||
model_id=row.model_id,
|
||||
dest_path=row.dest_path,
|
||||
temp_path=row.temp_path,
|
||||
priority=row.priority,
|
||||
credential_id=row.credential_id,
|
||||
expected_sha256=row.expected_sha256,
|
||||
allow_any_extension=row.allow_any_extension,
|
||||
etag=row.etag,
|
||||
attempts=row.attempts,
|
||||
)
|
||||
job = DownloadJob(spec, notify_cb=self._notify_cb)
|
||||
self._jobs[row.id] = job
|
||||
self._tasks[row.id] = asyncio.ensure_future(self._run_job(job))
|
||||
|
||||
async def _run_job(self, job: DownloadJob) -> None:
|
||||
download_id = job.spec.download_id
|
||||
status = DownloadStatus.FAILED
|
||||
try:
|
||||
status = await job.run()
|
||||
except Exception as e: # run() is defensive, but never let a task die silently
|
||||
logging.error("[model_downloader] job %s crashed: %s", download_id, e)
|
||||
finally:
|
||||
self._tasks.pop(download_id, None)
|
||||
self._jobs.pop(download_id, None)
|
||||
|
||||
if status == DownloadStatus.QUEUED:
|
||||
if job.spec.attempts >= _MAX_ATTEMPTS:
|
||||
queries.update_download(
|
||||
download_id,
|
||||
status=DownloadStatus.FAILED,
|
||||
error=f"giving up after {job.spec.attempts} attempts",
|
||||
)
|
||||
if self._notify_cb:
|
||||
self._notify_cb(download_id)
|
||||
else:
|
||||
delay = min(
|
||||
_BACKOFF_CAP, _BACKOFF_BASE ** job.spec.attempts
|
||||
) + random.uniform(0, 1.0)
|
||||
self._backoff_until[download_id] = time.monotonic() + delay
|
||||
asyncio.ensure_future(self._delayed_pump(delay))
|
||||
await self.pump()
|
||||
|
||||
async def _delayed_pump(self, delay: float) -> None:
|
||||
await asyncio.sleep(delay)
|
||||
await self.pump()
|
||||
|
||||
|
||||
SCHEDULER = Scheduler()
|
||||
84
app/model_downloader/security/allowlist.py
Normal file
84
app/model_downloader/security/allowlist.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""URL allowlist for server-side model fetches (PRD section 9.1).
|
||||
|
||||
Default-deny. A URL is downloadable only when its parsed host + scheme are
|
||||
allowlisted AND (unless explicitly relaxed) its final filename ends in a
|
||||
known model extension.
|
||||
|
||||
The built-in host defaults mirror the frontend's ``isModelDownloadable``
|
||||
allowlist so the two flows agree on what is eligible; ``--download-allowed-hosts``
|
||||
extends it for self-hosted mirrors. Matching is done on ``urlparse().hostname``
|
||||
(never a raw string prefix) so userinfo tricks like
|
||||
``http://127.0.0.1@169.254.169.254/x.safetensors`` — whose real host is the
|
||||
metadata IP — cannot slip past.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
# host -> set of allowed schemes. Frontend parity (HuggingFace / Civitai /
|
||||
# localhost). Extra hosts from --download-allowed-hosts are https-only.
|
||||
_DEFAULT_ALLOWED_HOSTS: dict[str, set[str]] = {
|
||||
"huggingface.co": {"https"},
|
||||
"civitai.com": {"https"},
|
||||
"localhost": {"http", "https"},
|
||||
"127.0.0.1": {"http", "https"},
|
||||
}
|
||||
|
||||
# Hosts for which loopback addresses are intentionally permitted (the localhost
|
||||
# "download a local model" feature). Every other host's loopback resolution is
|
||||
# rejected by the SSRF resolver.
|
||||
LOOPBACK_HOSTS = frozenset({"localhost", "127.0.0.1", "::1"})
|
||||
|
||||
# Known model file extensions (frontend parity). Checked on the final filename.
|
||||
ALLOWED_MODEL_EXTENSIONS = (
|
||||
".safetensors",
|
||||
".sft",
|
||||
".ckpt",
|
||||
".pth",
|
||||
".pt",
|
||||
".gguf",
|
||||
".bin",
|
||||
)
|
||||
|
||||
|
||||
def _allowed_hosts() -> dict[str, set[str]]:
|
||||
hosts = {h: set(s) for h, s in _DEFAULT_ALLOWED_HOSTS.items()}
|
||||
for extra in getattr(args, "download_allowed_hosts", []) or []:
|
||||
host = extra.strip().lower()
|
||||
if host:
|
||||
hosts.setdefault(host, set()).add("https")
|
||||
return hosts
|
||||
|
||||
|
||||
def is_host_allowed(host: str | None, scheme: str | None) -> bool:
|
||||
"""True iff ``host`` is allowlisted for ``scheme``.
|
||||
|
||||
Used both for the initial URL and re-checked on every redirect hop
|
||||
(PRD section 9.2), so a whitelisted URL cannot 30x into an off-list host.
|
||||
"""
|
||||
if not host or not scheme:
|
||||
return False
|
||||
allowed = _allowed_hosts().get(host.lower())
|
||||
return allowed is not None and scheme.lower() in allowed
|
||||
|
||||
|
||||
def has_allowed_extension(path: str, allow_any_extension: bool = False) -> bool:
|
||||
if allow_any_extension:
|
||||
return True
|
||||
return path.lower().endswith(ALLOWED_MODEL_EXTENSIONS)
|
||||
|
||||
|
||||
def is_url_allowed(url: str, allow_any_extension: bool = False) -> bool:
|
||||
"""Check whether ``url`` is permitted as a server-side download source."""
|
||||
if not isinstance(url, str) or not url:
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError:
|
||||
return False
|
||||
if not is_host_allowed(parsed.hostname, parsed.scheme):
|
||||
return False
|
||||
return has_allowed_extension(parsed.path, allow_any_extension)
|
||||
110
app/model_downloader/security/paths.py
Normal file
110
app/model_downloader/security/paths.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""Path resolution + traversal safety for downloads (PRD section 9.3).
|
||||
|
||||
A ``model_id`` is a *relative destination path* of the form
|
||||
``<directory>/<filename>`` (e.g. ``loras/my_lora.safetensors``). This module
|
||||
turns one into an absolute on-disk path under one of ComfyUI's registered
|
||||
model folders, rejecting unknown folders, path traversal, and symlink escape.
|
||||
This is the only thing that composes destination paths, so the engine never
|
||||
touches user-supplied path strings directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import folder_paths
|
||||
|
||||
from app.model_downloader.constants import TMP_SUFFIX
|
||||
from app.model_downloader.security.allowlist import ALLOWED_MODEL_EXTENSIONS
|
||||
|
||||
# A model_id component is a single path segment of safe characters — no slashes,
|
||||
# no "..", no leading dots that could escape the target directory.
|
||||
_SEGMENT_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$")
|
||||
|
||||
|
||||
class InvalidModelId(ValueError):
|
||||
"""Raised when a model_id is malformed or names an unknown model folder."""
|
||||
|
||||
|
||||
def parse_model_id(model_id: str, allow_any_extension: bool = False) -> tuple[str, str]:
|
||||
"""Split ``<directory>/<filename>`` and validate both components.
|
||||
|
||||
Returns ``(directory, filename)``. Does not touch the filesystem.
|
||||
"""
|
||||
if not isinstance(model_id, str) or "/" not in model_id:
|
||||
raise InvalidModelId(
|
||||
f"model_id must be '<directory>/<filename>', got {model_id!r}"
|
||||
)
|
||||
directory, _, filename = model_id.partition("/")
|
||||
if "/" in filename or not directory or not filename:
|
||||
raise InvalidModelId(
|
||||
f"model_id must have exactly one '/' separator, got {model_id!r}"
|
||||
)
|
||||
if not _SEGMENT_RE.match(directory):
|
||||
raise InvalidModelId(f"invalid directory segment {directory!r}")
|
||||
if not _SEGMENT_RE.match(filename):
|
||||
raise InvalidModelId(f"invalid filename segment {filename!r}")
|
||||
if not allow_any_extension and not filename.lower().endswith(
|
||||
ALLOWED_MODEL_EXTENSIONS
|
||||
):
|
||||
raise InvalidModelId(
|
||||
f"filename must end with a known model extension "
|
||||
f"{ALLOWED_MODEL_EXTENSIONS}, got {filename!r}"
|
||||
)
|
||||
if directory not in folder_paths.folder_names_and_paths:
|
||||
raise InvalidModelId(f"unknown model folder {directory!r}")
|
||||
return directory, filename
|
||||
|
||||
|
||||
def resolve_existing(model_id: str, allow_any_extension: bool = False) -> Optional[str]:
|
||||
"""Return the absolute path of an installed model, or None if missing.
|
||||
|
||||
Honours ``extra_model_paths.yaml`` transparently via ``get_full_path``.
|
||||
"""
|
||||
directory, filename = parse_model_id(model_id, allow_any_extension)
|
||||
return folder_paths.get_full_path(directory, filename)
|
||||
|
||||
|
||||
def resolve_destination(
|
||||
model_id: str, allow_any_extension: bool = False
|
||||
) -> tuple[str, str]:
|
||||
"""Return ``(final_path, temp_path)`` for a download.
|
||||
|
||||
Downloads land at the first registered path for the model's directory
|
||||
(the "primary" location). ``temp_path`` is a sibling ``.part`` file that
|
||||
is atomically renamed onto ``final_path`` on success. The result is
|
||||
asserted to stay within the registered root (defence in depth on top of
|
||||
the segment regex).
|
||||
"""
|
||||
directory, filename = parse_model_id(model_id, allow_any_extension)
|
||||
roots = folder_paths.get_folder_paths(directory)
|
||||
if not roots:
|
||||
raise InvalidModelId(f"no on-disk path registered for folder {directory!r}")
|
||||
root = os.path.realpath(roots[0])
|
||||
final_path = os.path.realpath(os.path.join(root, filename))
|
||||
if final_path != root and not final_path.startswith(root + os.sep):
|
||||
raise InvalidModelId(f"resolved path escapes model root: {model_id!r}")
|
||||
temp_path = f"{final_path}{TMP_SUFFIX}"
|
||||
return final_path, temp_path
|
||||
|
||||
|
||||
def iter_all_tmp_paths() -> Iterator[str]:
|
||||
"""Yield this subsystem's temp files under every registered model folder.
|
||||
|
||||
Matches only the distinctive ``TMP_SUFFIX`` so the startup orphan sweep
|
||||
can never delete temp files created by other tools.
|
||||
"""
|
||||
seen_roots: set[str] = set()
|
||||
for directory in list(folder_paths.folder_names_and_paths.keys()):
|
||||
for root in folder_paths.get_folder_paths(directory):
|
||||
if root in seen_roots or not os.path.isdir(root):
|
||||
continue
|
||||
seen_roots.add(root)
|
||||
try:
|
||||
for entry in os.scandir(root):
|
||||
if entry.is_file() and entry.name.endswith(TMP_SUFFIX):
|
||||
yield entry.path
|
||||
except OSError:
|
||||
continue
|
||||
111
app/model_downloader/security/ssrf.py
Normal file
111
app/model_downloader/security/ssrf.py
Normal file
@ -0,0 +1,111 @@
|
||||
"""SSRF / exfiltration defenses (PRD section 9.2).
|
||||
|
||||
Two cooperating layers:
|
||||
|
||||
1. :class:`ValidatingResolver` is installed on the shared connector. Every
|
||||
connection — the initial probe and every segment GET, including ones made
|
||||
after a redirect — resolves its host through this resolver, which rejects
|
||||
any address that lands on a private / special-use IP range. Because the
|
||||
resolve and the connect happen together inside the connector, there is no
|
||||
check-then-connect window for DNS rebinding to exploit.
|
||||
|
||||
2. :func:`check_redirect_hop` re-validates every redirect hop. The host
|
||||
allowlist gates only the *initial* user-supplied URL (anti-SSRF for
|
||||
arbitrary input); legitimate downloads from allowlisted origins redirect
|
||||
to presigned CDN hosts that are deliberately NOT on the allowlist (HF ->
|
||||
``cdn-lfs*.huggingface.co``, Civitai -> signed Cloudflare/S3), so hops are
|
||||
instead screened for scheme, embedded credentials, and — via the resolver
|
||||
above — private IPs. Credentials are only ever attached when a hop's host
|
||||
exactly matches a stored credential, so they are dropped on the CDN hop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from aiohttp.abc import AbstractResolver
|
||||
from aiohttp.resolver import DefaultResolver
|
||||
|
||||
from app.model_downloader.security.allowlist import LOOPBACK_HOSTS
|
||||
|
||||
# Cap the redirect chain length and the schemes a hop may use.
|
||||
MAX_REDIRECTS = 5
|
||||
ALLOWED_SCHEMES = ("https", "http")
|
||||
|
||||
|
||||
class SSRFError(Exception):
|
||||
"""A hop failed an SSRF / allowlist check."""
|
||||
|
||||
|
||||
def is_blocked_ip(ip_str: str) -> bool:
|
||||
"""True for any address we refuse to connect to.
|
||||
|
||||
Covers loopback, link-local (incl. 169.254.169.254 cloud metadata),
|
||||
RFC1918 private ranges, unique-local (ULA), unspecified (0.0.0.0/::),
|
||||
multicast and other reserved ranges.
|
||||
"""
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return True # unparseable -> refuse
|
||||
return (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
or ip.is_reserved
|
||||
or ip.is_unspecified
|
||||
)
|
||||
|
||||
|
||||
class ValidatingResolver(AbstractResolver):
|
||||
"""Delegating resolver that drops blocked IPs from every resolution.
|
||||
|
||||
If a hostname resolves only to blocked addresses, the connection fails
|
||||
closed with an :class:`OSError`, which aiohttp surfaces as a connection
|
||||
error to the caller.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._inner = DefaultResolver()
|
||||
|
||||
async def resolve(self, host, port=0, family=socket.AF_INET):
|
||||
infos = await self._inner.resolve(host, port, family)
|
||||
# localhost/127.0.0.1 are an explicit, opt-in allowlist feature.
|
||||
if isinstance(host, str) and host.lower() in LOOPBACK_HOSTS:
|
||||
return infos
|
||||
safe = [info for info in infos if not is_blocked_ip(info["host"])]
|
||||
if not safe:
|
||||
raise OSError(
|
||||
f"refusing to connect to {host!r}: resolves only to "
|
||||
f"private/special-use addresses"
|
||||
)
|
||||
return safe
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._inner.close()
|
||||
|
||||
|
||||
def check_redirect_hop(url: str) -> str:
|
||||
"""Validate one redirect hop's URL.
|
||||
|
||||
Returns the URL unchanged on success; raises :class:`SSRFError` otherwise.
|
||||
Enforces an allowed scheme and forbids credentials-in-URL. The host is NOT
|
||||
re-checked against the allowlist (CDN redirect targets are off-list by
|
||||
design); private-IP protection is provided by the connector's resolver,
|
||||
and credential leakage is prevented by exact host matching at attach time.
|
||||
The landing filename's extension is gated separately by the caller.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError as e:
|
||||
raise SSRFError(f"unparseable redirect URL {url!r}: {e}") from e
|
||||
if parsed.scheme.lower() not in ALLOWED_SCHEMES:
|
||||
raise SSRFError(f"redirect to disallowed scheme {parsed.scheme!r}")
|
||||
if parsed.username or parsed.password:
|
||||
raise SSRFError("credentials-in-URL are not allowed")
|
||||
if not parsed.hostname:
|
||||
raise SSRFError(f"redirect URL has no host: {url!r}")
|
||||
return url
|
||||
49
app/model_downloader/verify/checksum.py
Normal file
49
app/model_downloader/verify/checksum.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""Hub-checksum verification = SHA256 (PRD section 8.1).
|
||||
|
||||
Only used to confirm a download matches a *provided* ``expected_sha256``. It
|
||||
is NOT the dedup key (that is blake3, owned by the assets system). The full
|
||||
sequential read happens at most once, here, only when a checksum was supplied.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Callable, Optional
|
||||
|
||||
_CHUNK = 8 * 1024 * 1024
|
||||
|
||||
InterruptCheck = Callable[[], bool]
|
||||
|
||||
|
||||
class ChecksumError(Exception):
|
||||
"""The computed SHA256 did not match the expected value."""
|
||||
|
||||
|
||||
def sha256_file(path: str, interrupt_check: Optional[InterruptCheck] = None) -> Optional[str]:
|
||||
"""Stream the file and return its lowercase hex SHA256.
|
||||
|
||||
Returns ``None`` if interrupted via ``interrupt_check``.
|
||||
"""
|
||||
h = hashlib.sha256()
|
||||
with open(path, "rb") as f:
|
||||
while True:
|
||||
if interrupt_check is not None and interrupt_check():
|
||||
return None
|
||||
chunk = f.read(_CHUNK)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def verify_sha256(
|
||||
path: str, expected: str, interrupt_check: Optional[InterruptCheck] = None
|
||||
) -> None:
|
||||
"""Raise :class:`ChecksumError` unless the file's SHA256 matches ``expected``."""
|
||||
actual = sha256_file(path, interrupt_check)
|
||||
if actual is None:
|
||||
return # interrupted; caller will re-verify on resume
|
||||
if actual.lower() != expected.lower():
|
||||
raise ChecksumError(
|
||||
f"sha256 mismatch: expected {expected.lower()}, got {actual.lower()}"
|
||||
)
|
||||
53
app/model_downloader/verify/dedup.py
Normal file
53
app/model_downloader/verify/dedup.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""Dedup + catalog handoff — reuse the assets system (PRD section 8.5).
|
||||
|
||||
We do NOT build a parallel indexer. "Do I already have it?" is answered by
|
||||
``resolve_existing`` (path) at enqueue time and, where a hash is known, by the
|
||||
assets blake3 catalog. After a completed download we register the file
|
||||
through the assets ingest path so it is cataloged and (eventually) hashed by
|
||||
the existing enrichment worker.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _register_sync(abs_path: str) -> Optional[str]:
|
||||
"""Register a finished file into the assets catalog. Returns asset hash."""
|
||||
try:
|
||||
from app.assets.services.ingest import register_file_in_place
|
||||
except Exception as e: # assets package import failure — non-fatal
|
||||
logging.debug("[model_downloader] assets ingest unavailable: %s", e)
|
||||
return None
|
||||
try:
|
||||
result = register_file_in_place(abs_path, name=os.path.basename(abs_path), tags=[])
|
||||
return result.asset.hash if result and result.asset else None
|
||||
except Exception as e:
|
||||
# The file is already safely on disk; cataloging is best-effort.
|
||||
logging.warning(
|
||||
"[model_downloader] could not register %s into assets catalog: %s",
|
||||
abs_path, e,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def register_completed(abs_path: str) -> Optional[str]:
|
||||
"""Catalog a completed download via the assets system (off the event loop)."""
|
||||
return await asyncio.to_thread(_register_sync, abs_path)
|
||||
|
||||
|
||||
def _find_by_hash_sync(blake3_hex: str) -> Optional[str]:
|
||||
try:
|
||||
from app.assets.services.asset_management import get_asset_by_hash
|
||||
except Exception:
|
||||
return None
|
||||
asset = get_asset_by_hash("blake3:" + blake3_hex)
|
||||
return asset.hash if asset is not None else None
|
||||
|
||||
|
||||
async def find_existing_by_hash(blake3_hex: str) -> Optional[str]:
|
||||
"""Pure DB lookup — never triggers hashing on the hot path."""
|
||||
return await asyncio.to_thread(_find_by_hash_sync, blake3_hex)
|
||||
65
app/model_downloader/verify/structural.py
Normal file
65
app/model_downloader/verify/structural.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""Cheap structural validation, no full read (PRD section 8.2).
|
||||
|
||||
For ``.safetensors``/``.sft`` we parse the header (first few KB): it carries
|
||||
the tensor table and the byte length of the data region. We assert
|
||||
``file_size == 8 + header_len + data_region_len``. This detects truncation
|
||||
and most corruption for free, before any crypto hashing. Other extensions
|
||||
have no cheap structural check and pass through.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
|
||||
_SAFETENSORS_EXTS = (".safetensors", ".sft")
|
||||
# A sane upper bound so a corrupt header length can't make us read gigabytes.
|
||||
_MAX_HEADER_BYTES = 100 * 1024 * 1024
|
||||
|
||||
|
||||
class StructuralError(Exception):
|
||||
"""The file failed its structural integrity check."""
|
||||
|
||||
|
||||
def validate(path: str) -> None:
|
||||
"""Validate the file at ``path``. Raises :class:`StructuralError` on failure."""
|
||||
lower = path.lower()
|
||||
if lower.endswith(_SAFETENSORS_EXTS):
|
||||
_validate_safetensors(path)
|
||||
# No structural check for other formats; the size + (optional) checksum
|
||||
# gates in the engine cover those.
|
||||
|
||||
|
||||
def _validate_safetensors(path: str) -> None:
|
||||
file_size = os.path.getsize(path)
|
||||
if file_size < 8:
|
||||
raise StructuralError(f"file too small to be safetensors ({file_size} bytes)")
|
||||
with open(path, "rb") as f:
|
||||
header_len = struct.unpack("<Q", f.read(8))[0]
|
||||
if header_len <= 0 or header_len > _MAX_HEADER_BYTES:
|
||||
raise StructuralError(f"implausible safetensors header length {header_len}")
|
||||
if 8 + header_len > file_size:
|
||||
raise StructuralError("safetensors header extends past end of file")
|
||||
try:
|
||||
header = json.loads(f.read(header_len).decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
||||
raise StructuralError(f"safetensors header is not valid JSON: {e}") from e
|
||||
|
||||
data_len = 0
|
||||
for name, entry in header.items():
|
||||
if name == "__metadata__":
|
||||
continue
|
||||
if not isinstance(entry, dict) or "data_offsets" not in entry:
|
||||
raise StructuralError(f"tensor {name!r} missing data_offsets")
|
||||
offsets = entry["data_offsets"]
|
||||
if not (isinstance(offsets, list) and len(offsets) == 2):
|
||||
raise StructuralError(f"tensor {name!r} has malformed data_offsets")
|
||||
data_len = max(data_len, int(offsets[1]))
|
||||
|
||||
expected = 8 + header_len + data_len
|
||||
if file_size != expected:
|
||||
raise StructuralError(
|
||||
f"size mismatch: file is {file_size} bytes, header implies {expected} "
|
||||
f"(8 + {header_len} header + {data_len} data)"
|
||||
)
|
||||
@ -243,6 +243,14 @@ parser.add_argument("--enable-assets", action="store_true", help="Enable the ass
|
||||
parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
|
||||
parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")
|
||||
|
||||
# ----- Model download manager (PRD: docs/prd-download-manager.md) -----
|
||||
parser.add_argument("--download-segments", type=int, default=8, metavar="N", help="Number of parallel HTTP range segments per file for the model download manager (default: 8).")
|
||||
parser.add_argument("--download-max-active", type=int, default=3, metavar="N", help="Maximum number of model downloads running concurrently (default: 3).")
|
||||
parser.add_argument("--download-max-connections-per-host", type=int, default=16, metavar="N", help="Maximum simultaneous connections to a single host for the download manager (default: 16).")
|
||||
parser.add_argument("--download-chunk-size", type=int, default=4 * 1024 * 1024, metavar="BYTES", help="Read chunk size in bytes for the download manager (default: 4 MiB).")
|
||||
parser.add_argument("--download-allowed-hosts", type=str, nargs="*", default=[], metavar="HOST", help="Additional hostnames to add to the download manager allowlist (https only). The built-in defaults always include huggingface.co and civitai.com.")
|
||||
parser.add_argument("--download-allow-any-extension", action="store_true", help="Allow the download manager to fetch files with any extension (default: only known model extensions like .safetensors).")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
|
||||
@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
|
||||
if (want_requant and len(fns) == 0 or update_weight):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
if isinstance(orig, QuantizedTensor):
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed)
|
||||
else:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
|
||||
if want_requant and len(fns) == 0:
|
||||
@ -1306,8 +1306,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
# dtype is now implicit in the layout class
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
|
||||
@ -104,6 +104,7 @@ _CORE_FEATURE_FLAGS: dict[str, Any] = {
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
"node_replacements": True,
|
||||
"assets": args.enable_assets,
|
||||
"server_side_model_downloads": True,
|
||||
}
|
||||
|
||||
# CLI-provided flags cannot overwrite core flags
|
||||
|
||||
@ -1,85 +1,68 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import ctypes
|
||||
import logging
|
||||
import ctypes.util
|
||||
import importlib.util
|
||||
from typing import TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import nodes
|
||||
import comfy_angle
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from typing_extensions import override
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _check_opengl_availability():
|
||||
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
|
||||
logger.debug("_check_opengl_availability: starting")
|
||||
missing = []
|
||||
def _preload_angle():
|
||||
egl_path = comfy_angle.get_egl_path()
|
||||
gles_path = comfy_angle.get_glesv2_path()
|
||||
|
||||
# Check Python packages (using find_spec to avoid importing)
|
||||
logger.debug("_check_opengl_availability: checking for glfw package")
|
||||
if importlib.util.find_spec("glfw") is None:
|
||||
missing.append("glfw")
|
||||
if sys.platform == "win32":
|
||||
angle_dir = comfy_angle.get_lib_dir()
|
||||
os.add_dll_directory(angle_dir)
|
||||
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
|
||||
|
||||
logger.debug("_check_opengl_availability: checking for OpenGL package")
|
||||
if importlib.util.find_spec("OpenGL") is None:
|
||||
missing.append("PyOpenGL")
|
||||
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
||||
)
|
||||
|
||||
# On Linux without display, check if headless backends are available
|
||||
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
|
||||
if sys.platform.startswith("linux"):
|
||||
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
|
||||
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
|
||||
if not has_display:
|
||||
# Check for EGL or OSMesa libraries
|
||||
logger.debug("_check_opengl_availability: checking for EGL library")
|
||||
has_egl = ctypes.util.find_library("EGL")
|
||||
logger.debug("_check_opengl_availability: checking for OSMesa library")
|
||||
has_osmesa = ctypes.util.find_library("OSMesa")
|
||||
|
||||
# Error disabled for CI as it fails this check
|
||||
# if not has_egl and not has_osmesa:
|
||||
# raise RuntimeError(
|
||||
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
|
||||
# "See error below for installation instructions."
|
||||
# )
|
||||
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
|
||||
|
||||
logger.debug("_check_opengl_availability: completed")
|
||||
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
|
||||
ctypes.CDLL(str(egl_path), mode=mode)
|
||||
ctypes.CDLL(str(gles_path), mode=mode)
|
||||
|
||||
|
||||
# Run early check at import time
|
||||
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
|
||||
_check_opengl_availability()
|
||||
|
||||
# OpenGL modules - initialized lazily when context is created
|
||||
gl = None
|
||||
glfw = None
|
||||
EGL = None
|
||||
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
|
||||
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
|
||||
_preload_angle()
|
||||
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
|
||||
|
||||
def _import_opengl():
|
||||
"""Import OpenGL module. Called after context is created."""
|
||||
global gl
|
||||
if gl is None:
|
||||
logger.debug("_import_opengl: importing OpenGL.GL")
|
||||
import OpenGL.GL as _gl
|
||||
gl = _gl
|
||||
logger.debug("_import_opengl: import completed")
|
||||
return gl
|
||||
import OpenGL
|
||||
OpenGL.USE_ACCELERATE = False
|
||||
|
||||
|
||||
def _patch_find_library():
|
||||
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
|
||||
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
|
||||
'libGLESv2'. Patch find_library to return the full ANGLE paths so
|
||||
PyOpenGL loads the same libraries we pre-loaded."""
|
||||
if sys.platform == "linux":
|
||||
return
|
||||
import ctypes.util
|
||||
_orig = ctypes.util.find_library
|
||||
def _patched(name):
|
||||
if name == 'EGL':
|
||||
return comfy_angle.get_egl_path()
|
||||
if name == 'GLESv2':
|
||||
return comfy_angle.get_glesv2_path()
|
||||
return _orig(name)
|
||||
ctypes.util.find_library = _patched
|
||||
|
||||
|
||||
_patch_find_library()
|
||||
|
||||
from OpenGL import EGL
|
||||
from OpenGL import GLES3 as gl
|
||||
|
||||
class SizeModeInput(TypedDict):
|
||||
size_mode: str
|
||||
width: int
|
||||
@ -102,7 +85,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
# (-1,-1)---(3,-1)
|
||||
#
|
||||
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
||||
VERTEX_SHADER = """#version 330 core
|
||||
VERTEX_SHADER = """#version 300 es
|
||||
out vec2 v_texCoord;
|
||||
void main() {
|
||||
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
||||
@ -126,14 +109,99 @@ void main() {
|
||||
"""
|
||||
|
||||
|
||||
def _convert_es_to_desktop(source: str) -> str:
|
||||
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
||||
# Remove any existing #version directive
|
||||
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
||||
# Remove precision qualifiers (not needed in desktop GLSL)
|
||||
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
||||
# Prepend desktop GLSL version
|
||||
return "#version 330 core\n" + source
|
||||
|
||||
def _egl_attribs(*values):
|
||||
"""Build an EGL_NONE-terminated EGLint attribute array."""
|
||||
vals = list(values) + [EGL.EGL_NONE]
|
||||
return (ctypes.c_int32 * len(vals))(*vals)
|
||||
|
||||
|
||||
# EGL platform extension constants
|
||||
EGL_PLATFORM_ANGLE_ANGLE = 0x3202
|
||||
EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203
|
||||
EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450
|
||||
EGL_MESA_PLATFORM_SURFACELESS = 0x31DD
|
||||
|
||||
|
||||
_eglGetPlatformDisplayEXT = None
|
||||
|
||||
def _get_egl_platform_display_ext(platform, native_display, attribs):
|
||||
"""Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL)."""
|
||||
global _eglGetPlatformDisplayEXT
|
||||
if _eglGetPlatformDisplayEXT is None:
|
||||
from OpenGL import platform as _plat
|
||||
egl_lib = _plat.PLATFORM.EGL
|
||||
_get_proc = egl_lib.eglGetProcAddress
|
||||
_get_proc.restype = ctypes.c_void_p
|
||||
_get_proc.argtypes = [ctypes.c_char_p]
|
||||
ptr = _get_proc(b"eglGetPlatformDisplayEXT")
|
||||
if not ptr:
|
||||
return None
|
||||
func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p)
|
||||
_eglGetPlatformDisplayEXT = func_type(ptr)
|
||||
|
||||
raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs)
|
||||
if not raw:
|
||||
return None
|
||||
return ctypes.cast(raw, EGL.EGLDisplay)
|
||||
|
||||
|
||||
def _get_egl_display():
|
||||
"""Get an EGL display, trying the default first then ANGLE's Vulkan
|
||||
platform for headless environments without a display server."""
|
||||
failures = []
|
||||
|
||||
# Try the default display first (works when X11/Wayland is available)
|
||||
display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
|
||||
if display:
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
try:
|
||||
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
return display, major.value, minor.value
|
||||
except Exception as e:
|
||||
failures.append(f"default: {e}")
|
||||
|
||||
logger.info("Default EGL display unavailable, trying headless fallbacks")
|
||||
|
||||
# Headless fallback strategies, tried in order:
|
||||
headless_strategies = [
|
||||
("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None),
|
||||
("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None,
|
||||
_egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)),
|
||||
]
|
||||
|
||||
for name, platform, native_display, attribs in headless_strategies:
|
||||
display = _get_egl_platform_display_ext(platform, native_display, attribs)
|
||||
if not display:
|
||||
failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display")
|
||||
continue
|
||||
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
||||
try:
|
||||
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
|
||||
logger.info(f"Using EGL {name} platform (headless)")
|
||||
return display, major.value, minor.value
|
||||
failures.append(f"{name}: eglInitialize returned false")
|
||||
except Exception as e:
|
||||
failures.append(f"{name}: {e}")
|
||||
continue
|
||||
|
||||
details = "\n".join(f" - {f}" for f in failures)
|
||||
raise RuntimeError(
|
||||
"Failed to initialize EGL display.\n"
|
||||
"No display server and no headless EGL platform available.\n"
|
||||
f"Tried:\n{details}\n"
|
||||
"Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer."
|
||||
)
|
||||
|
||||
|
||||
def _gl_str(name):
|
||||
"""Get an OpenGL string parameter."""
|
||||
v = gl.glGetString(name)
|
||||
if not v:
|
||||
return "Unknown"
|
||||
if isinstance(v, bytes):
|
||||
return v.decode(errors="replace")
|
||||
return ctypes.string_at(v).decode(errors="replace")
|
||||
|
||||
|
||||
def _detect_output_count(source: str) -> int:
|
||||
@ -159,163 +227,8 @@ def _detect_pass_count(source: str) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _init_glfw():
|
||||
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_glfw: starting")
|
||||
# On macOS, glfw.init() must be called from main thread or it hangs forever
|
||||
if sys.platform == "darwin":
|
||||
logger.debug("_init_glfw: skipping on macOS")
|
||||
raise RuntimeError("GLFW backend not supported on macOS")
|
||||
|
||||
logger.debug("_init_glfw: importing glfw module")
|
||||
import glfw as _glfw
|
||||
|
||||
logger.debug("_init_glfw: calling glfw.init()")
|
||||
if not _glfw.init():
|
||||
raise RuntimeError("glfw.init() failed")
|
||||
|
||||
try:
|
||||
logger.debug("_init_glfw: setting window hints")
|
||||
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
|
||||
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
|
||||
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
|
||||
|
||||
logger.debug("_init_glfw: calling create_window()")
|
||||
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
||||
if not window:
|
||||
raise RuntimeError("glfw.create_window() failed")
|
||||
|
||||
logger.debug("_init_glfw: calling make_context_current()")
|
||||
_glfw.make_context_current(window)
|
||||
logger.debug("_init_glfw: completed successfully")
|
||||
return window, _glfw
|
||||
except Exception:
|
||||
logger.debug("_init_glfw: failed, terminating glfw")
|
||||
_glfw.terminate()
|
||||
raise
|
||||
|
||||
|
||||
def _init_egl():
|
||||
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
|
||||
logger.debug("_init_egl: starting")
|
||||
from OpenGL import EGL as _EGL
|
||||
from OpenGL.EGL import (
|
||||
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
|
||||
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
|
||||
eglTerminate, eglDestroyContext, eglDestroySurface,
|
||||
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
|
||||
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
|
||||
)
|
||||
logger.debug("_init_egl: imports completed")
|
||||
|
||||
display = None
|
||||
context = None
|
||||
surface = None
|
||||
|
||||
try:
|
||||
logger.debug("_init_egl: calling eglGetDisplay()")
|
||||
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
|
||||
if display == _EGL.EGL_NO_DISPLAY:
|
||||
raise RuntimeError("eglGetDisplay() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglInitialize()")
|
||||
major, minor = _EGL.EGLint(), _EGL.EGLint()
|
||||
if not eglInitialize(display, major, minor):
|
||||
display = None # Not initialized, don't terminate
|
||||
raise RuntimeError("eglInitialize() failed")
|
||||
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
|
||||
|
||||
config_attribs = [
|
||||
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
|
||||
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
||||
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
|
||||
EGL_DEPTH_SIZE, 0, EGL_NONE
|
||||
]
|
||||
configs = (_EGL.EGLConfig * 1)()
|
||||
num_configs = _EGL.EGLint()
|
||||
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
config = configs[0]
|
||||
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
|
||||
|
||||
if not eglBindAPI(EGL_OPENGL_API):
|
||||
raise RuntimeError("eglBindAPI() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreateContext()")
|
||||
context_attribs = [
|
||||
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
|
||||
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
|
||||
EGL_NONE
|
||||
]
|
||||
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
|
||||
if context == EGL_NO_CONTEXT:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
|
||||
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
|
||||
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
|
||||
if surface == _EGL.EGL_NO_SURFACE:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
logger.debug("_init_egl: calling eglMakeCurrent()")
|
||||
if not eglMakeCurrent(display, surface, surface, context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_egl: completed successfully")
|
||||
return display, context, surface, _EGL
|
||||
|
||||
except Exception:
|
||||
logger.debug("_init_egl: failed, cleaning up")
|
||||
# Clean up any resources on failure
|
||||
if surface is not None:
|
||||
eglDestroySurface(display, surface)
|
||||
if context is not None:
|
||||
eglDestroyContext(display, context)
|
||||
if display is not None:
|
||||
eglTerminate(display)
|
||||
raise
|
||||
|
||||
|
||||
def _init_osmesa():
|
||||
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
|
||||
import ctypes
|
||||
|
||||
logger.debug("_init_osmesa: starting")
|
||||
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
||||
|
||||
logger.debug("_init_osmesa: importing OpenGL.osmesa")
|
||||
from OpenGL import GL as _gl
|
||||
from OpenGL.osmesa import (
|
||||
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
|
||||
OSMESA_RGBA,
|
||||
)
|
||||
logger.debug("_init_osmesa: imports completed")
|
||||
|
||||
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
|
||||
if not ctx:
|
||||
raise RuntimeError("OSMesaCreateContextExt() failed")
|
||||
|
||||
width, height = 64, 64
|
||||
buffer = (ctypes.c_ubyte * (width * height * 4))()
|
||||
|
||||
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
|
||||
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
|
||||
OSMesaDestroyContext(ctx)
|
||||
raise RuntimeError("OSMesaMakeCurrent() failed")
|
||||
|
||||
logger.debug("_init_osmesa: completed successfully")
|
||||
return ctx, buffer
|
||||
|
||||
|
||||
class GLContext:
|
||||
"""Manages OpenGL context and resources for shader execution.
|
||||
|
||||
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
|
||||
"""
|
||||
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
@ -327,131 +240,105 @@ class GLContext:
|
||||
|
||||
def __init__(self):
|
||||
if GLContext._initialized:
|
||||
logger.debug("GLContext.__init__: already initialized, skipping")
|
||||
return
|
||||
|
||||
logger.debug("GLContext.__init__: starting initialization")
|
||||
|
||||
global glfw, EGL
|
||||
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
|
||||
self._backend = None
|
||||
self._window = None
|
||||
self._egl_display = None
|
||||
self._egl_context = None
|
||||
self._egl_surface = None
|
||||
self._osmesa_ctx = None
|
||||
self._osmesa_buffer = None
|
||||
self._display = None
|
||||
self._surface = None
|
||||
self._context = None
|
||||
self._vao = None
|
||||
|
||||
# Try backends in order: GLFW → EGL → OSMesa
|
||||
errors = []
|
||||
|
||||
logger.debug("GLContext.__init__: trying GLFW backend")
|
||||
try:
|
||||
self._window, glfw = _init_glfw()
|
||||
self._backend = "glfw"
|
||||
logger.debug("GLContext.__init__: GLFW backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
|
||||
errors.append(("GLFW", e))
|
||||
self._display, self._egl_major, self._egl_minor = _get_egl_display()
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying EGL backend")
|
||||
try:
|
||||
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
|
||||
self._backend = "egl"
|
||||
logger.debug("GLContext.__init__: EGL backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
|
||||
errors.append(("EGL", e))
|
||||
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
|
||||
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
|
||||
|
||||
if self._backend is None:
|
||||
logger.debug("GLContext.__init__: trying OSMesa backend")
|
||||
try:
|
||||
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
|
||||
self._backend = "osmesa"
|
||||
logger.debug("GLContext.__init__: OSMesa backend succeeded")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
|
||||
errors.append(("OSMesa", e))
|
||||
config = EGL.EGLConfig()
|
||||
n_configs = ctypes.c_int32(0)
|
||||
if not EGL.eglChooseConfig(
|
||||
self._display,
|
||||
_egl_attribs(
|
||||
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
|
||||
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
|
||||
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
|
||||
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
|
||||
),
|
||||
ctypes.byref(config), 1, ctypes.byref(n_configs),
|
||||
) or n_configs.value == 0:
|
||||
raise RuntimeError("eglChooseConfig() failed")
|
||||
|
||||
if self._backend is None:
|
||||
if sys.platform == "win32":
|
||||
platform_help = (
|
||||
"Windows: Ensure GPU drivers are installed and display is available.\n"
|
||||
" CPU-only/headless mode is not supported on Windows."
|
||||
)
|
||||
elif sys.platform == "darwin":
|
||||
platform_help = (
|
||||
"macOS: GLFW is not supported.\n"
|
||||
" Install OSMesa via Homebrew: brew install mesa\n"
|
||||
" Then: pip install PyOpenGL PyOpenGL-accelerate"
|
||||
)
|
||||
else:
|
||||
platform_help = (
|
||||
"Linux: Install one of these backends:\n"
|
||||
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
|
||||
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
|
||||
" Headless (CPU): sudo apt install libosmesa6"
|
||||
)
|
||||
|
||||
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenGL context.\n\n"
|
||||
f"Backend errors:\n{error_details}\n\n"
|
||||
f"{platform_help}"
|
||||
self._surface = EGL.eglCreatePbufferSurface(
|
||||
self._display, config,
|
||||
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
|
||||
)
|
||||
if not self._surface:
|
||||
raise RuntimeError("eglCreatePbufferSurface() failed")
|
||||
|
||||
# Now import OpenGL.GL (after context is current)
|
||||
logger.debug("GLContext.__init__: importing OpenGL.GL")
|
||||
_import_opengl()
|
||||
self._context = EGL.eglCreateContext(
|
||||
self._display, config, EGL.EGL_NO_CONTEXT,
|
||||
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
|
||||
)
|
||||
if not self._context:
|
||||
raise RuntimeError("eglCreateContext() failed")
|
||||
|
||||
# Create VAO (required for core profile, but OSMesa may use compat profile)
|
||||
logger.debug("GLContext.__init__: creating VAO")
|
||||
try:
|
||||
vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(vao)
|
||||
self._vao = vao # Only store after successful bind
|
||||
logger.debug("GLContext.__init__: VAO created successfully")
|
||||
except Exception as e:
|
||||
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
|
||||
# OSMesa with older Mesa may not support VAOs
|
||||
# Clean up if we created but couldn't bind
|
||||
if vao:
|
||||
try:
|
||||
gl.glDeleteVertexArrays(1, [vao])
|
||||
except Exception:
|
||||
pass
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
raise RuntimeError("eglMakeCurrent() failed")
|
||||
|
||||
self._vao = gl.glGenVertexArrays(1)
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
except Exception:
|
||||
self._cleanup()
|
||||
raise
|
||||
|
||||
elapsed = (time.perf_counter() - start) * 1000
|
||||
|
||||
# Log device info
|
||||
renderer = gl.glGetString(gl.GL_RENDERER)
|
||||
vendor = gl.glGetString(gl.GL_VENDOR)
|
||||
version = gl.glGetString(gl.GL_VERSION)
|
||||
renderer = renderer.decode() if renderer else "Unknown"
|
||||
vendor = vendor.decode() if vendor else "Unknown"
|
||||
version = version.decode() if version else "Unknown"
|
||||
renderer = _gl_str(gl.GL_RENDERER)
|
||||
vendor = _gl_str(gl.GL_VENDOR)
|
||||
version = _gl_str(gl.GL_VERSION)
|
||||
|
||||
GLContext._initialized = True
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
|
||||
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}")
|
||||
|
||||
def make_current(self):
|
||||
if self._backend == "glfw":
|
||||
glfw.make_context_current(self._window)
|
||||
elif self._backend == "egl":
|
||||
from OpenGL.EGL import eglMakeCurrent
|
||||
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
|
||||
elif self._backend == "osmesa":
|
||||
from OpenGL.osmesa import OSMesaMakeCurrent
|
||||
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
|
||||
|
||||
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
||||
err = EGL.eglGetError()
|
||||
raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})")
|
||||
if self._vao is not None:
|
||||
gl.glBindVertexArray(self._vao)
|
||||
|
||||
def _cleanup(self):
|
||||
if not self._display:
|
||||
return
|
||||
try:
|
||||
if self._vao is not None:
|
||||
gl.glDeleteVertexArrays(1, [self._vao])
|
||||
self._vao = None
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._context:
|
||||
EGL.eglDestroyContext(self._display, self._context)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._surface:
|
||||
EGL.eglDestroySurface(self._display, self._surface)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
EGL.eglTerminate(self._display)
|
||||
except Exception:
|
||||
pass
|
||||
self._display = None
|
||||
|
||||
|
||||
def _compile_shader(source: str, shader_type: int) -> int:
|
||||
"""Compile a shader and return its ID."""
|
||||
@ -459,8 +346,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
|
||||
gl.glShaderSource(shader, source)
|
||||
gl.glCompileShader(shader)
|
||||
|
||||
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetShaderInfoLog(shader).decode()
|
||||
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
|
||||
error = gl.glGetShaderInfoLog(shader)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteShader(shader)
|
||||
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
||||
|
||||
@ -484,8 +373,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
|
||||
gl.glDeleteShader(vertex_shader)
|
||||
gl.glDeleteShader(fragment_shader)
|
||||
|
||||
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
||||
error = gl.glGetProgramInfoLog(program).decode()
|
||||
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
|
||||
error = gl.glGetProgramInfoLog(program)
|
||||
if isinstance(error, bytes):
|
||||
error = error.decode(errors="replace")
|
||||
gl.glDeleteProgram(program)
|
||||
raise RuntimeError(f"Program linking failed:\n{error}")
|
||||
|
||||
@ -530,9 +421,6 @@ def _render_shader_batch(
|
||||
ctx = GLContext()
|
||||
ctx.make_current()
|
||||
|
||||
# Convert from GLSL ES to desktop GLSL 330
|
||||
fragment_source = _convert_es_to_desktop(fragment_code)
|
||||
|
||||
# Detect how many outputs the shader actually uses
|
||||
num_outputs = _detect_output_count(fragment_code)
|
||||
|
||||
@ -558,9 +446,9 @@ def _render_shader_batch(
|
||||
try:
|
||||
# Compile shaders (once for all batches)
|
||||
try:
|
||||
program = _create_program(VERTEX_SHADER, fragment_source)
|
||||
program = _create_program(VERTEX_SHADER, fragment_code)
|
||||
except RuntimeError:
|
||||
logger.error(f"Fragment shader:\n{fragment_source}")
|
||||
logger.error(f"Fragment shader:\n{fragment_code}")
|
||||
raise
|
||||
|
||||
gl.glUseProgram(program)
|
||||
@ -723,13 +611,13 @@ def _render_shader_batch(
|
||||
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
||||
|
||||
# Read back outputs for this batch
|
||||
# (glGetTexImage is synchronous, implicitly waits for rendering)
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
||||
batch_outputs = []
|
||||
for tex in output_textures:
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
||||
batch_outputs.append(img[::-1, :, :].copy())
|
||||
for i in range(num_outputs):
|
||||
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
|
||||
buf = np.empty((height, width, 4), dtype=np.float32)
|
||||
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
|
||||
batch_outputs.append(buf[::-1, :, :].copy())
|
||||
|
||||
# Pad with black images for unused outputs
|
||||
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
||||
@ -750,18 +638,18 @@ def _render_shader_batch(
|
||||
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
||||
gl.glUseProgram(0)
|
||||
|
||||
for tex in input_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in curve_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in output_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in ping_pong_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
if input_textures:
|
||||
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||
if curve_textures:
|
||||
gl.glDeleteTextures(len(curve_textures), curve_textures)
|
||||
if output_textures:
|
||||
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||
if ping_pong_textures:
|
||||
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
|
||||
if fbo is not None:
|
||||
gl.glDeleteFramebuffers(1, [fbo])
|
||||
for pp_fbo in ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(1, [pp_fbo])
|
||||
if ping_pong_fbos:
|
||||
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
|
||||
if program is not None:
|
||||
gl.glDeleteProgram(program)
|
||||
|
||||
|
||||
55
execution.py
55
execution.py
@ -1113,32 +1113,6 @@ def full_type_name(klass):
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
def node_not_executable_reason(class_def, class_type):
|
||||
"""Return a human-readable reason the node cannot be executed, or None if it's fine.
|
||||
|
||||
Catches a node whose declared entry point doesn't resolve to a real method
|
||||
(e.g. a V1 ``FUNCTION = "invert"`` where the method is misspelled, or a V3 node
|
||||
missing its ``execute`` override). Running this during validation surfaces the
|
||||
problem before execution starts, instead of after upstream nodes have run.
|
||||
|
||||
Only the class is inspected; the node is never instantiated here, so a node's
|
||||
``__init__`` side effects cannot run (or fail) during validation.
|
||||
"""
|
||||
try:
|
||||
if issubclass(class_def, _ComfyNodeInternal):
|
||||
# V3: validates that execute()/define_schema() overrides exist.
|
||||
class_def.VALIDATE_CLASS()
|
||||
return None
|
||||
# V1: FUNCTION names the method to call; it must exist on the class.
|
||||
function_name = getattr(class_def, "FUNCTION", None)
|
||||
if function_name is None:
|
||||
return f"'{class_type}' does not define FUNCTION"
|
||||
if not callable(getattr(class_def, function_name, None)):
|
||||
return f"'{class_type}' has no method '{function_name}' (declared in FUNCTION)"
|
||||
return None
|
||||
except Exception as ex:
|
||||
return str(ex)
|
||||
|
||||
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
@ -1174,35 +1148,6 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[
|
||||
}
|
||||
return (False, error, [], {})
|
||||
|
||||
# Make sure the node is actually executable (its FUNCTION/execute entry
|
||||
# point resolves to a real method) before we touch any schema-derived
|
||||
# attributes below or start execution. Catches code typos up front and
|
||||
# attributes the error to the offending node.
|
||||
not_executable = node_not_executable_reason(class_, class_type)
|
||||
if not_executable is not None:
|
||||
node_title = prompt[x].get('_meta', {}).get('title', class_type)
|
||||
error = {
|
||||
"type": "invalid_node_definition",
|
||||
"message": "Node is not executable",
|
||||
"details": f"{not_executable} (Node ID '#{x}')",
|
||||
"extra_info": {
|
||||
"node_id": x,
|
||||
"class_type": class_type,
|
||||
"node_title": node_title,
|
||||
}
|
||||
}
|
||||
node_errors = {x: {
|
||||
"errors": [{
|
||||
"type": "invalid_node_definition",
|
||||
"message": "Node is not executable",
|
||||
"details": not_executable,
|
||||
"extra_info": {},
|
||||
}],
|
||||
"dependent_outputs": [],
|
||||
"class_type": class_type,
|
||||
}}
|
||||
return (False, error, [], node_errors)
|
||||
|
||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||
if partial_execution_list is None or x in partial_execution_list:
|
||||
outputs.add(x)
|
||||
|
||||
483
openapi.yaml
483
openapi.yaml
@ -230,6 +230,93 @@ components:
|
||||
- base_version
|
||||
- workflow_json
|
||||
type: object
|
||||
DownloadEnqueueRequest:
|
||||
description: Request body for enqueuing a server-side model download.
|
||||
properties:
|
||||
allow_any_extension:
|
||||
default: false
|
||||
description: Permit a non-model file extension (default only allows known model extensions).
|
||||
type: boolean
|
||||
credential_id:
|
||||
description: Explicit per-host credential to use; otherwise auto-resolved by host. Still subject to the per-hop host match.
|
||||
nullable: true
|
||||
type: string
|
||||
expected_sha256:
|
||||
description: Optional hub-provided SHA256 to verify the completed file against (fail-closed).
|
||||
nullable: true
|
||||
type: string
|
||||
model_id:
|
||||
description: Destination as "<directory>/<filename>", resolving to a registered model folder (e.g. "loras/my_lora.safetensors").
|
||||
type: string
|
||||
priority:
|
||||
default: 0
|
||||
description: Scheduling priority; higher is admitted first.
|
||||
type: integer
|
||||
url:
|
||||
description: Source URL; must be on the allowlist (host + scheme + extension).
|
||||
type: string
|
||||
required:
|
||||
- url
|
||||
- model_id
|
||||
type: object
|
||||
DownloadStatus:
|
||||
description: Current state and live progress of a single download.
|
||||
properties:
|
||||
bytes_done:
|
||||
type: integer
|
||||
created_at:
|
||||
type: integer
|
||||
download_id:
|
||||
format: uuid
|
||||
type: string
|
||||
error:
|
||||
nullable: true
|
||||
type: string
|
||||
eta_seconds:
|
||||
nullable: true
|
||||
type: number
|
||||
model_id:
|
||||
type: string
|
||||
priority:
|
||||
type: integer
|
||||
progress:
|
||||
description: Fraction in [0,1]; null until total size is known.
|
||||
nullable: true
|
||||
type: number
|
||||
segments:
|
||||
description: Per-segment progress (segmented downloads only).
|
||||
items:
|
||||
properties:
|
||||
bytes_done:
|
||||
type: integer
|
||||
idx:
|
||||
type: integer
|
||||
length:
|
||||
type: integer
|
||||
type: object
|
||||
nullable: true
|
||||
type: array
|
||||
speed_bps:
|
||||
nullable: true
|
||||
type: number
|
||||
status:
|
||||
enum:
|
||||
- queued
|
||||
- active
|
||||
- paused
|
||||
- verifying
|
||||
- completed
|
||||
- failed
|
||||
- cancelled
|
||||
type: string
|
||||
total_bytes:
|
||||
nullable: true
|
||||
type: integer
|
||||
updated_at:
|
||||
type: integer
|
||||
url:
|
||||
type: string
|
||||
type: object
|
||||
ErrorResponse:
|
||||
description: Standard error response with a machine-readable code and human-readable message.
|
||||
properties:
|
||||
@ -511,6 +598,78 @@ components:
|
||||
required:
|
||||
- history
|
||||
type: object
|
||||
HostCredentialUpsert:
|
||||
description: Request body for upserting a per-host credential. The secret is write-only.
|
||||
properties:
|
||||
auth_scheme:
|
||||
default: bearer
|
||||
description: How the secret is attached to requests.
|
||||
enum:
|
||||
- bearer
|
||||
- header
|
||||
- query
|
||||
type: string
|
||||
enabled:
|
||||
default: true
|
||||
type: boolean
|
||||
header_name:
|
||||
description: Header name when auth_scheme=header (defaults to Authorization).
|
||||
nullable: true
|
||||
type: string
|
||||
host:
|
||||
description: Normalized hostname the key applies to (e.g. "civitai.com").
|
||||
type: string
|
||||
label:
|
||||
description: User-friendly name for display.
|
||||
nullable: true
|
||||
type: string
|
||||
match_subdomains:
|
||||
default: false
|
||||
description: Also match label-boundary subdomains of host (off by default; unsafe for hub CDNs).
|
||||
type: boolean
|
||||
query_param:
|
||||
description: Query parameter name when auth_scheme=query.
|
||||
nullable: true
|
||||
type: string
|
||||
secret:
|
||||
description: The API key. Write-only — never returned by any endpoint.
|
||||
type: string
|
||||
required:
|
||||
- host
|
||||
- secret
|
||||
type: object
|
||||
HostCredentialView:
|
||||
description: Masked, API-safe view of a stored credential. Never includes the secret.
|
||||
properties:
|
||||
auth_scheme:
|
||||
type: string
|
||||
created_at:
|
||||
type: integer
|
||||
enabled:
|
||||
type: boolean
|
||||
header_name:
|
||||
nullable: true
|
||||
type: string
|
||||
host:
|
||||
type: string
|
||||
id:
|
||||
format: uuid
|
||||
type: string
|
||||
label:
|
||||
nullable: true
|
||||
type: string
|
||||
match_subdomains:
|
||||
type: boolean
|
||||
query_param:
|
||||
nullable: true
|
||||
type: string
|
||||
secret_last4:
|
||||
description: Last 4 characters of the secret, for masked display only.
|
||||
nullable: true
|
||||
type: string
|
||||
updated_at:
|
||||
type: integer
|
||||
type: object
|
||||
JobCancelResponse:
|
||||
description: Response for POST /api/jobs/{job_id}/cancel. Returned on both fresh cancels and idempotent no-ops.
|
||||
properties:
|
||||
@ -2350,6 +2509,330 @@ paths:
|
||||
summary: Get tag histogram for filtered assets
|
||||
tags:
|
||||
- file
|
||||
/api/download:
|
||||
get:
|
||||
description: List all known downloads (queued, active, paused, and terminal) with live progress.
|
||||
operationId: listDownloads
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
downloads:
|
||||
items:
|
||||
$ref: '#/components/schemas/DownloadStatus'
|
||||
type: array
|
||||
type: object
|
||||
description: List of downloads
|
||||
summary: List downloads
|
||||
tags:
|
||||
- download
|
||||
/api/download/availability:
|
||||
post:
|
||||
description: |
|
||||
Bulk per-id availability for a set of model_ids declared in a workflow.
|
||||
Returns whether each model is available on disk, currently downloading
|
||||
(with progress), or missing, plus whether its URL is on the allowlist.
|
||||
operationId: getModelsAvailability
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
models:
|
||||
additionalProperties:
|
||||
type: string
|
||||
description: Map of "<directory>/<filename>" model_id to its declared source URL.
|
||||
type: object
|
||||
type: object
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
models:
|
||||
additionalProperties: true
|
||||
type: object
|
||||
type: object
|
||||
description: Per-id availability map
|
||||
summary: Bulk model availability + status
|
||||
tags:
|
||||
- download
|
||||
/api/download/credentials:
|
||||
get:
|
||||
description: List stored per-host credentials. Secrets are never returned; only masked metadata (last 4 chars, scheme, label).
|
||||
operationId: listDownloadCredentials
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
credentials:
|
||||
items:
|
||||
$ref: '#/components/schemas/HostCredentialView'
|
||||
type: array
|
||||
type: object
|
||||
description: Masked credential list
|
||||
summary: List host credentials (masked)
|
||||
tags:
|
||||
- download
|
||||
post:
|
||||
description: |
|
||||
Upsert (by host) a per-host API key used to authenticate downloads.
|
||||
The secret is write-only: it is stored once here and never returned by any endpoint.
|
||||
operationId: upsertDownloadCredential
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HostCredentialUpsert'
|
||||
responses:
|
||||
"201":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HostCredentialView'
|
||||
description: Credential stored (masked view returned)
|
||||
"400":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Invalid credential
|
||||
summary: Upsert a host credential
|
||||
tags:
|
||||
- download
|
||||
/api/download/credentials/{id}:
|
||||
delete:
|
||||
description: Delete a stored host credential.
|
||||
operationId: deleteDownloadCredential
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
deleted:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Deleted
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: No such credential
|
||||
summary: Delete a host credential
|
||||
tags:
|
||||
- download
|
||||
get:
|
||||
description: Get a single host credential (masked; never includes the secret).
|
||||
operationId: getDownloadCredential
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HostCredentialView'
|
||||
description: Masked credential
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: No such credential
|
||||
summary: Get a host credential (masked)
|
||||
tags:
|
||||
- download
|
||||
/api/download/enqueue:
|
||||
post:
|
||||
description: |
|
||||
Enqueue a server-side model download. The URL must be on the allowlist
|
||||
(host + scheme + extension) and the model_id must be "<directory>/<filename>"
|
||||
resolving to a registered model folder. Returns immediately; track progress
|
||||
via GET /api/download/{id} or the "download_progress" websocket event.
|
||||
operationId: enqueueDownload
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/DownloadEnqueueRequest'
|
||||
responses:
|
||||
"202":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
accepted:
|
||||
type: boolean
|
||||
download_id:
|
||||
format: uuid
|
||||
type: string
|
||||
type: object
|
||||
description: Download accepted and queued
|
||||
"400":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Invalid request (bad URL, model_id, or not allowlisted)
|
||||
"409":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Already on disk or already downloading
|
||||
summary: Enqueue a model download
|
||||
tags:
|
||||
- download
|
||||
/api/download/{id}:
|
||||
get:
|
||||
description: Get the current status + progress of a single download.
|
||||
operationId: getDownloadStatus
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/DownloadStatus'
|
||||
description: Download status
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: No such download
|
||||
summary: Get download status
|
||||
tags:
|
||||
- download
|
||||
/api/download/{id}/cancel:
|
||||
post:
|
||||
description: Cancel a download. The partial file is removed.
|
||||
operationId: cancelDownload
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
ok:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Cancelled
|
||||
summary: Cancel a download
|
||||
tags:
|
||||
- download
|
||||
/api/download/{id}/pause:
|
||||
post:
|
||||
description: Pause a download. The partial file and per-segment offsets are retained for resume.
|
||||
operationId: pauseDownload
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
ok:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Paused
|
||||
summary: Pause a download
|
||||
tags:
|
||||
- download
|
||||
/api/download/{id}/priority:
|
||||
post:
|
||||
description: Set a download's scheduling priority. Higher priority is admitted first when a slot frees.
|
||||
operationId: setDownloadPriority
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
priority:
|
||||
type: integer
|
||||
required:
|
||||
- priority
|
||||
type: object
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
ok:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Priority updated
|
||||
summary: Set download priority
|
||||
tags:
|
||||
- download
|
||||
/api/download/{id}/resume:
|
||||
post:
|
||||
description: Resume a paused (or failed) download from its persisted offsets.
|
||||
operationId: resumeDownload
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
ok:
|
||||
type: boolean
|
||||
type: object
|
||||
description: Resumed
|
||||
summary: Resume a download
|
||||
tags:
|
||||
- download
|
||||
/api/embeddings:
|
||||
get:
|
||||
description: Returns the list of text-encoder embeddings available on disk.
|
||||
|
||||
@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=16.0.0
|
||||
comfy-kitchen==0.2.12
|
||||
comfy-kitchen==0.2.13
|
||||
comfy-aimdo==0.4.10
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
@ -33,5 +33,5 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
PyOpenGL
|
||||
glfw
|
||||
PyOpenGL>=3.1.8
|
||||
comfy-angle
|
||||
|
||||
21
server.py
21
server.py
@ -45,6 +45,8 @@ from app.frontend_management import FrontendManager, parse_version
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.api.routes import register_assets_routes
|
||||
from app.model_downloader.api.routes import register_routes as register_model_downloader_routes
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER
|
||||
from app.assets.services.ingest import register_file_in_place
|
||||
from app.assets.services.asset_management import resolve_hash_to_path
|
||||
|
||||
@ -256,6 +258,7 @@ class PromptServer():
|
||||
else:
|
||||
register_assets_routes(self.app)
|
||||
asset_seeder.disable()
|
||||
register_model_downloader_routes(self.app)
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
@ -1182,6 +1185,24 @@ class PromptServer():
|
||||
async def setup(self):
|
||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||
await self._setup_model_downloader()
|
||||
|
||||
async def _setup_model_downloader(self):
|
||||
"""Start the download manager: push progress over the websocket and
|
||||
resume any downloads interrupted by a previous run."""
|
||||
def _notify(download_id: str) -> None:
|
||||
try:
|
||||
view = DOWNLOAD_MANAGER.status_sync(download_id)
|
||||
if view is not None:
|
||||
self.send_sync("download_progress", view)
|
||||
except Exception:
|
||||
logging.debug("download progress notify failed", exc_info=True)
|
||||
|
||||
DOWNLOAD_MANAGER.set_notify(_notify)
|
||||
try:
|
||||
await DOWNLOAD_MANAGER.start()
|
||||
except Exception as e:
|
||||
logging.warning("Failed to start model download manager: %s", e)
|
||||
|
||||
def add_routes(self):
|
||||
self.user_manager.add_routes(self.routes)
|
||||
|
||||
@ -1,137 +0,0 @@
|
||||
"""Tests for pre-execution validation that a node is actually executable.
|
||||
|
||||
validate_prompt rejects a node whose declared entry point does not resolve to a
|
||||
real method (a V1 FUNCTION typo, or a V3 node missing its execute override) before
|
||||
any node runs, attributing the error to the offending node.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
import nodes
|
||||
from comfy_api.latest import io
|
||||
from execution import node_not_executable_reason, validate_prompt
|
||||
|
||||
|
||||
class _GoodV1Node:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "run"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "Test"
|
||||
|
||||
def run(self):
|
||||
return (None,)
|
||||
|
||||
|
||||
class _TypoV1Node:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "invert" # method below is misspelled
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "Test"
|
||||
|
||||
def invvert(self):
|
||||
return (None,)
|
||||
|
||||
|
||||
class _SideEffectInitV1Node:
|
||||
"""Valid class-level method, but a constructor that must never run in validation."""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "run"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "Test"
|
||||
|
||||
def __init__(self):
|
||||
raise RuntimeError("__init__ must not run during validation")
|
||||
|
||||
def run(self):
|
||||
return (None,)
|
||||
|
||||
|
||||
def _v3_schema(node_id):
|
||||
return io.Schema(
|
||||
node_id=node_id,
|
||||
display_name=node_id,
|
||||
category="Test",
|
||||
inputs=[],
|
||||
outputs=[io.Image.Output()],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
|
||||
class _GoodV3Node(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return _v3_schema("GoodV3Node")
|
||||
|
||||
@classmethod
|
||||
def execute(cls):
|
||||
return io.NodeOutput(None)
|
||||
|
||||
|
||||
class _TypoV3Node(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return _v3_schema("TypoV3Node")
|
||||
|
||||
@classmethod
|
||||
def exicute(cls): # typo: should be "execute"
|
||||
return io.NodeOutput(None)
|
||||
|
||||
|
||||
def _register(class_type, class_def):
|
||||
nodes.NODE_CLASS_MAPPINGS[class_type] = class_def
|
||||
|
||||
|
||||
def _validate(class_type):
|
||||
prompt = {"1": {"class_type": class_type, "inputs": {}}}
|
||||
return asyncio.run(validate_prompt("pid", prompt, None))
|
||||
|
||||
|
||||
def test_good_node_passes():
|
||||
_register("GoodV1Node", _GoodV1Node)
|
||||
assert node_not_executable_reason(_GoodV1Node, "GoodV1Node") is None
|
||||
valid, _, _, _ = _validate("GoodV1Node")
|
||||
assert valid is True
|
||||
|
||||
|
||||
def test_typo_node_rejected_with_node_error():
|
||||
_register("TypoV1Node", _TypoV1Node)
|
||||
valid, error, _, node_errors = _validate("TypoV1Node")
|
||||
assert valid is False
|
||||
assert error["type"] == "invalid_node_definition"
|
||||
assert node_errors["1"]["class_type"] == "TypoV1Node"
|
||||
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"
|
||||
assert "invert" in node_errors["1"]["errors"][0]["details"]
|
||||
|
||||
|
||||
def test_validation_does_not_instantiate_node():
|
||||
"""A valid node is not constructed during validation, so __init__ never runs."""
|
||||
_register("SideEffectInitV1Node", _SideEffectInitV1Node)
|
||||
assert node_not_executable_reason(_SideEffectInitV1Node, "SideEffectInitV1Node") is None
|
||||
valid, _, _, _ = _validate("SideEffectInitV1Node")
|
||||
assert valid is True
|
||||
|
||||
|
||||
def test_good_v3_node_passes():
|
||||
_register("GoodV3Node", _GoodV3Node)
|
||||
assert node_not_executable_reason(_GoodV3Node, "GoodV3Node") is None
|
||||
valid, _, _, _ = _validate("GoodV3Node")
|
||||
assert valid is True
|
||||
|
||||
|
||||
def test_typo_v3_node_rejected_with_node_error():
|
||||
_register("TypoV3Node", _TypoV3Node)
|
||||
valid, error, _, node_errors = _validate("TypoV3Node")
|
||||
assert valid is False
|
||||
assert error["type"] == "invalid_node_definition"
|
||||
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"
|
||||
63
tests-unit/model_downloader_test/conftest.py
Normal file
63
tests-unit/model_downloader_test/conftest.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Shared fixtures for the model download manager tests.
|
||||
|
||||
These run in-process (no ComfyUI subprocess): a file-backed SQLite DB is
|
||||
initialized once, a temp model folder is registered with ``folder_paths``, and
|
||||
the shared aiohttp session is reset between tests so each async test gets a
|
||||
session bound to its own event loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _init_db():
|
||||
import app.database.db as db
|
||||
from comfy.cli_args import args
|
||||
|
||||
db_path = tempfile.mktemp(suffix="-dlmgr-test.sqlite3")
|
||||
args.database_url = f"sqlite:///{db_path}"
|
||||
db.init_db()
|
||||
yield
|
||||
try:
|
||||
os.remove(db_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_runtime():
|
||||
"""Reset module singletons that hold event-loop-bound or cross-test state."""
|
||||
import app.model_downloader.net.session as ns
|
||||
from app.model_downloader.scheduler import SCHEDULER
|
||||
|
||||
ns._session = None
|
||||
SCHEDULER._jobs.clear()
|
||||
SCHEDULER._tasks.clear()
|
||||
SCHEDULER._backoff_until.clear()
|
||||
SCHEDULER._started = False
|
||||
yield
|
||||
ns._session = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_root(tmp_path):
|
||||
"""Register a temp 'loras' model folder and return its absolute path."""
|
||||
import folder_paths
|
||||
|
||||
root = tmp_path / "loras"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
saved = folder_paths.folder_names_and_paths.get("loras")
|
||||
folder_paths.folder_names_and_paths["loras"] = (
|
||||
[str(root)],
|
||||
{".safetensors", ".sft", ".ckpt", ".pt", ".pth"},
|
||||
)
|
||||
yield str(root)
|
||||
if saved is not None:
|
||||
folder_paths.folder_names_and_paths["loras"] = saved
|
||||
else:
|
||||
folder_paths.folder_names_and_paths.pop("loras", None)
|
||||
110
tests-unit/model_downloader_test/test_credentials.py
Normal file
110
tests-unit/model_downloader_test/test_credentials.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""Unit tests for the credential store and the per-hop credential resolver.
|
||||
|
||||
Covers the critical rule: a secret is only ever attached when the current
|
||||
hop's host matches a stored credential, and never over a non-https hop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from app.model_downloader.credentials import resolver
|
||||
from app.model_downloader.credentials.store import (
|
||||
CREDENTIAL_STORE,
|
||||
CredentialValidationError,
|
||||
normalize_host,
|
||||
)
|
||||
from app.model_downloader.database.models import HostCredential
|
||||
|
||||
|
||||
# ----- pure host normalization + matching -----
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw,expected",
|
||||
[
|
||||
("Civitai.com", "civitai.com"),
|
||||
("HuggingFace.co:443", "huggingface.co"),
|
||||
(" Example.COM ", "example.com"),
|
||||
],
|
||||
)
|
||||
def test_normalize_host(raw, expected):
|
||||
assert normalize_host(raw) == expected
|
||||
|
||||
|
||||
def _cred(**kw) -> HostCredential:
|
||||
base = dict(
|
||||
id="x", host="civitai.com", match_subdomains=False, auth_scheme="bearer",
|
||||
secret="SECRET", enabled=True,
|
||||
)
|
||||
base.update(kw)
|
||||
return HostCredential(**base)
|
||||
|
||||
|
||||
def test_matches_exact_only_by_default():
|
||||
c = _cred(host="civitai.com")
|
||||
assert resolver._matches(c, "civitai.com") is True
|
||||
assert resolver._matches(c, "api.civitai.com") is False
|
||||
assert resolver._matches(c, "evil-civitai.com") is False
|
||||
|
||||
|
||||
def test_matches_subdomain_label_boundary():
|
||||
c = _cred(host="example.com", match_subdomains=True)
|
||||
assert resolver._matches(c, "api.example.com") is True
|
||||
assert resolver._matches(c, "example.com") is True
|
||||
# not a label boundary -> no match
|
||||
assert resolver._matches(c, "evil-example.com") is False
|
||||
|
||||
|
||||
def test_build_auth_shapes():
|
||||
assert resolver._build_auth(_cred(auth_scheme="bearer")).headers == {
|
||||
"Authorization": "Bearer SECRET"
|
||||
}
|
||||
assert resolver._build_auth(
|
||||
_cred(auth_scheme="header", header_name="X-Api-Key")
|
||||
).headers == {"X-Api-Key": "SECRET"}
|
||||
q = resolver._build_auth(_cred(auth_scheme="query", query_param="token"))
|
||||
assert q.query == {"token": "SECRET"}
|
||||
assert q.apply_to_url("https://civitai.com/x") == "https://civitai.com/x?token=SECRET"
|
||||
|
||||
|
||||
# ----- DB-backed store + resolver -----
|
||||
|
||||
|
||||
def test_store_upsert_is_write_only_and_masked():
|
||||
async def _run():
|
||||
view = await CREDENTIAL_STORE.upsert("civitai.com", "abcd1234", label="my key")
|
||||
# The view never carries the secret, only the last 4.
|
||||
assert not hasattr(view, "secret")
|
||||
assert view.secret_last4 == "1234"
|
||||
assert view.host == "civitai.com"
|
||||
listed = await CREDENTIAL_STORE.list()
|
||||
assert any(v.host == "civitai.com" for v in listed)
|
||||
await CREDENTIAL_STORE.delete(view.id)
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_query_scheme_requires_param():
|
||||
async def _run():
|
||||
with pytest.raises(CredentialValidationError):
|
||||
await CREDENTIAL_STORE.upsert("civitai.com", "k", auth_scheme="query")
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_resolver_never_crosses_host_boundary():
|
||||
async def _run():
|
||||
view = await CREDENTIAL_STORE.upsert("huggingface.co", "hf_secret_key")
|
||||
try:
|
||||
# matching host over https -> attached
|
||||
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
|
||||
assert auth is not None
|
||||
assert auth.headers["Authorization"] == "Bearer hf_secret_key"
|
||||
# CDN redirect host -> dropped
|
||||
assert await resolver.resolve_auth_for_hop("cdn-lfs.huggingface.co", "https") is None
|
||||
# non-https hop -> never attached
|
||||
assert await resolver.resolve_auth_for_hop("huggingface.co", "http") is None
|
||||
finally:
|
||||
await CREDENTIAL_STORE.delete(view.id)
|
||||
asyncio.run(_run())
|
||||
270
tests-unit/model_downloader_test/test_engine_integration.py
Normal file
270
tests-unit/model_downloader_test/test_engine_integration.py
Normal file
@ -0,0 +1,270 @@
|
||||
"""Integration tests for the download engine against a local aiohttp server.
|
||||
|
||||
Covers single-stream and segmented transfers, deterministic resume from a
|
||||
partial file, and cancel rollback. Async tests are driven via ``asyncio.run``
|
||||
so no pytest-asyncio plugin is required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from comfy.cli_args import args
|
||||
from app.model_downloader.constants import DownloadStatus
|
||||
from app.model_downloader.database import queries
|
||||
from app.model_downloader.engine.job import DownloadJob, JobSpec
|
||||
from app.model_downloader.net.session import close_session
|
||||
from app.model_downloader.security import paths
|
||||
|
||||
PAYLOAD_ETAG = '"v1"'
|
||||
|
||||
|
||||
def _payload(n: int) -> bytes:
|
||||
return bytes((i * 37 + 11) % 256 for i in range(n))
|
||||
|
||||
|
||||
def _range_handler(payload: bytes):
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
rng = request.headers.get("Range")
|
||||
if rng:
|
||||
spec = rng.split("=", 1)[1]
|
||||
s, _, e = spec.partition("-")
|
||||
start = int(s)
|
||||
end = int(e) if e else len(payload) - 1
|
||||
chunk = payload[start : end + 1]
|
||||
return web.Response(
|
||||
status=206,
|
||||
body=chunk,
|
||||
headers={
|
||||
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"ETag": PAYLOAD_ETAG,
|
||||
},
|
||||
)
|
||||
return web.Response(
|
||||
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
|
||||
)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _noranges_handler(payload: bytes):
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
# Always full body, never advertises Accept-Ranges -> single-stream.
|
||||
return web.Response(status=200, body=payload)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _slow_handler(payload: bytes, chunk: int = 16384, delay: float = 0.01):
|
||||
async def handler(request: web.Request) -> web.StreamResponse:
|
||||
resp = web.StreamResponse(
|
||||
status=200, headers={"Content-Length": str(len(payload))}
|
||||
)
|
||||
await resp.prepare(request)
|
||||
for i in range(0, len(payload), chunk):
|
||||
await resp.write(payload[i : i + chunk])
|
||||
await asyncio.sleep(delay)
|
||||
await resp.write_eof()
|
||||
return resp
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
async def _serve(handler):
|
||||
app = web.Application()
|
||||
app.router.add_route("*", "/{name:.*}", handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", 0)
|
||||
await site.start()
|
||||
port = site._server.sockets[0].getsockname()[1]
|
||||
return runner, port
|
||||
|
||||
|
||||
def _insert(model_id: str, url: str, status: str = DownloadStatus.QUEUED) -> tuple[str, str, str]:
|
||||
final_path, temp_path = paths.resolve_destination(model_id)
|
||||
download_id = str(uuid.uuid4())
|
||||
queries.insert_download(
|
||||
{
|
||||
"id": download_id,
|
||||
"url": url,
|
||||
"model_id": model_id,
|
||||
"dest_path": final_path,
|
||||
"temp_path": temp_path,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
return download_id, final_path, temp_path
|
||||
|
||||
|
||||
# ----- single-stream -----
|
||||
|
||||
|
||||
def test_single_stream_download(model_root):
|
||||
payload = _payload(300_000)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_noranges_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, _temp = _insert("loras/single.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/single.safetensors",
|
||||
dest_path=final_path, temp_path=_temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
|
||||
assert os.path.exists(final_path)
|
||||
assert open(final_path, "rb").read() == payload
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- segmented -----
|
||||
|
||||
|
||||
def test_segmented_download(model_root):
|
||||
payload = _payload(4 * 1024 * 1024) # 4 MiB -> multiple segments
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_range_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/seg.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/seg.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
|
||||
assert open(final_path, "rb").read() == payload
|
||||
# More than one segment row was planned.
|
||||
assert len(queries.list_segments(did)) > 1
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- deterministic resume from a partial file -----
|
||||
|
||||
|
||||
def test_resume_from_partial(model_root):
|
||||
payload = _payload(512 * 1024) # < 1 MiB -> single segment, but ranges work
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_range_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/resume.safetensors", url)
|
||||
# Simulate a prior partial: first 200 KiB already written, offset persisted.
|
||||
prefix = 200 * 1024
|
||||
os.makedirs(os.path.dirname(temp), exist_ok=True)
|
||||
with open(temp, "wb") as f:
|
||||
f.write(payload[:prefix])
|
||||
queries.update_download(did, bytes_done=prefix, etag=PAYLOAD_ETAG)
|
||||
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/resume.safetensors",
|
||||
dest_path=final_path, temp_path=temp, etag=PAYLOAD_ETAG,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
|
||||
assert open(final_path, "rb").read() == payload
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- cancel rollback -----
|
||||
|
||||
|
||||
def test_cancel_rollback(model_root, monkeypatch):
|
||||
monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False)
|
||||
payload = _payload(1024 * 1024)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_slow_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/cancel.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/cancel.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
task = asyncio.ensure_future(job.run())
|
||||
# Wait until some bytes have been written, then cancel.
|
||||
for _ in range(200):
|
||||
await asyncio.sleep(0.01)
|
||||
if job.state.bytes_done > 0:
|
||||
break
|
||||
job.request_cancel()
|
||||
status = await task
|
||||
assert status == DownloadStatus.CANCELLED
|
||||
assert not os.path.exists(temp)
|
||||
assert not os.path.exists(final_path)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- manager + scheduler end-to-end -----
|
||||
|
||||
|
||||
def test_manager_enqueue_to_completion(model_root):
|
||||
payload = _payload(2 * 1024 * 1024)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER
|
||||
|
||||
runner, port = await _serve(_range_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/e2e.safetensors")
|
||||
# Wait for completion.
|
||||
final_path, _ = paths.resolve_destination("loras/e2e.safetensors")
|
||||
for _ in range(500):
|
||||
await asyncio.sleep(0.02)
|
||||
row = queries.get_download(did)
|
||||
if row.status in DownloadStatus.TERMINAL:
|
||||
break
|
||||
row = queries.get_download(did)
|
||||
assert row.status == DownloadStatus.COMPLETED, row.error
|
||||
assert open(final_path, "rb").read() == payload
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_manager_rejects_disallowed_url(model_root):
|
||||
async def _run():
|
||||
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
|
||||
|
||||
with pytest.raises(DownloadError) as ei:
|
||||
await DOWNLOAD_MANAGER.enqueue(
|
||||
"https://evil.example.com/x.safetensors", "loras/bad.safetensors"
|
||||
)
|
||||
assert ei.value.code == "URL_NOT_ALLOWED"
|
||||
|
||||
asyncio.run(_run())
|
||||
71
tests-unit/model_downloader_test/test_planner_structural.py
Normal file
71
tests-unit/model_downloader_test/test_planner_structural.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""Unit tests for the segment planner and structural safetensors validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import struct
|
||||
|
||||
import pytest
|
||||
|
||||
from app.model_downloader.engine.planner import (
|
||||
effective_segment_count,
|
||||
plan_segments,
|
||||
)
|
||||
from app.model_downloader.verify import structural
|
||||
|
||||
|
||||
# ----- planner -----
|
||||
|
||||
|
||||
def test_plan_segments_covers_full_range_contiguously():
|
||||
total = 1000
|
||||
plans = plan_segments(total, 4)
|
||||
assert len(plans) == 4
|
||||
assert plans[0].start == 0
|
||||
assert plans[-1].end == total - 1
|
||||
# contiguous, no gaps/overlaps
|
||||
for a, b in zip(plans, plans[1:]):
|
||||
assert b.start == a.end + 1
|
||||
assert sum(p.length for p in plans) == total
|
||||
|
||||
|
||||
def test_effective_segment_count_falls_back_to_single():
|
||||
# No range support -> single
|
||||
assert effective_segment_count(10_000_000, False, 8) == 1
|
||||
# Unknown size -> single
|
||||
assert effective_segment_count(None, True, 8) == 1
|
||||
# Tiny file -> fewer segments than configured
|
||||
assert effective_segment_count(1024, True, 8) == 1
|
||||
# Large file with range support -> configured count
|
||||
assert effective_segment_count(1_000_000_000, True, 8) == 8
|
||||
|
||||
|
||||
# ----- structural -----
|
||||
|
||||
|
||||
def _make_safetensors(tensor_data_len: int, *, corrupt_size: bool = False) -> bytes:
|
||||
header = {"t": {"dtype": "F32", "shape": [tensor_data_len], "data_offsets": [0, tensor_data_len]}}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
body = b"\x00" * tensor_data_len
|
||||
if corrupt_size:
|
||||
body = body[:-1] # truncate one byte
|
||||
return struct.pack("<Q", len(header_bytes)) + header_bytes + body
|
||||
|
||||
|
||||
def test_structural_valid_safetensors(tmp_path):
|
||||
p = tmp_path / "ok.safetensors"
|
||||
p.write_bytes(_make_safetensors(256))
|
||||
structural.validate(str(p)) # no raise
|
||||
|
||||
|
||||
def test_structural_detects_truncation(tmp_path):
|
||||
p = tmp_path / "bad.safetensors"
|
||||
p.write_bytes(_make_safetensors(256, corrupt_size=True))
|
||||
with pytest.raises(structural.StructuralError):
|
||||
structural.validate(str(p))
|
||||
|
||||
|
||||
def test_structural_skips_unknown_extension(tmp_path):
|
||||
p = tmp_path / "weights.bin"
|
||||
p.write_bytes(b"anything")
|
||||
structural.validate(str(p)) # no structural check, no raise
|
||||
111
tests-unit/model_downloader_test/test_security.py
Normal file
111
tests-unit/model_downloader_test/test_security.py
Normal file
@ -0,0 +1,111 @@
|
||||
"""Unit tests for the security layer: allowlist, SSRF checks, path safety."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.model_downloader.security import allowlist, paths
|
||||
from app.model_downloader.security.ssrf import (
|
||||
SSRFError,
|
||||
check_redirect_hop,
|
||||
is_blocked_ip,
|
||||
)
|
||||
|
||||
|
||||
# ----- allowlist -----
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url,allowed",
|
||||
[
|
||||
("https://huggingface.co/org/repo/resolve/main/model.safetensors", True),
|
||||
("https://civitai.com/api/download/x/model.safetensors", True),
|
||||
("http://localhost/model.safetensors", True),
|
||||
# off-list host
|
||||
("https://evil.example.com/model.safetensors", False),
|
||||
# http to a non-loopback allowlisted host is not permitted (https only)
|
||||
("http://huggingface.co/org/repo/resolve/main/model.safetensors", False),
|
||||
# bad extension on an allowed host
|
||||
("https://huggingface.co/org/repo/resolve/main/config.json", False),
|
||||
# userinfo trick: real host is the metadata IP, not 127.0.0.1
|
||||
("http://127.0.0.1@169.254.169.254/x.safetensors", False),
|
||||
],
|
||||
)
|
||||
def test_is_url_allowed(url, allowed):
|
||||
assert allowlist.is_url_allowed(url) is allowed
|
||||
|
||||
|
||||
def test_allow_any_extension_relaxes_extension_only():
|
||||
url = "https://huggingface.co/org/repo/resolve/main/weights.bin"
|
||||
assert allowlist.is_url_allowed(url) is True # .bin is in the known set
|
||||
odd = "https://huggingface.co/org/repo/resolve/main/weights.zip"
|
||||
assert allowlist.is_url_allowed(odd) is False
|
||||
assert allowlist.is_url_allowed(odd, allow_any_extension=True) is True
|
||||
|
||||
|
||||
# ----- SSRF: blocked IPs -----
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ip,blocked",
|
||||
[
|
||||
("169.254.169.254", True), # cloud metadata / link-local
|
||||
("127.0.0.1", True),
|
||||
("10.0.0.5", True),
|
||||
("192.168.1.1", True),
|
||||
("172.16.0.1", True),
|
||||
("::1", True),
|
||||
("0.0.0.0", True),
|
||||
("8.8.8.8", False),
|
||||
("1.1.1.1", False),
|
||||
("not-an-ip", True), # unparseable -> refuse
|
||||
],
|
||||
)
|
||||
def test_is_blocked_ip(ip, blocked):
|
||||
assert is_blocked_ip(ip) is blocked
|
||||
|
||||
|
||||
# ----- SSRF: redirect hop validation -----
|
||||
|
||||
|
||||
def test_check_redirect_hop_rejects_bad_scheme_and_userinfo():
|
||||
with pytest.raises(SSRFError):
|
||||
check_redirect_hop("ftp://huggingface.co/x.safetensors")
|
||||
with pytest.raises(SSRFError):
|
||||
check_redirect_hop("https://user:pass@cdn.example.com/x")
|
||||
# A CDN host that is NOT on the allowlist is allowed as a redirect target
|
||||
# (private-IP protection is the resolver's job; credential leak is prevented
|
||||
# by exact host matching).
|
||||
assert check_redirect_hop("https://cdn-lfs.huggingface.co/abc") is not None
|
||||
|
||||
|
||||
# ----- path safety -----
|
||||
|
||||
|
||||
def test_parse_model_id_valid(model_root):
|
||||
directory, filename = paths.parse_model_id("loras/my_lora.safetensors")
|
||||
assert directory == "loras"
|
||||
assert filename == "my_lora.safetensors"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_id",
|
||||
[
|
||||
"loras/../etc/passwd.safetensors", # traversal
|
||||
"loras/sub/dir.safetensors", # nested
|
||||
"unknownfolder/x.safetensors", # unknown folder
|
||||
"loras/model.txt", # bad extension
|
||||
"noslash.safetensors", # missing directory
|
||||
"loras/", # empty filename
|
||||
],
|
||||
)
|
||||
def test_parse_model_id_rejects(model_root, model_id):
|
||||
with pytest.raises(paths.InvalidModelId):
|
||||
paths.parse_model_id(model_id)
|
||||
|
||||
|
||||
def test_resolve_destination_stays_in_root(model_root):
|
||||
final_path, temp_path = paths.resolve_destination("loras/x.safetensors")
|
||||
assert final_path.startswith(model_root)
|
||||
assert temp_path.startswith(model_root)
|
||||
assert temp_path != final_path
|
||||
Reference in New Issue
Block a user