Compare commits

..

3 Commits

Author SHA1 Message Date
8d18675e75 Add initial commit for model downloader. 2026-06-27 13:10:05 +02:00
603d891eaf Update GLSL node to use ANGLE library (CORE-162) (#13195) 2026-06-27 08:40:31 +08:00
470ac36a0a Fix int8 loras causing lower quality requant with wrong settings. (#14650)
* Update comfy-kitchen

* Support requantizing with same settings as orig quant.
2026-06-26 16:41:29 -07:00
38 changed files with 4203 additions and 547 deletions

View File

@ -0,0 +1,118 @@
"""
Download manager schema.
Adds the three tables that back the server-side model download manager
(PRD section 7): transient job/queue state (``downloads`` + per-segment
``download_segments``) and one-API-key-per-host auth (``host_credentials``).
The local file catalog / dedup index is intentionally NOT added here — it
is owned by the assets system (``assets`` / ``asset_references``).
Revision ID: 0005_download_manager
Revises: 0004_drop_tag_type
Create Date: 2026-06-27
"""
from alembic import op
import sqlalchemy as sa
revision = "0005_download_manager"
down_revision = "0004_drop_tag_type"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"downloads",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("url", sa.Text(), nullable=False),
sa.Column("final_url", sa.Text(), nullable=True),
sa.Column("model_id", sa.String(length=1024), nullable=False),
sa.Column("dest_path", sa.Text(), nullable=False),
sa.Column("temp_path", sa.Text(), nullable=False),
sa.Column("status", sa.String(length=16), nullable=False),
sa.Column("priority", sa.Integer(), nullable=False, server_default="0"),
sa.Column("total_bytes", sa.BigInteger(), nullable=True),
sa.Column("bytes_done", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("etag", sa.String(length=512), nullable=True),
sa.Column("last_modified", sa.String(length=128), nullable=True),
sa.Column(
"accept_ranges", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column("expected_sha256", sa.String(length=64), nullable=True),
sa.Column("credential_id", sa.String(length=36), nullable=True),
sa.Column(
"allow_any_extension",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
sa.Column("attempts", sa.Integer(), nullable=False, server_default="0"),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
sa.CheckConstraint("bytes_done >= 0", name="ck_downloads_bytes_done_nonneg"),
sa.CheckConstraint(
"total_bytes IS NULL OR total_bytes >= 0",
name="ck_downloads_total_bytes_nonneg",
),
)
op.create_index("ix_downloads_status", "downloads", ["status"])
op.create_index("ix_downloads_priority", "downloads", ["priority"])
op.create_index("ix_downloads_model_id", "downloads", ["model_id"])
op.create_table(
"download_segments",
sa.Column(
"download_id",
sa.String(length=36),
sa.ForeignKey("downloads.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("idx", sa.Integer(), nullable=False),
sa.Column("start_offset", sa.BigInteger(), nullable=False),
sa.Column("end_offset", sa.BigInteger(), nullable=False),
sa.Column("bytes_done", sa.BigInteger(), nullable=False, server_default="0"),
sa.PrimaryKeyConstraint("download_id", "idx", name="pk_download_segments"),
sa.CheckConstraint("bytes_done >= 0", name="ck_segments_bytes_done_nonneg"),
sa.CheckConstraint("end_offset >= start_offset", name="ck_segments_range"),
)
op.create_table(
"host_credentials",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("host", sa.String(length=255), nullable=False),
sa.Column(
"match_subdomains",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
sa.Column("label", sa.String(length=255), nullable=True),
sa.Column(
"auth_scheme", sa.String(length=16), nullable=False, server_default="bearer"
),
sa.Column("header_name", sa.String(length=255), nullable=True),
sa.Column("query_param", sa.String(length=255), nullable=True),
sa.Column("secret", sa.Text(), nullable=False),
sa.Column("secret_last4", sa.String(length=4), nullable=True),
sa.Column("enabled", sa.Boolean(), nullable=False, server_default=sa.text("true")),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
)
op.create_index(
"uq_host_credentials_host", "host_credentials", ["host"], unique=True
)
def downgrade() -> None:
op.drop_index("uq_host_credentials_host", table_name="host_credentials")
op.drop_table("host_credentials")
op.drop_table("download_segments")
op.drop_index("ix_downloads_model_id", table_name="downloads")
op.drop_index("ix_downloads_priority", table_name="downloads")
op.drop_index("ix_downloads_status", table_name="downloads")
op.drop_table("downloads")

View File

@ -21,6 +21,7 @@ try:
from app.database.models import Base
import app.assets.database.models # noqa: F401 — register models with Base.metadata
import app.model_downloader.database.models # noqa: F401 — register models with Base.metadata
_DB_AVAILABLE = True
except ImportError as e:

View File

@ -0,0 +1,203 @@
"""aiohttp routes for the download manager.
Endpoint surface (all under ``/api/download``), mirroring the response
envelope used by ``app/assets/api/routes.py``:
POST /api/download/enqueue
GET /api/download
POST /api/download/availability
POST /api/download/credentials
GET /api/download/credentials
GET /api/download/credentials/{id}
DELETE /api/download/credentials/{id}
GET /api/download/{id}
POST /api/download/{id}/pause
POST /api/download/{id}/resume
POST /api/download/{id}/cancel
POST /api/download/{id}/priority
Note on ordering: the static ``credentials`` routes are registered before the
dynamic ``/api/download/{id}`` route so a request to ``.../credentials`` is not
captured as ``id == "credentials"``.
"""
from __future__ import annotations
import json
from aiohttp import web
from pydantic import BaseModel, ValidationError
from app.model_downloader.api import schemas_in, schemas_out
from app.model_downloader.credentials.store import (
CREDENTIAL_STORE,
CredentialValidationError,
)
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
ROUTES = web.RouteTableDef()
def register_routes(app: web.Application) -> None:
"""Wire the download-manager routes into the running aiohttp app."""
app.add_routes(ROUTES)
# ----- envelope helpers (same shape as app/assets/api/routes.py) -----
def _error(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
return web.json_response(
{"error": {"code": code, "message": message, "details": details or {}}},
status=status,
)
def _ok(payload, status: int = 200) -> web.Response:
return web.json_response(payload, status=status)
async def _parse(request: web.Request, model: type[BaseModel]):
try:
raw = await request.json()
except json.JSONDecodeError:
return _error(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
return model.model_validate(raw)
except ValidationError as ve:
return _error(400, "INVALID_BODY", "Validation failed.", {"errors": json.loads(ve.json())})
def _from_download_error(e: DownloadError) -> web.Response:
return _error(e.http_status, e.code, e.message)
# ----- downloads: collection + enqueue + availability -----
@ROUTES.post("/api/download/enqueue")
async def enqueue(request: web.Request) -> web.Response:
parsed = await _parse(request, schemas_in.EnqueueRequest)
if isinstance(parsed, web.Response):
return parsed
try:
download_id = await DOWNLOAD_MANAGER.enqueue(
parsed.url,
parsed.model_id,
priority=parsed.priority,
expected_sha256=parsed.expected_sha256,
allow_any_extension=parsed.allow_any_extension,
credential_id=parsed.credential_id,
)
except DownloadError as e:
return _from_download_error(e)
return _ok({"download_id": download_id, "accepted": True}, status=202)
@ROUTES.get("/api/download")
async def list_downloads(request: web.Request) -> web.Response:
return _ok({"downloads": await DOWNLOAD_MANAGER.list()})
@ROUTES.post("/api/download/availability")
async def availability(request: web.Request) -> web.Response:
parsed = await _parse(request, schemas_in.AvailabilityRequest)
if isinstance(parsed, web.Response):
return parsed
return _ok({"models": await DOWNLOAD_MANAGER.availability(parsed.models)})
# ----- credentials (secrets are write-only) — must precede /{id} -----
@ROUTES.post("/api/download/credentials")
async def upsert_credential(request: web.Request) -> web.Response:
parsed = await _parse(request, schemas_in.CredentialUpsertRequest)
if isinstance(parsed, web.Response):
return parsed
try:
view = await CREDENTIAL_STORE.upsert(
parsed.host,
parsed.secret,
auth_scheme=parsed.auth_scheme,
header_name=parsed.header_name,
query_param=parsed.query_param,
label=parsed.label,
match_subdomains=parsed.match_subdomains,
enabled=parsed.enabled,
)
except CredentialValidationError as e:
return _error(400, "INVALID_CREDENTIAL", str(e))
return _ok(schemas_out.credential_to_dict(view), status=201)
@ROUTES.get("/api/download/credentials")
async def list_credentials(request: web.Request) -> web.Response:
views = await CREDENTIAL_STORE.list()
return _ok({"credentials": [schemas_out.credential_to_dict(v) for v in views]})
@ROUTES.get("/api/download/credentials/{id}")
async def get_credential(request: web.Request) -> web.Response:
view = await CREDENTIAL_STORE.get(request.match_info["id"])
if view is None:
return _error(404, "NOT_FOUND", "No such credential.")
return _ok(schemas_out.credential_to_dict(view))
@ROUTES.delete("/api/download/credentials/{id}")
async def delete_credential(request: web.Request) -> web.Response:
deleted = await CREDENTIAL_STORE.delete(request.match_info["id"])
if not deleted:
return _error(404, "NOT_FOUND", "No such credential.")
return _ok({"deleted": True})
# ----- single download by id (dynamic; registered last) -----
@ROUTES.get("/api/download/{id}")
async def get_download(request: web.Request) -> web.Response:
view = await DOWNLOAD_MANAGER.status(request.match_info["id"])
if view is None:
return _error(404, "NOT_FOUND", "No such download.")
return _ok(view)
@ROUTES.post("/api/download/{id}/pause")
async def pause(request: web.Request) -> web.Response:
try:
await DOWNLOAD_MANAGER.pause(request.match_info["id"])
except DownloadError as e:
return _from_download_error(e)
return _ok({"ok": True})
@ROUTES.post("/api/download/{id}/resume")
async def resume(request: web.Request) -> web.Response:
try:
await DOWNLOAD_MANAGER.resume(request.match_info["id"])
except DownloadError as e:
return _from_download_error(e)
return _ok({"ok": True})
@ROUTES.post("/api/download/{id}/cancel")
async def cancel(request: web.Request) -> web.Response:
try:
await DOWNLOAD_MANAGER.cancel(request.match_info["id"])
except DownloadError as e:
return _from_download_error(e)
return _ok({"ok": True})
@ROUTES.post("/api/download/{id}/priority")
async def set_priority(request: web.Request) -> web.Response:
parsed = await _parse(request, schemas_in.PriorityRequest)
if isinstance(parsed, web.Response):
return parsed
try:
await DOWNLOAD_MANAGER.set_priority(request.match_info["id"], parsed.priority)
except DownloadError as e:
return _from_download_error(e)
return _ok({"ok": True})

View File

@ -0,0 +1,51 @@
"""Request schemas for the download manager API.
Pydantic enforces shape at the boundary; handlers operate only on validated
values past that point.
"""
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel, Field
from app.model_downloader.constants import AUTH_SCHEME_BEARER
class EnqueueRequest(BaseModel):
url: str
model_id: str
priority: int = 0
expected_sha256: Optional[str] = None
allow_any_extension: bool = False
credential_id: Optional[str] = None
class PriorityRequest(BaseModel):
priority: int
class AvailabilityRequest(BaseModel):
"""``{model_id: url}`` — the URLs declared in the workflow JSON."""
models: dict[str, str] = Field(default_factory=dict)
class CredentialUpsertRequest(BaseModel):
host: str
secret: str
auth_scheme: str = AUTH_SCHEME_BEARER
header_name: Optional[str] = None
query_param: Optional[str] = None
label: Optional[str] = None
match_subdomains: bool = False
enabled: bool = True
__all__ = [
"EnqueueRequest",
"PriorityRequest",
"AvailabilityRequest",
"CredentialUpsertRequest",
]

View File

@ -0,0 +1,26 @@
"""Response helpers for the download manager API.
The download/status read models are plain dicts produced by the manager. This
module only needs to mask credentials for output (the secret is never returned).
"""
from __future__ import annotations
from app.model_downloader.credentials.store import CredentialView
def credential_to_dict(view: CredentialView) -> dict:
"""API-safe credential representation — never includes the secret."""
return {
"id": view.id,
"host": view.host,
"auth_scheme": view.auth_scheme,
"header_name": view.header_name,
"query_param": view.query_param,
"label": view.label,
"match_subdomains": view.match_subdomains,
"enabled": view.enabled,
"secret_last4": view.secret_last4,
"created_at": view.created_at,
"updated_at": view.updated_at,
}

View File

@ -0,0 +1,38 @@
"""Shared constants for the download manager.
Status values are persisted as TEXT in the ``downloads`` table; keep them
stable. The lifecycle is (PRD section 6):
queued -> active -> verifying -> completed
| |-> paused -> (resume) -> active
| |-> failed (network, retryable) -> queued (backoff)
|-> cancelled
"""
from __future__ import annotations
# Auth schemes for HostCredential (PRD section 9.4.1).
AUTH_SCHEME_BEARER = "bearer"
AUTH_SCHEME_HEADER = "header"
AUTH_SCHEME_QUERY = "query"
AUTH_SCHEMES = (AUTH_SCHEME_BEARER, AUTH_SCHEME_HEADER, AUTH_SCHEME_QUERY)
class DownloadStatus:
QUEUED = "queued"
ACTIVE = "active"
PAUSED = "paused"
VERIFYING = "verifying"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
#: States from which a worker is doing (or about to do) network I/O.
LIVE = (QUEUED, ACTIVE, VERIFYING)
#: Terminal states — the job will not transition again on its own.
TERMINAL = (COMPLETED, FAILED, CANCELLED)
# Default temp-file suffix. Distinctive so the startup orphan sweep only
# removes files THIS subsystem created, never unrelated *.tmp files.
TMP_SUFFIX = ".comfy-download.part"

View File

@ -0,0 +1,99 @@
"""Turn a stored credential into a per-hop request modifier (PRD section 9.4.2).
The critical rule: a credential is only ever attached when *the current hop's
host* matches a stored credential, and only over https. This is recomputed
from scratch on every redirect hop, so a token bound to ``huggingface.co`` is
silently dropped when the request is redirected to a presigned CDN host —
which is exactly what these hubs expect.
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from typing import Optional
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
from app.model_downloader.constants import (
AUTH_SCHEME_BEARER,
AUTH_SCHEME_HEADER,
AUTH_SCHEME_QUERY,
)
from app.model_downloader.credentials.store import normalize_host
from app.model_downloader.database import queries
from app.model_downloader.database.models import HostCredential
@dataclass
class RequestAuth:
"""How to modify a single request to carry a credential."""
headers: dict[str, str] = field(default_factory=dict)
query: dict[str, str] = field(default_factory=dict)
def apply_to_url(self, url: str) -> str:
if not self.query:
return url
parts = urlsplit(url)
params = dict(parse_qsl(parts.query, keep_blank_values=True))
params.update(self.query)
return urlunsplit(parts._replace(query=urlencode(params)))
def _matches(cred: HostCredential, hop_host: str) -> bool:
cred_host = cred.host
if hop_host == cred_host:
return True
if cred.match_subdomains:
# Label-boundary suffix: api.example.com matches example.com, but
# evil-example.com does NOT.
return hop_host.endswith("." + cred_host)
return False
def _build_auth(cred: HostCredential) -> RequestAuth:
if cred.auth_scheme == AUTH_SCHEME_BEARER:
return RequestAuth(headers={"Authorization": f"Bearer {cred.secret}"})
if cred.auth_scheme == AUTH_SCHEME_HEADER:
name = cred.header_name or "Authorization"
return RequestAuth(headers={name: cred.secret})
if cred.auth_scheme == AUTH_SCHEME_QUERY and cred.query_param:
return RequestAuth(query={cred.query_param: cred.secret})
return RequestAuth()
def _resolve_sync(
host: str, scheme: str, explicit_credential_id: Optional[str]
) -> Optional[RequestAuth]:
# Never attach a secret over a non-https hop (PRD section 9.4.2).
if scheme.lower() != "https":
return None
hop_host = normalize_host(host)
if not hop_host:
return None
if explicit_credential_id is not None:
cred = queries.get_credential(explicit_credential_id)
# An explicit credential is still subject to the per-hop host check —
# it is not forced onto a non-matching host.
if cred is None or not cred.enabled or not _matches(cred, hop_host):
return None
return _build_auth(cred)
# Auto-resolve: exact host first, then any subdomain-matching credential.
cred = queries.get_credential_by_host(hop_host)
if cred is not None and cred.enabled:
return _build_auth(cred)
for sub in queries.list_subdomain_credentials():
if sub.enabled and _matches(sub, hop_host):
return _build_auth(sub)
return None
async def resolve_auth_for_hop(
host: str, scheme: str, *, explicit_credential_id: Optional[str] = None
) -> Optional[RequestAuth]:
"""Resolve the credential (if any) to attach for one request hop."""
return await asyncio.to_thread(
_resolve_sync, host, scheme, explicit_credential_id
)

View File

@ -0,0 +1,137 @@
"""The credential store: one API key per host (PRD section 9.4).
Secrets are write-only over the API — :class:`CredentialView` carries only
masked metadata (``secret_last4`` + scheme + label), never the secret itself.
At-rest protection for v1 is filesystem permissions on the shared DB (the DB
is the trust boundary); encryption-at-rest is a noted future seam.
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Optional
from app.model_downloader.constants import (
AUTH_SCHEME_BEARER,
AUTH_SCHEME_HEADER,
AUTH_SCHEME_QUERY,
AUTH_SCHEMES,
)
from app.model_downloader.database import queries
from app.model_downloader.database.models import HostCredential
def normalize_host(host: str) -> str:
"""Lowercase, strip port, IDNA-encode (PRD section 9.4.3)."""
if not host:
return ""
host = host.strip().lower()
if host.startswith("[") and "]" in host: # bracketed IPv6 literal
host = host[1 : host.index("]")]
elif host.count(":") == 1: # host:port (not IPv6)
host = host.split(":", 1)[0]
try:
host = host.encode("idna").decode("ascii")
except (UnicodeError, ValueError):
pass
return host
@dataclass(frozen=True)
class CredentialView:
"""Masked, API-safe view of a credential — never includes the secret."""
id: str
host: str
auth_scheme: str
header_name: Optional[str]
query_param: Optional[str]
label: Optional[str]
match_subdomains: bool
enabled: bool
secret_last4: Optional[str]
created_at: int
updated_at: int
def _to_view(row: HostCredential) -> CredentialView:
return CredentialView(
id=row.id,
host=row.host,
auth_scheme=row.auth_scheme,
header_name=row.header_name,
query_param=row.query_param,
label=row.label,
match_subdomains=row.match_subdomains,
enabled=row.enabled,
secret_last4=row.secret_last4,
created_at=row.created_at,
updated_at=row.updated_at,
)
class CredentialValidationError(ValueError):
"""A credential upsert had inconsistent fields."""
class CredentialStore:
"""Async facade over the ``host_credentials`` table.
DB access is synchronous (SQLite) and offloaded via ``asyncio.to_thread``.
"""
async def upsert(
self,
host: str,
secret: str,
*,
auth_scheme: str = AUTH_SCHEME_BEARER,
header_name: Optional[str] = None,
query_param: Optional[str] = None,
label: Optional[str] = None,
match_subdomains: bool = False,
enabled: bool = True,
) -> CredentialView:
host = normalize_host(host)
if not host:
raise CredentialValidationError("host is required")
if not secret:
raise CredentialValidationError("secret is required")
if auth_scheme not in AUTH_SCHEMES:
raise CredentialValidationError(
f"auth_scheme must be one of {AUTH_SCHEMES}, got {auth_scheme!r}"
)
if auth_scheme == AUTH_SCHEME_HEADER and not header_name:
header_name = "Authorization"
if auth_scheme == AUTH_SCHEME_QUERY and not query_param:
raise CredentialValidationError(
"query_param is required when auth_scheme='query'"
)
values = {
"host": host,
"secret": secret,
"secret_last4": secret[-4:] if len(secret) >= 4 else secret,
"auth_scheme": auth_scheme,
"header_name": header_name,
"query_param": query_param,
"label": label,
"match_subdomains": match_subdomains,
"enabled": enabled,
}
row = await asyncio.to_thread(queries.upsert_credential, values)
return _to_view(row)
async def list(self) -> list[CredentialView]:
rows = await asyncio.to_thread(queries.list_credentials)
return [_to_view(r) for r in rows]
async def get(self, credential_id: str) -> Optional[CredentialView]:
row = await asyncio.to_thread(queries.get_credential, credential_id)
return _to_view(row) if row is not None else None
async def delete(self, credential_id: str) -> bool:
return await asyncio.to_thread(queries.delete_credential, credential_id)
CREDENTIAL_STORE = CredentialStore()

View File

@ -0,0 +1,162 @@
"""SQLAlchemy models for the download manager.
Three tables (PRD section 7):
- ``downloads`` one row per requested file (job + queue state).
- ``download_segments`` per-segment byte progress, for segmented resume.
- ``host_credentials`` one API key per host, reused across downloads.
The local file catalog / dedup index is NOT here — that is owned by the
assets system (``assets`` / ``asset_references``). On completion a finished
file is registered into the assets catalog; ``downloads`` is kept only as
job history.
"""
from __future__ import annotations
import time
import uuid
from sqlalchemy import (
BigInteger,
Boolean,
CheckConstraint,
ForeignKey,
Index,
Integer,
String,
Text,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database.models import Base
def _uuid() -> str:
return str(uuid.uuid4())
def _now() -> int:
return int(time.time())
class Download(Base):
__tablename__ = "downloads"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
# Original requested URL and the final URL after validated redirects.
url: Mapped[str] = mapped_column(Text, nullable=False)
final_url: Mapped[str | None] = mapped_column(Text, nullable=True)
# Canonical "<directory>/<filename>" identifier (resolved via folder_paths).
model_id: Mapped[str] = mapped_column(String(1024), nullable=False)
# Final on-disk location and the .part write target.
dest_path: Mapped[str] = mapped_column(Text, nullable=False)
temp_path: Mapped[str] = mapped_column(Text, nullable=False)
status: Mapped[str] = mapped_column(String(16), nullable=False)
priority: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
total_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
bytes_done: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
etag: Mapped[str | None] = mapped_column(String(512), nullable=True)
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
accept_ranges: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
# Optional hub-provided checksum to verify against (NOT the dedup key).
expected_sha256: Mapped[str | None] = mapped_column(String(64), nullable=True)
# Explicit credential override; otherwise auto-resolved by host.
credential_id: Mapped[str | None] = mapped_column(String(36), nullable=True)
allow_any_extension: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
# How many retryable failures we have seen (for backoff capping).
attempts: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
error: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[int] = mapped_column(BigInteger, nullable=False, default=_now)
updated_at: Mapped[int] = mapped_column(
BigInteger, nullable=False, default=_now, onupdate=_now
)
segments: Mapped[list[DownloadSegment]] = relationship(
"DownloadSegment",
back_populates="download",
cascade="all,delete-orphan",
passive_deletes=True,
order_by="DownloadSegment.idx",
)
__table_args__ = (
Index("ix_downloads_status", "status"),
Index("ix_downloads_priority", "priority"),
Index("ix_downloads_model_id", "model_id"),
CheckConstraint("bytes_done >= 0", name="ck_downloads_bytes_done_nonneg"),
CheckConstraint(
"total_bytes IS NULL OR total_bytes >= 0",
name="ck_downloads_total_bytes_nonneg",
),
)
def __repr__(self) -> str:
return f"<Download id={self.id} model_id={self.model_id!r} status={self.status}>"
class DownloadSegment(Base):
__tablename__ = "download_segments"
download_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("downloads.id", ondelete="CASCADE"),
primary_key=True,
)
idx: Mapped[int] = mapped_column(Integer, primary_key=True)
start_offset: Mapped[int] = mapped_column(BigInteger, nullable=False)
end_offset: Mapped[int] = mapped_column(BigInteger, nullable=False)
bytes_done: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
download: Mapped[Download] = relationship("Download", back_populates="segments")
__table_args__ = (
CheckConstraint("bytes_done >= 0", name="ck_segments_bytes_done_nonneg"),
CheckConstraint("end_offset >= start_offset", name="ck_segments_range"),
)
def __repr__(self) -> str:
return (
f"<DownloadSegment {self.download_id}#{self.idx} "
f"{self.start_offset}-{self.end_offset} done={self.bytes_done}>"
)
class HostCredential(Base):
__tablename__ = "host_credentials"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid)
# Normalized lowercase hostname, e.g. "civitai.com".
host: Mapped[str] = mapped_column(String(255), nullable=False)
match_subdomains: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
label: Mapped[str | None] = mapped_column(String(255), nullable=True)
auth_scheme: Mapped[str] = mapped_column(
String(16), nullable=False, default="bearer"
)
header_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
query_param: Mapped[str | None] = mapped_column(String(255), nullable=True)
# The API key itself. Write-only over the API; never returned. See PRD 9.4.4.
secret: Mapped[str] = mapped_column(Text, nullable=False)
secret_last4: Mapped[str | None] = mapped_column(String(4), nullable=True)
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
created_at: Mapped[int] = mapped_column(BigInteger, nullable=False, default=_now)
updated_at: Mapped[int] = mapped_column(
BigInteger, nullable=False, default=_now, onupdate=_now
)
__table_args__ = (
Index("uq_host_credentials_host", "host", unique=True),
)
def __repr__(self) -> str:
return f"<HostCredential id={self.id} host={self.host!r} scheme={self.auth_scheme}>"

View File

@ -0,0 +1,235 @@
"""Synchronous DB access for the download manager.
All functions open their own short-lived session via ``create_session`` and
commit before returning, mirroring ``app/assets`` usage. They are blocking
(SQLite) and should be called from async code through ``asyncio.to_thread``.
"""
from __future__ import annotations
import time
from typing import Optional
from sqlalchemy import select
from app.database.db import create_session
from app.model_downloader.constants import DownloadStatus
from app.model_downloader.database.models import (
Download,
DownloadSegment,
HostCredential,
)
# ----- downloads -----
def insert_download(values: dict) -> None:
with create_session() as session:
session.add(Download(**values))
session.commit()
def get_download(download_id: str) -> Optional[Download]:
with create_session() as session:
row = session.get(Download, download_id)
if row is not None:
session.expunge_all()
return row
def list_downloads() -> list[Download]:
with create_session() as session:
rows = list(
session.execute(
select(Download).order_by(Download.created_at.desc())
).scalars()
)
session.expunge_all()
return rows
def list_segments(download_id: str) -> list[DownloadSegment]:
with create_session() as session:
rows = list(
session.execute(
select(DownloadSegment)
.where(DownloadSegment.download_id == download_id)
.order_by(DownloadSegment.idx)
).scalars()
)
session.expunge_all()
return rows
def update_download(download_id: str, **fields) -> None:
if not fields:
return
fields.setdefault("updated_at", int(time.time()))
with create_session() as session:
row = session.get(Download, download_id)
if row is None:
return
for key, value in fields.items():
setattr(row, key, value)
session.commit()
def delete_download(download_id: str) -> None:
with create_session() as session:
row = session.get(Download, download_id)
if row is not None:
session.delete(row)
session.commit()
def replace_segments(download_id: str, segments: list[dict]) -> None:
"""Atomically replace the segment plan for a download."""
with create_session() as session:
session.query(DownloadSegment).filter(
DownloadSegment.download_id == download_id
).delete()
for seg in segments:
session.add(DownloadSegment(download_id=download_id, **seg))
session.commit()
def update_segment_progress(download_id: str, idx: int, bytes_done: int) -> None:
with create_session() as session:
row = session.get(DownloadSegment, {"download_id": download_id, "idx": idx})
if row is None:
return
row.bytes_done = bytes_done
session.commit()
def list_queued_downloads() -> list[Download]:
"""Queued rows ordered for admission (priority desc, then FIFO)."""
with create_session() as session:
rows = list(
session.execute(
select(Download)
.where(Download.status == DownloadStatus.QUEUED)
.order_by(Download.priority.desc(), Download.created_at.asc())
).scalars()
)
session.expunge_all()
return rows
def reconcile_live_downloads() -> list[Download]:
"""Reset any ``active``/``verifying`` rows left by a previous run.
On a clean restart there can be no live worker, so anything still marked
live is stale. Move it back to ``queued`` (offsets are preserved on the
segment rows) so the scheduler re-admits it. Returns the rows that should
be re-queued by the scheduler (queued + paused).
"""
with create_session() as session:
stale = list(
session.execute(
select(Download).where(
Download.status.in_([DownloadStatus.ACTIVE, DownloadStatus.VERIFYING])
)
).scalars()
)
now = int(time.time())
for row in stale:
row.status = DownloadStatus.QUEUED
row.updated_at = now
session.commit()
resumable = list(
session.execute(
select(Download)
.where(Download.status == DownloadStatus.QUEUED)
.order_by(Download.priority.desc(), Download.created_at.asc())
).scalars()
)
session.expunge_all()
return resumable
# ----- host credentials -----
def get_credential(credential_id: str) -> Optional[HostCredential]:
with create_session() as session:
row = session.get(HostCredential, credential_id)
if row is not None:
session.expunge_all()
return row
def get_credential_by_host(host: str) -> Optional[HostCredential]:
with create_session() as session:
row = (
session.execute(
select(HostCredential).where(HostCredential.host == host).limit(1)
)
.scalars()
.first()
)
if row is not None:
session.expunge_all()
return row
def list_credentials() -> list[HostCredential]:
with create_session() as session:
rows = list(
session.execute(
select(HostCredential).order_by(HostCredential.host)
).scalars()
)
session.expunge_all()
return rows
def list_subdomain_credentials() -> list[HostCredential]:
"""Credentials that opted into subdomain matching, for suffix checks."""
with create_session() as session:
rows = list(
session.execute(
select(HostCredential).where(HostCredential.match_subdomains.is_(True))
).scalars()
)
session.expunge_all()
return rows
def upsert_credential(values: dict) -> HostCredential:
"""Insert or update a credential keyed by ``host``."""
host = values["host"]
now = int(time.time())
with create_session() as session:
row = (
session.execute(
select(HostCredential).where(HostCredential.host == host).limit(1)
)
.scalars()
.first()
)
if row is None:
row = HostCredential(**values)
row.created_at = now
row.updated_at = now
session.add(row)
else:
for key, value in values.items():
setattr(row, key, value)
row.updated_at = now
session.commit()
session.refresh(row)
session.expunge(row)
return row
def delete_credential(credential_id: str) -> bool:
with create_session() as session:
row = session.get(HostCredential, credential_id)
if row is None:
return False
session.delete(row)
session.commit()
return True

View File

@ -0,0 +1,443 @@
"""The per-download worker (PRD sections 5, 6, 8, 12).
One :class:`DownloadJob` drives a single file from probe to verified, cataloged
completion. It supports cooperative pause / resume / cancel, segmented
multi-connection transfer with positioned writes, and a verification gate
(size + structural + optional sha256) before the atomic rename into place.
Control is cooperative: external callers flip ``_control`` via
:meth:`request_pause` / :meth:`request_cancel`; segment loops observe it between
chunks and raise, which unwinds cleanly and persists resume offsets.
"""
from __future__ import annotations
import asyncio
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Callable, Optional
from comfy.cli_args import args
from app.model_downloader.constants import DownloadStatus
from app.model_downloader.database import queries
from app.model_downloader.engine.planner import (
SegmentPlan,
effective_segment_count,
plan_segments,
)
from app.model_downloader.engine.writer import FileWriter
from app.model_downloader.net.http import open_validated
from app.model_downloader.net.probe import probe
from app.model_downloader.verify import checksum, dedup, structural
_RETRYABLE_STATUSES = {408, 429, 500, 502, 503, 504}
_PERSIST_INTERVAL = 2.0 # seconds between throttled progress persists
class Paused(Exception):
pass
class Cancelled(Exception):
pass
class RemoteChanged(Exception):
"""The remote file changed under a resume (got 200 where 206 expected)."""
class RetryableError(Exception):
pass
class FatalError(Exception):
"""Non-retryable: 4xx, checksum mismatch, structural failure, gated, etc."""
@dataclass
class SegmentRuntime:
idx: int
start: int
end: int # inclusive; may be -1 for unknown-size single stream
bytes_done: int = 0
@property
def length(self) -> int:
return self.end - self.start + 1
@dataclass
class RuntimeState:
download_id: str
model_id: str
url: str
priority: int
status: str
total_bytes: Optional[int] = None
bytes_done: int = 0
error: Optional[str] = None
segments: list[SegmentRuntime] = field(default_factory=list)
started_at: float = field(default_factory=time.monotonic)
_last_bytes: int = 0
_last_time: float = field(default_factory=time.monotonic)
speed_bps: float = 0.0
@property
def progress(self) -> Optional[float]:
if not self.total_bytes:
return None
return min(1.0, self.bytes_done / self.total_bytes)
@property
def eta_seconds(self) -> Optional[float]:
if not self.total_bytes or self.speed_bps <= 0:
return None
remaining = max(0, self.total_bytes - self.bytes_done)
return remaining / self.speed_bps
@dataclass
class JobSpec:
download_id: str
url: str
model_id: str
dest_path: str
temp_path: str
priority: int = 0
credential_id: Optional[str] = None
expected_sha256: Optional[str] = None
allow_any_extension: bool = False
etag: Optional[str] = None
attempts: int = 0
class DownloadJob:
def __init__(
self, spec: JobSpec, notify_cb: Optional[Callable[[str], None]] = None
) -> None:
self.spec = spec
self._notify = notify_cb
self._control = "run" # run | pause | cancel
self.state = RuntimeState(
download_id=spec.download_id,
model_id=spec.model_id,
url=spec.url,
priority=spec.priority,
status=DownloadStatus.QUEUED,
)
self._writer: Optional[FileWriter] = None
self._etag: Optional[str] = spec.etag
self._last_persist = 0.0
# ----- external control -----
def request_pause(self) -> None:
if self._control == "run":
self._control = "pause"
def request_cancel(self) -> None:
self._control = "cancel"
def _check_control(self) -> None:
if self._control == "cancel":
raise Cancelled()
if self._control == "pause":
raise Paused()
# ----- lifecycle -----
async def run(self) -> str:
"""Run to a terminal/paused state; returns the final status string."""
self._set_status(DownloadStatus.ACTIVE, error=None)
try:
pr = await self._probe_and_plan()
await self._transfer(pr)
await self._finalize()
self._set_status(DownloadStatus.COMPLETED)
except Paused:
await self._persist_progress(force=True)
self._set_status(DownloadStatus.PAUSED)
except Cancelled:
await self._close_writer()
self._remove_temp()
self._set_status(DownloadStatus.CANCELLED)
except RemoteChanged:
await self._reset_for_restart()
self._set_status(
DownloadStatus.QUEUED, error="remote file changed; restarting"
)
except RetryableError as e:
await self._persist_progress(force=True)
self._set_status(DownloadStatus.QUEUED, error=str(e))
except FatalError as e:
await self._close_writer()
self._remove_temp()
self._set_status(DownloadStatus.FAILED, error=str(e))
except Exception as e: # unexpected -> treat as retryable
logging.warning(
"[model_downloader] %s unexpected error: %s",
self.spec.model_id, e, exc_info=True,
)
await self._persist_progress(force=True)
self._set_status(DownloadStatus.QUEUED, error=f"{type(e).__name__}: {e}")
finally:
await self._close_writer()
return self.state.status
# ----- probe + plan -----
async def _probe_and_plan(self):
pr = await probe(self.spec.url, credential_id=self.spec.credential_id)
if not pr.ok:
if pr.gated:
raise FatalError(
f"{self.spec.url} requires authentication. Add an API key for "
f"this host at /api/download/credentials and retry."
)
if pr.status == 0 or pr.status in _RETRYABLE_STATUSES:
raise RetryableError(pr.error or "probe failed")
raise FatalError(pr.error or f"probe returned HTTP {pr.status}")
self._etag = pr.etag or self._etag
self.state.total_bytes = pr.total_bytes
queries.update_download(
self.spec.download_id,
final_url=pr.final_url,
total_bytes=pr.total_bytes,
accept_ranges=pr.accept_ranges,
etag=pr.etag,
last_modified=pr.last_modified,
)
seg_count = effective_segment_count(
pr.total_bytes, pr.accept_ranges, max(1, args.download_segments)
)
existing = queries.list_segments(self.spec.download_id)
if (
seg_count > 1
and existing
and pr.total_bytes is not None
and existing[-1].end_offset == pr.total_bytes - 1
):
# Resume an existing segmented plan.
self.state.segments = [
SegmentRuntime(s.idx, s.start_offset, s.end_offset, s.bytes_done)
for s in existing
]
elif seg_count > 1 and pr.total_bytes is not None:
plans = plan_segments(pr.total_bytes, seg_count)
queries.replace_segments(
self.spec.download_id,
[
{"idx": p.idx, "start_offset": p.start, "end_offset": p.end, "bytes_done": 0}
for p in plans
],
)
self.state.segments = [SegmentRuntime(p.idx, p.start, p.end, 0) for p in plans]
else:
# Single-stream: one logical segment; bytes_done tracked on the row.
row = queries.get_download(self.spec.download_id)
resume_from = row.bytes_done if row else 0
end = (pr.total_bytes - 1) if pr.total_bytes else -1
self.state.segments = [SegmentRuntime(0, 0, end, resume_from)]
self._recompute_bytes_done()
return pr
# ----- transfer -----
async def _transfer(self, pr) -> None:
self._writer = FileWriter(self.spec.temp_path)
await self._writer.open()
segmented = len(self.state.segments) > 1
if segmented and self.state.total_bytes:
await self._writer.preallocate(self.state.total_bytes)
await self._run_segmented()
else:
await self._run_single()
await self._writer.flush()
async def _run_segmented(self) -> None:
pending = [
asyncio.ensure_future(self._run_segment(seg))
for seg in self.state.segments
if seg.bytes_done < seg.length
]
if not pending:
return
done, not_done = await asyncio.wait(
pending, return_when=asyncio.FIRST_EXCEPTION
)
first_exc: Optional[BaseException] = None
for task in done:
exc = task.exception()
if exc is not None and first_exc is None:
first_exc = exc
if first_exc is not None:
for task in not_done:
task.cancel()
await asyncio.gather(*not_done, return_exceptions=True)
raise first_exc
async def _run_segment(self, seg: SegmentRuntime) -> None:
offset = seg.start + seg.bytes_done
headers = {
"Range": f"bytes={offset}-{seg.end}",
"Accept-Encoding": "identity",
}
if self._etag:
headers["If-Range"] = self._etag
async with open_validated(
"GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers
) as (resp, _final):
if resp.status == 200:
# Server ignored the range -> remote changed / no resume support.
raise RemoteChanged()
if resp.status not in (206,):
self._raise_for_status(resp.status)
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
self._check_control()
await self._writer.write_at(offset, chunk)
offset += len(chunk)
seg.bytes_done += len(chunk)
self._recompute_bytes_done()
await self._persist_progress()
async def _run_single(self) -> None:
seg = self.state.segments[0]
offset = seg.bytes_done # resume from here for single-stream
headers = {"Accept-Encoding": "identity"}
if offset > 0:
headers["Range"] = f"bytes={offset}-"
if self._etag:
headers["If-Range"] = self._etag
async with open_validated(
"GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers
) as (resp, _final):
if offset > 0 and resp.status == 200:
# Resume not honoured -> start over from the beginning.
offset = 0
seg.bytes_done = 0
elif offset > 0 and resp.status != 206:
self._raise_for_status(resp.status)
elif offset == 0 and resp.status != 200:
self._raise_for_status(resp.status)
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
self._check_control()
await self._writer.write_at(offset, chunk)
offset += len(chunk)
seg.bytes_done = offset
self.state.bytes_done = offset
await self._persist_progress()
def _raise_for_status(self, status: int) -> None:
if status in (401, 403):
raise FatalError(
f"{self.spec.url} returned {status}; add/update an API key for "
f"this host at /api/download/credentials."
)
if status in _RETRYABLE_STATUSES:
raise RetryableError(f"HTTP {status}")
raise FatalError(f"unexpected HTTP {status}")
# ----- finalize / verify (PRD section 8.4) -----
async def _finalize(self) -> None:
self._check_control()
await self._close_writer()
self._set_status(DownloadStatus.VERIFYING)
total = self.state.total_bytes
actual_size = os.path.getsize(self.spec.temp_path)
if total is not None and actual_size != total:
raise FatalError(
f"size mismatch: wrote {actual_size} of {total} bytes"
)
# Structural gate (cheap, no full read) then optional sha256 (full read).
await asyncio.to_thread(structural.validate, self.spec.temp_path)
if self.spec.expected_sha256:
await asyncio.to_thread(
checksum.verify_sha256, self.spec.temp_path, self.spec.expected_sha256
)
os.makedirs(os.path.dirname(self.spec.dest_path), exist_ok=True)
os.replace(self.spec.temp_path, self.spec.dest_path)
logging.info(
"[model_downloader] completed %s (%d bytes)",
self.spec.model_id, actual_size,
)
# Catalog into the assets system (blake3 dedup identity). Best-effort.
await dedup.register_completed(self.spec.dest_path)
# ----- helpers -----
def _recompute_bytes_done(self) -> None:
self.state.bytes_done = sum(s.bytes_done for s in self.state.segments)
now = time.monotonic()
dt = now - self.state._last_time
if dt >= 0.5:
self.state.speed_bps = (self.state.bytes_done - self.state._last_bytes) / dt
self.state._last_bytes = self.state.bytes_done
self.state._last_time = now
async def _persist_progress(self, force: bool = False) -> None:
now = time.monotonic()
if not force and now - self._last_persist < _PERSIST_INTERVAL:
if self._notify:
self._notify(self.spec.download_id)
return
self._last_persist = now
queries.update_download(self.spec.download_id, bytes_done=self.state.bytes_done)
for seg in self.state.segments:
if seg.end >= seg.start: # skip unknown-size sentinel
queries.update_segment_progress(
self.spec.download_id, seg.idx, seg.bytes_done
)
if self._notify:
self._notify(self.spec.download_id)
async def _reset_for_restart(self) -> None:
await self._close_writer()
self._remove_temp()
for seg in self.state.segments:
seg.bytes_done = 0
self.state.bytes_done = 0
queries.update_download(self.spec.download_id, bytes_done=0)
if queries.list_segments(self.spec.download_id):
queries.replace_segments(self.spec.download_id, [])
async def _close_writer(self) -> None:
if self._writer is not None:
try:
await self._writer.close()
except Exception:
logging.debug("[model_downloader] writer close error", exc_info=True)
self._writer = None
def _remove_temp(self) -> None:
try:
os.remove(self.spec.temp_path)
except FileNotFoundError:
pass
except OSError as e:
logging.warning(
"[model_downloader] could not remove %s: %s", self.spec.temp_path, e
)
def _set_status(self, status: str, error: Optional[str] = None) -> None:
self.state.status = status
if error is not None:
self.state.error = error
fields = {"status": status, "bytes_done": self.state.bytes_done}
if error is not None:
fields["error"] = error
if status == DownloadStatus.QUEUED:
fields["attempts"] = self.spec.attempts + 1
self.spec.attempts += 1
queries.update_download(self.spec.download_id, **fields)
if self._notify:
self._notify(self.spec.download_id)

View File

@ -0,0 +1,51 @@
"""Segment planning (PRD section 5.2).
Split a known byte range into S roughly-equal segments, each fetched by its
own coroutine with ``Range: bytes=start-end``. Falls back to a single segment
when the server doesn't support ranges or the size is unknown/too small for
segmentation to be worthwhile.
"""
from __future__ import annotations
from dataclasses import dataclass
# Below this size, the per-connection setup cost outweighs any parallelism.
_MIN_SEGMENT_BYTES = 1 * 1024 * 1024
@dataclass(frozen=True)
class SegmentPlan:
idx: int
start: int
end: int # inclusive
@property
def length(self) -> int:
return self.end - self.start + 1
def effective_segment_count(
total_bytes: int | None, accept_ranges: bool, configured: int
) -> int:
"""How many segments to actually use for this file."""
if not accept_ranges or total_bytes is None or total_bytes <= 0:
return 1
by_size = max(1, total_bytes // _MIN_SEGMENT_BYTES)
return max(1, min(configured, by_size))
def plan_segments(total_bytes: int, num_segments: int) -> list[SegmentPlan]:
"""Return ``num_segments`` contiguous, inclusive byte ranges covering [0, total)."""
if total_bytes <= 0 or num_segments <= 1:
return [SegmentPlan(idx=0, start=0, end=max(0, total_bytes - 1))]
base = total_bytes // num_segments
plans: list[SegmentPlan] = []
start = 0
for i in range(num_segments):
# Last segment soaks up the remainder.
length = base if i < num_segments - 1 else total_bytes - start
end = start + length - 1
plans.append(SegmentPlan(idx=i, start=start, end=end))
start = end + 1
return plans

View File

@ -0,0 +1,61 @@
"""Positioned, off-loop file writes (PRD section 4 + 5.2).
Network I/O stays on the event loop; every blocking disk op (preallocate,
positioned write, fsync) is run in a bounded thread pool via
``run_in_executor`` so downloads never stall inference or the web server.
A single file descriptor is opened for the whole download. Segments write to
their own offsets with ``os.pwrite`` — which is offset-addressed and atomic
per call, so concurrent segment writers need no extra locking. Per-chunk
fsync is avoided; we fsync once at completion.
"""
from __future__ import annotations
import asyncio
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
# One shared, bounded pool for all download disk I/O.
_EXECUTOR = ThreadPoolExecutor(max_workers=8, thread_name_prefix="dl-writer")
class FileWriter:
"""Owns the ``.part`` file descriptor for one download."""
def __init__(self, path: str) -> None:
self.path = path
self._fd: Optional[int] = None
def _open(self) -> None:
os.makedirs(os.path.dirname(self.path), exist_ok=True)
self._fd = os.open(self.path, os.O_RDWR | os.O_CREAT, 0o644)
async def open(self) -> None:
await asyncio.get_running_loop().run_in_executor(_EXECUTOR, self._open)
async def preallocate(self, size: int) -> None:
"""Grow the file to ``size`` so segments write to their offsets."""
if self._fd is None or size <= 0:
return
await asyncio.get_running_loop().run_in_executor(
_EXECUTOR, os.ftruncate, self._fd, size
)
async def write_at(self, offset: int, data: bytes) -> None:
assert self._fd is not None, "writer not opened"
await asyncio.get_running_loop().run_in_executor(
_EXECUTOR, os.pwrite, self._fd, data, offset
)
async def flush(self) -> None:
if self._fd is None:
return
await asyncio.get_running_loop().run_in_executor(_EXECUTOR, os.fsync, self._fd)
async def close(self) -> None:
if self._fd is None:
return
fd, self._fd = self._fd, None
await asyncio.get_running_loop().run_in_executor(_EXECUTOR, os.close, fd)

View File

@ -0,0 +1,294 @@
"""Public facade for the download manager (PRD section 10).
This is the only object the server imports. It validates requests, owns the
:class:`Scheduler`, and exposes a small async API plus read models for status.
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from typing import Callable, Optional
from app.model_downloader.constants import DownloadStatus
from app.model_downloader.database import queries
from app.model_downloader.scheduler import SCHEDULER
from app.model_downloader.security import paths
from app.model_downloader.security.allowlist import is_url_allowed
from app.model_downloader.security.paths import InvalidModelId
# Non-terminal statuses: an existing row in one of these blocks a re-enqueue.
_LIVE_STATUSES = (
DownloadStatus.QUEUED,
DownloadStatus.ACTIVE,
DownloadStatus.PAUSED,
DownloadStatus.VERIFYING,
)
class DownloadError(Exception):
"""A user-facing error with a stable machine-readable code."""
def __init__(self, code: str, message: str, status: int = 400) -> None:
super().__init__(message)
self.code = code
self.message = message
self.http_status = status
class DownloadManager:
def __init__(self) -> None:
self._scheduler = SCHEDULER
self._notify_cb: Optional[Callable[[str], None]] = None
def set_notify(self, cb: Optional[Callable[[str], None]]) -> None:
self._notify_cb = cb
self._scheduler.set_notify(cb)
async def start(self) -> None:
await self._scheduler.start()
# ----- enqueue -----
async def enqueue(
self,
url: str,
model_id: str,
*,
priority: int = 0,
expected_sha256: Optional[str] = None,
allow_any_extension: bool = False,
credential_id: Optional[str] = None,
) -> str:
if not is_url_allowed(url, allow_any_extension):
raise DownloadError(
"URL_NOT_ALLOWED",
"URL is not on the download allowlist (host/scheme/extension).",
)
try:
paths.parse_model_id(model_id, allow_any_extension)
dest_path, temp_path = paths.resolve_destination(model_id, allow_any_extension)
except InvalidModelId as e:
raise DownloadError("INVALID_MODEL_ID", str(e))
if await asyncio.to_thread(
paths.resolve_existing, model_id, allow_any_extension
):
raise DownloadError(
"ALREADY_AVAILABLE",
f"Model already exists on disk: {model_id}",
status=409,
)
if await self._has_live_download(model_id):
raise DownloadError(
"ALREADY_DOWNLOADING",
f"A download for {model_id} is already in progress.",
status=409,
)
download_id = str(uuid.uuid4())
await asyncio.to_thread(
queries.insert_download,
{
"id": download_id,
"url": url,
"model_id": model_id,
"dest_path": dest_path,
"temp_path": temp_path,
"status": DownloadStatus.QUEUED,
"priority": priority,
"expected_sha256": expected_sha256,
"credential_id": credential_id,
"allow_any_extension": allow_any_extension,
},
)
logging.info("[model_downloader] enqueued %s -> %s", url, model_id)
await self._scheduler.pump()
return download_id
async def _has_live_download(self, model_id: str) -> bool:
rows = await asyncio.to_thread(queries.list_downloads)
return any(
r.model_id == model_id and r.status in _LIVE_STATUSES for r in rows
)
# ----- control -----
async def pause(self, download_id: str) -> None:
job = self._scheduler.get_job(download_id)
if job is not None:
job.request_pause()
return
row = await asyncio.to_thread(queries.get_download, download_id)
if row is None:
raise DownloadError("NOT_FOUND", "No such download.", status=404)
if row.status == DownloadStatus.QUEUED:
await asyncio.to_thread(
queries.update_download, download_id, status=DownloadStatus.PAUSED
)
async def resume(self, download_id: str) -> None:
row = await asyncio.to_thread(queries.get_download, download_id)
if row is None:
raise DownloadError("NOT_FOUND", "No such download.", status=404)
if row.status in (DownloadStatus.PAUSED, DownloadStatus.FAILED):
await asyncio.to_thread(
queries.update_download,
download_id,
status=DownloadStatus.QUEUED,
error=None,
)
await self._scheduler.pump()
async def cancel(self, download_id: str) -> None:
job = self._scheduler.get_job(download_id)
if job is not None:
job.request_cancel()
return
row = await asyncio.to_thread(queries.get_download, download_id)
if row is None:
raise DownloadError("NOT_FOUND", "No such download.", status=404)
if row.status in _LIVE_STATUSES:
import os
try:
os.remove(row.temp_path)
except OSError:
pass
await asyncio.to_thread(
queries.update_download, download_id, status=DownloadStatus.CANCELLED
)
async def set_priority(self, download_id: str, priority: int) -> None:
row = await asyncio.to_thread(queries.get_download, download_id)
if row is None:
raise DownloadError("NOT_FOUND", "No such download.", status=404)
await asyncio.to_thread(
queries.update_download, download_id, priority=priority
)
# Admission-order only (PRD section 13 default); a higher priority is
# picked up the next time a slot frees. Pump in case a slot is free now.
await self._scheduler.pump()
# ----- read models -----
def _view(self, row) -> dict:
"""Combine the persisted row with live in-memory progress, if running."""
job = self._scheduler.get_job(row.id)
bytes_done = row.bytes_done
total = row.total_bytes
speed = None
eta = None
segments = None
if job is not None:
st = job.state
bytes_done = st.bytes_done
total = st.total_bytes if st.total_bytes is not None else total
speed = st.speed_bps
eta = st.eta_seconds
segments = [
{"idx": s.idx, "bytes_done": s.bytes_done, "length": s.length}
for s in st.segments
if s.end >= s.start
]
progress = (bytes_done / total) if total else None
return {
"download_id": row.id,
"model_id": row.model_id,
"url": row.url,
"status": row.status,
"priority": row.priority,
"total_bytes": total,
"bytes_done": bytes_done,
"progress": progress,
"speed_bps": speed,
"eta_seconds": eta,
"segments": segments,
"error": row.error,
"created_at": row.created_at,
"updated_at": row.updated_at,
}
def _view_from_state(self, job) -> dict:
"""Build a view purely from the live in-memory job state (no DB)."""
st = job.state
return {
"download_id": st.download_id,
"model_id": st.model_id,
"url": st.url,
"status": st.status,
"priority": st.priority,
"total_bytes": st.total_bytes,
"bytes_done": st.bytes_done,
"progress": st.progress,
"speed_bps": st.speed_bps,
"eta_seconds": st.eta_seconds,
"segments": [
{"idx": s.idx, "bytes_done": s.bytes_done, "length": s.length}
for s in st.segments
if s.end >= s.start
],
"error": st.error,
}
def status_sync(self, download_id: str) -> Optional[dict]:
"""Synchronous status read for the websocket notify path.
Uses live in-memory state when the job is running (no DB round-trip on
the hot path); falls back to a quick DB read otherwise.
"""
job = self._scheduler.get_job(download_id)
if job is not None:
return self._view_from_state(job)
row = queries.get_download(download_id)
return self._view(row) if row is not None else None
async def status(self, download_id: str) -> Optional[dict]:
row = await asyncio.to_thread(queries.get_download, download_id)
return self._view(row) if row is not None else None
async def list(self) -> list[dict]:
rows = await asyncio.to_thread(queries.list_downloads)
return [self._view(r) for r in rows]
async def availability(self, models: dict[str, str]) -> dict[str, dict]:
"""Bulk per-id ``{state, progress, ...}`` for the frontend poll.
``state`` is ``available`` (on disk), ``downloading`` (live row), or
``missing``. Cheap: a path lookup plus an in-memory/DB status check.
"""
rows = await asyncio.to_thread(queries.list_downloads)
by_model: dict[str, object] = {}
for r in rows:
if r.status in _LIVE_STATUSES or r.model_id not in by_model:
by_model[r.model_id] = r
out: dict[str, dict] = {}
for model_id, url in models.items():
try:
exists = await asyncio.to_thread(paths.resolve_existing, model_id)
except InvalidModelId:
out[model_id] = {"state": "missing", "url_allowed": is_url_allowed(url)}
continue
if exists:
out[model_id] = {"state": "available", "url_allowed": is_url_allowed(url)}
continue
row = by_model.get(model_id)
if row is not None and row.status in _LIVE_STATUSES:
view = self._view(row)
out[model_id] = {
"state": "downloading",
"url_allowed": is_url_allowed(url),
"download_id": view["download_id"],
"progress": view["progress"],
"bytes_done": view["bytes_done"],
"total_bytes": view["total_bytes"],
"speed_bps": view["speed_bps"],
}
else:
out[model_id] = {"state": "missing", "url_allowed": is_url_allowed(url)}
return out
DOWNLOAD_MANAGER = DownloadManager()

View File

@ -0,0 +1,110 @@
"""Manual, validated redirect-following request opener.
Automatic redirects are disabled (PRD section 9.2): we follow hops ourselves
so that on *every* hop we (a) re-validate scheme + reject credentials-in-URL,
(b) recompute which stored credential — if any — applies to that hop's host,
and (c) let the connector's resolver screen the IP. This is the single place
that attaches credentials, so a token can never ride a redirect to a CDN host.
"""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from typing import AsyncIterator, Optional
from urllib.parse import urljoin, urlsplit, urlunsplit
import aiohttp
from app.model_downloader.credentials.resolver import resolve_auth_for_hop
from app.model_downloader.net.session import get_session
from app.model_downloader.security.ssrf import (
MAX_REDIRECTS,
SSRFError,
check_redirect_hop,
)
_REDIRECT_CODES = {301, 302, 303, 307, 308}
DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120)
def redact_url(url: str) -> str:
"""Drop the query string so a query-scheme secret is never logged/stored."""
try:
parts = urlsplit(url)
except ValueError:
return "<unparseable-url>"
return urlunsplit(parts._replace(query=""))
async def _resolve_final_response(
method: str,
url: str,
credential_id: Optional[str],
base_headers: dict[str, str],
timeout: aiohttp.ClientTimeout,
) -> tuple[aiohttp.ClientResponse, str]:
"""Follow redirects manually until a non-redirect response.
Each intermediate redirect response is released before the next hop.
Returns the final ``(response, final_url)``; the caller owns releasing it.
"""
session = await get_session()
current = url
hops = 0
while True:
check_redirect_hop(current)
parts = urlsplit(current)
auth = await resolve_auth_for_hop(
parts.hostname or "", parts.scheme, explicit_credential_id=credential_id
)
req_headers = dict(base_headers)
req_url = current
if auth is not None:
req_headers.update(auth.headers)
req_url = auth.apply_to_url(current)
resp = await session.request(
method,
req_url,
allow_redirects=False,
headers=req_headers,
timeout=timeout,
)
if resp.status in _REDIRECT_CODES and resp.headers.get("Location"):
next_url = urljoin(str(resp.url), resp.headers["Location"])
await resp.release()
hops += 1
if hops > MAX_REDIRECTS:
raise SSRFError(
f"too many redirects (> {MAX_REDIRECTS}) for {redact_url(url)}"
)
current = next_url
continue
return resp, redact_url(str(resp.url))
@asynccontextmanager
async def open_validated(
method: str,
url: str,
*,
credential_id: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
timeout: aiohttp.ClientTimeout = DEFAULT_TIMEOUT,
) -> AsyncIterator[tuple[aiohttp.ClientResponse, str]]:
"""Open ``method url`` following redirects manually and validated.
Yields ``(response, final_url)`` where ``final_url`` is redacted of any
query string. The response is released automatically on exit.
"""
resp, final_url = await _resolve_final_response(
method, url, credential_id, dict(headers or {}), timeout
)
try:
yield resp, final_url
finally:
try:
await resp.release()
except Exception: # pragma: no cover - best-effort cleanup
logging.debug("[model_downloader] response release error", exc_info=True)

View File

@ -0,0 +1,90 @@
"""Pre-download probe (PRD section 5.1).
Issues a tiny ranged GET (``Range: bytes=0-0``) — which doubles as a
range-support test — to discover ``Content-Length``, ``Accept-Ranges``,
``ETag``/``Last-Modified``, and the final post-redirect URL. For HuggingFace
LFS files the true size also appears in the non-standard ``X-Linked-Size``
header, which we read as a fallback.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Optional
import aiohttp
from app.model_downloader.net.http import open_validated
from app.model_downloader.net.session import parse_int_header
_PROBE_TIMEOUT = aiohttp.ClientTimeout(total=60, sock_connect=30, sock_read=30)
@dataclass
class ProbeResult:
ok: bool
status: int
final_url: Optional[str] = None
total_bytes: Optional[int] = None
accept_ranges: bool = False
etag: Optional[str] = None
last_modified: Optional[str] = None
gated: bool = False # 401/403 — needs (or has wrong) credentials
error: Optional[str] = None
def _total_from_content_range(value: Optional[str]) -> Optional[int]:
# "bytes 0-0/12345" -> 12345 ; "bytes 0-0/*" -> None
if not value or "/" not in value:
return None
total = value.rsplit("/", 1)[1].strip()
return parse_int_header(total)
async def probe(url: str, *, credential_id: Optional[str] = None) -> ProbeResult:
"""Probe ``url`` and return discovered metadata, failing soft."""
try:
async with open_validated(
"GET",
url,
credential_id=credential_id,
headers={"Range": "bytes=0-0", "Accept-Encoding": "identity"},
timeout=_PROBE_TIMEOUT,
) as (resp, final_url):
if resp.status in (401, 403):
return ProbeResult(
ok=False, status=resp.status, final_url=final_url, gated=True,
error=f"host returned {resp.status} (authentication required)",
)
if resp.status not in (200, 206):
return ProbeResult(
ok=False, status=resp.status, final_url=final_url,
error=f"probe returned HTTP {resp.status}",
)
headers = resp.headers
accept_ranges = False
total: Optional[int] = None
if resp.status == 206:
accept_ranges = True
total = _total_from_content_range(headers.get("Content-Range"))
else: # 200: server ignored the range
accept_ranges = headers.get("Accept-Ranges", "").lower() == "bytes"
total = parse_int_header(headers.get("Content-Length"))
if total is None:
total = parse_int_header(headers.get("X-Linked-Size"))
return ProbeResult(
ok=True,
status=resp.status,
final_url=final_url,
total_bytes=total,
accept_ranges=accept_ranges,
etag=headers.get("ETag"),
last_modified=headers.get("Last-Modified"),
)
except Exception as e: # network / SSRF / timeout
logging.debug("[model_downloader] probe failed for %s: %s", url, e)
return ProbeResult(ok=False, status=0, error=f"{type(e).__name__}: {e}")

View File

@ -0,0 +1,72 @@
"""Lazily-created shared :class:`aiohttp.ClientSession`.
A single session reuses TLS handshakes and TCP connections across the probe
and the many segment GETs to the same host (HuggingFace is the dominant
case), which is a large speedup on cold connections and exactly the
connection-reuse strategy that lets us match aria2c (PRD section 5.2).
The connector uses :class:`ValidatingResolver` so every connection — initial
or post-redirect — is screened for private/special-use IPs at connect time.
TLS is pinned to certifi's CA bundle because the OS trust store is not wired
up on some Python installs (python.org macOS, slim containers).
"""
from __future__ import annotations
import asyncio
import ssl
from typing import Optional
import aiohttp
try:
import certifi
_CA_FILE = certifi.where()
except Exception: # pragma: no cover - certifi is a transitive dep of aiohttp
_CA_FILE = None
from comfy.cli_args import args
from app.model_downloader.security.ssrf import ValidatingResolver
_session: Optional[aiohttp.ClientSession] = None
_lock = asyncio.Lock()
def ssl_context() -> ssl.SSLContext:
if _CA_FILE is not None:
return ssl.create_default_context(cafile=_CA_FILE)
return ssl.create_default_context()
async def get_session() -> aiohttp.ClientSession:
"""Return the shared session, creating it on first use."""
global _session
if _session is not None and not _session.closed:
return _session
async with _lock:
if _session is None or _session.closed:
connector = aiohttp.TCPConnector(
limit_per_host=max(1, getattr(args, "download_max_connections_per_host", 16)),
ssl=ssl_context(),
resolver=ValidatingResolver(),
)
_session = aiohttp.ClientSession(connector=connector)
return _session
async def close_session() -> None:
global _session
if _session is not None and not _session.closed:
await _session.close()
_session = None
def parse_int_header(value: Optional[str]) -> Optional[int]:
"""Parse a non-negative integer header value, or None if bad/absent."""
if not value:
return None
try:
n = int(value)
except (TypeError, ValueError):
return None
return n if n >= 0 else None

View File

@ -0,0 +1,160 @@
"""Priority scheduler + lifecycle (PRD sections 4, 6, 12).
Owns the set of running jobs and admits queued downloads up to a global
concurrency limit (K), highest priority first, FIFO within a priority. Runs
entirely on the existing ComfyUI asyncio loop; blocking work (disk, hashing,
DB) is offloaded by the job/writer layers.
On startup it reconciles DB vs. disk: ``active``/``verifying`` rows left by a
previous run are reset to ``queued`` and resumed from persisted offsets, and
orphaned ``.part`` files with no live download row are swept.
"""
from __future__ import annotations
import asyncio
import logging
import os
import random
import time
from typing import Callable, Optional
from comfy.cli_args import args
from app.model_downloader.constants import DownloadStatus
from app.model_downloader.database import queries
from app.model_downloader.engine.job import DownloadJob, JobSpec
from app.model_downloader.security import paths
# Backoff for retryable failures (PRD section 12).
_BACKOFF_BASE = 2.0
_BACKOFF_CAP = 300.0
_MAX_ATTEMPTS = 6
class Scheduler:
def __init__(self) -> None:
self._jobs: dict[str, DownloadJob] = {}
self._tasks: dict[str, asyncio.Task] = {}
self._backoff_until: dict[str, float] = {}
self._pump_lock = asyncio.Lock()
self._notify_cb: Optional[Callable[[str], None]] = None
self._started = False
@property
def max_active(self) -> int:
return max(1, getattr(args, "download_max_active", 3))
def set_notify(self, cb: Optional[Callable[[str], None]]) -> None:
self._notify_cb = cb
def get_job(self, download_id: str) -> Optional[DownloadJob]:
return self._jobs.get(download_id)
def is_active(self, download_id: str) -> bool:
return download_id in self._tasks
# ----- startup -----
async def start(self) -> None:
if self._started:
return
self._started = True
try:
await asyncio.to_thread(queries.reconcile_live_downloads)
await asyncio.to_thread(self._sweep_orphan_temp_files)
except Exception as e:
logging.warning("[model_downloader] startup reconcile failed: %s", e)
await self.pump()
@staticmethod
def _sweep_orphan_temp_files() -> None:
"""Remove ``.part`` files not referenced by a resumable download row.
Resumable partials (queued/paused rows) are preserved; only truly
orphaned temp files from crashed runs are deleted.
"""
live = {
row.temp_path
for row in queries.list_downloads()
if row.status in (DownloadStatus.QUEUED, DownloadStatus.PAUSED)
}
for path in paths.iter_all_tmp_paths():
if path in live:
continue
try:
os.remove(path)
logging.info("[model_downloader] removed orphan temp file: %s", path)
except OSError as e:
logging.warning("[model_downloader] could not remove %s: %s", path, e)
# ----- admission -----
async def pump(self) -> None:
async with self._pump_lock:
slots = self.max_active - len(self._tasks)
if slots <= 0:
return
now = time.monotonic()
candidates = await asyncio.to_thread(queries.list_queued_downloads)
for row in candidates:
if slots <= 0:
break
if row.id in self._tasks:
continue
if self._backoff_until.get(row.id, 0.0) > now:
continue
self._admit(row)
slots -= 1
def _admit(self, row) -> None:
spec = JobSpec(
download_id=row.id,
url=row.url,
model_id=row.model_id,
dest_path=row.dest_path,
temp_path=row.temp_path,
priority=row.priority,
credential_id=row.credential_id,
expected_sha256=row.expected_sha256,
allow_any_extension=row.allow_any_extension,
etag=row.etag,
attempts=row.attempts,
)
job = DownloadJob(spec, notify_cb=self._notify_cb)
self._jobs[row.id] = job
self._tasks[row.id] = asyncio.ensure_future(self._run_job(job))
async def _run_job(self, job: DownloadJob) -> None:
download_id = job.spec.download_id
status = DownloadStatus.FAILED
try:
status = await job.run()
except Exception as e: # run() is defensive, but never let a task die silently
logging.error("[model_downloader] job %s crashed: %s", download_id, e)
finally:
self._tasks.pop(download_id, None)
self._jobs.pop(download_id, None)
if status == DownloadStatus.QUEUED:
if job.spec.attempts >= _MAX_ATTEMPTS:
queries.update_download(
download_id,
status=DownloadStatus.FAILED,
error=f"giving up after {job.spec.attempts} attempts",
)
if self._notify_cb:
self._notify_cb(download_id)
else:
delay = min(
_BACKOFF_CAP, _BACKOFF_BASE ** job.spec.attempts
) + random.uniform(0, 1.0)
self._backoff_until[download_id] = time.monotonic() + delay
asyncio.ensure_future(self._delayed_pump(delay))
await self.pump()
async def _delayed_pump(self, delay: float) -> None:
await asyncio.sleep(delay)
await self.pump()
SCHEDULER = Scheduler()

View File

@ -0,0 +1,84 @@
"""URL allowlist for server-side model fetches (PRD section 9.1).
Default-deny. A URL is downloadable only when its parsed host + scheme are
allowlisted AND (unless explicitly relaxed) its final filename ends in a
known model extension.
The built-in host defaults mirror the frontend's ``isModelDownloadable``
allowlist so the two flows agree on what is eligible; ``--download-allowed-hosts``
extends it for self-hosted mirrors. Matching is done on ``urlparse().hostname``
(never a raw string prefix) so userinfo tricks like
``http://127.0.0.1@169.254.169.254/x.safetensors`` — whose real host is the
metadata IP — cannot slip past.
"""
from __future__ import annotations
from urllib.parse import urlparse
from comfy.cli_args import args
# host -> set of allowed schemes. Frontend parity (HuggingFace / Civitai /
# localhost). Extra hosts from --download-allowed-hosts are https-only.
_DEFAULT_ALLOWED_HOSTS: dict[str, set[str]] = {
"huggingface.co": {"https"},
"civitai.com": {"https"},
"localhost": {"http", "https"},
"127.0.0.1": {"http", "https"},
}
# Hosts for which loopback addresses are intentionally permitted (the localhost
# "download a local model" feature). Every other host's loopback resolution is
# rejected by the SSRF resolver.
LOOPBACK_HOSTS = frozenset({"localhost", "127.0.0.1", "::1"})
# Known model file extensions (frontend parity). Checked on the final filename.
ALLOWED_MODEL_EXTENSIONS = (
".safetensors",
".sft",
".ckpt",
".pth",
".pt",
".gguf",
".bin",
)
def _allowed_hosts() -> dict[str, set[str]]:
hosts = {h: set(s) for h, s in _DEFAULT_ALLOWED_HOSTS.items()}
for extra in getattr(args, "download_allowed_hosts", []) or []:
host = extra.strip().lower()
if host:
hosts.setdefault(host, set()).add("https")
return hosts
def is_host_allowed(host: str | None, scheme: str | None) -> bool:
"""True iff ``host`` is allowlisted for ``scheme``.
Used both for the initial URL and re-checked on every redirect hop
(PRD section 9.2), so a whitelisted URL cannot 30x into an off-list host.
"""
if not host or not scheme:
return False
allowed = _allowed_hosts().get(host.lower())
return allowed is not None and scheme.lower() in allowed
def has_allowed_extension(path: str, allow_any_extension: bool = False) -> bool:
if allow_any_extension:
return True
return path.lower().endswith(ALLOWED_MODEL_EXTENSIONS)
def is_url_allowed(url: str, allow_any_extension: bool = False) -> bool:
"""Check whether ``url`` is permitted as a server-side download source."""
if not isinstance(url, str) or not url:
return False
try:
parsed = urlparse(url)
except ValueError:
return False
if not is_host_allowed(parsed.hostname, parsed.scheme):
return False
return has_allowed_extension(parsed.path, allow_any_extension)

View File

@ -0,0 +1,110 @@
"""Path resolution + traversal safety for downloads (PRD section 9.3).
A ``model_id`` is a *relative destination path* of the form
``<directory>/<filename>`` (e.g. ``loras/my_lora.safetensors``). This module
turns one into an absolute on-disk path under one of ComfyUI's registered
model folders, rejecting unknown folders, path traversal, and symlink escape.
This is the only thing that composes destination paths, so the engine never
touches user-supplied path strings directly.
"""
from __future__ import annotations
import os
import re
from typing import Iterator, Optional
import folder_paths
from app.model_downloader.constants import TMP_SUFFIX
from app.model_downloader.security.allowlist import ALLOWED_MODEL_EXTENSIONS
# A model_id component is a single path segment of safe characters — no slashes,
# no "..", no leading dots that could escape the target directory.
_SEGMENT_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$")
class InvalidModelId(ValueError):
"""Raised when a model_id is malformed or names an unknown model folder."""
def parse_model_id(model_id: str, allow_any_extension: bool = False) -> tuple[str, str]:
"""Split ``<directory>/<filename>`` and validate both components.
Returns ``(directory, filename)``. Does not touch the filesystem.
"""
if not isinstance(model_id, str) or "/" not in model_id:
raise InvalidModelId(
f"model_id must be '<directory>/<filename>', got {model_id!r}"
)
directory, _, filename = model_id.partition("/")
if "/" in filename or not directory or not filename:
raise InvalidModelId(
f"model_id must have exactly one '/' separator, got {model_id!r}"
)
if not _SEGMENT_RE.match(directory):
raise InvalidModelId(f"invalid directory segment {directory!r}")
if not _SEGMENT_RE.match(filename):
raise InvalidModelId(f"invalid filename segment {filename!r}")
if not allow_any_extension and not filename.lower().endswith(
ALLOWED_MODEL_EXTENSIONS
):
raise InvalidModelId(
f"filename must end with a known model extension "
f"{ALLOWED_MODEL_EXTENSIONS}, got {filename!r}"
)
if directory not in folder_paths.folder_names_and_paths:
raise InvalidModelId(f"unknown model folder {directory!r}")
return directory, filename
def resolve_existing(model_id: str, allow_any_extension: bool = False) -> Optional[str]:
"""Return the absolute path of an installed model, or None if missing.
Honours ``extra_model_paths.yaml`` transparently via ``get_full_path``.
"""
directory, filename = parse_model_id(model_id, allow_any_extension)
return folder_paths.get_full_path(directory, filename)
def resolve_destination(
model_id: str, allow_any_extension: bool = False
) -> tuple[str, str]:
"""Return ``(final_path, temp_path)`` for a download.
Downloads land at the first registered path for the model's directory
(the "primary" location). ``temp_path`` is a sibling ``.part`` file that
is atomically renamed onto ``final_path`` on success. The result is
asserted to stay within the registered root (defence in depth on top of
the segment regex).
"""
directory, filename = parse_model_id(model_id, allow_any_extension)
roots = folder_paths.get_folder_paths(directory)
if not roots:
raise InvalidModelId(f"no on-disk path registered for folder {directory!r}")
root = os.path.realpath(roots[0])
final_path = os.path.realpath(os.path.join(root, filename))
if final_path != root and not final_path.startswith(root + os.sep):
raise InvalidModelId(f"resolved path escapes model root: {model_id!r}")
temp_path = f"{final_path}{TMP_SUFFIX}"
return final_path, temp_path
def iter_all_tmp_paths() -> Iterator[str]:
"""Yield this subsystem's temp files under every registered model folder.
Matches only the distinctive ``TMP_SUFFIX`` so the startup orphan sweep
can never delete temp files created by other tools.
"""
seen_roots: set[str] = set()
for directory in list(folder_paths.folder_names_and_paths.keys()):
for root in folder_paths.get_folder_paths(directory):
if root in seen_roots or not os.path.isdir(root):
continue
seen_roots.add(root)
try:
for entry in os.scandir(root):
if entry.is_file() and entry.name.endswith(TMP_SUFFIX):
yield entry.path
except OSError:
continue

View File

@ -0,0 +1,111 @@
"""SSRF / exfiltration defenses (PRD section 9.2).
Two cooperating layers:
1. :class:`ValidatingResolver` is installed on the shared connector. Every
connection — the initial probe and every segment GET, including ones made
after a redirect — resolves its host through this resolver, which rejects
any address that lands on a private / special-use IP range. Because the
resolve and the connect happen together inside the connector, there is no
check-then-connect window for DNS rebinding to exploit.
2. :func:`check_redirect_hop` re-validates every redirect hop. The host
allowlist gates only the *initial* user-supplied URL (anti-SSRF for
arbitrary input); legitimate downloads from allowlisted origins redirect
to presigned CDN hosts that are deliberately NOT on the allowlist (HF ->
``cdn-lfs*.huggingface.co``, Civitai -> signed Cloudflare/S3), so hops are
instead screened for scheme, embedded credentials, and — via the resolver
above — private IPs. Credentials are only ever attached when a hop's host
exactly matches a stored credential, so they are dropped on the CDN hop.
"""
from __future__ import annotations
import ipaddress
import socket
from urllib.parse import urlparse
from aiohttp.abc import AbstractResolver
from aiohttp.resolver import DefaultResolver
from app.model_downloader.security.allowlist import LOOPBACK_HOSTS
# Cap the redirect chain length and the schemes a hop may use.
MAX_REDIRECTS = 5
ALLOWED_SCHEMES = ("https", "http")
class SSRFError(Exception):
"""A hop failed an SSRF / allowlist check."""
def is_blocked_ip(ip_str: str) -> bool:
"""True for any address we refuse to connect to.
Covers loopback, link-local (incl. 169.254.169.254 cloud metadata),
RFC1918 private ranges, unique-local (ULA), unspecified (0.0.0.0/::),
multicast and other reserved ranges.
"""
try:
ip = ipaddress.ip_address(ip_str)
except ValueError:
return True # unparseable -> refuse
return (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
or ip.is_unspecified
)
class ValidatingResolver(AbstractResolver):
"""Delegating resolver that drops blocked IPs from every resolution.
If a hostname resolves only to blocked addresses, the connection fails
closed with an :class:`OSError`, which aiohttp surfaces as a connection
error to the caller.
"""
def __init__(self) -> None:
self._inner = DefaultResolver()
async def resolve(self, host, port=0, family=socket.AF_INET):
infos = await self._inner.resolve(host, port, family)
# localhost/127.0.0.1 are an explicit, opt-in allowlist feature.
if isinstance(host, str) and host.lower() in LOOPBACK_HOSTS:
return infos
safe = [info for info in infos if not is_blocked_ip(info["host"])]
if not safe:
raise OSError(
f"refusing to connect to {host!r}: resolves only to "
f"private/special-use addresses"
)
return safe
async def close(self) -> None:
await self._inner.close()
def check_redirect_hop(url: str) -> str:
"""Validate one redirect hop's URL.
Returns the URL unchanged on success; raises :class:`SSRFError` otherwise.
Enforces an allowed scheme and forbids credentials-in-URL. The host is NOT
re-checked against the allowlist (CDN redirect targets are off-list by
design); private-IP protection is provided by the connector's resolver,
and credential leakage is prevented by exact host matching at attach time.
The landing filename's extension is gated separately by the caller.
"""
try:
parsed = urlparse(url)
except ValueError as e:
raise SSRFError(f"unparseable redirect URL {url!r}: {e}") from e
if parsed.scheme.lower() not in ALLOWED_SCHEMES:
raise SSRFError(f"redirect to disallowed scheme {parsed.scheme!r}")
if parsed.username or parsed.password:
raise SSRFError("credentials-in-URL are not allowed")
if not parsed.hostname:
raise SSRFError(f"redirect URL has no host: {url!r}")
return url

View File

@ -0,0 +1,49 @@
"""Hub-checksum verification = SHA256 (PRD section 8.1).
Only used to confirm a download matches a *provided* ``expected_sha256``. It
is NOT the dedup key (that is blake3, owned by the assets system). The full
sequential read happens at most once, here, only when a checksum was supplied.
"""
from __future__ import annotations
import hashlib
from typing import Callable, Optional
_CHUNK = 8 * 1024 * 1024
InterruptCheck = Callable[[], bool]
class ChecksumError(Exception):
"""The computed SHA256 did not match the expected value."""
def sha256_file(path: str, interrupt_check: Optional[InterruptCheck] = None) -> Optional[str]:
"""Stream the file and return its lowercase hex SHA256.
Returns ``None`` if interrupted via ``interrupt_check``.
"""
h = hashlib.sha256()
with open(path, "rb") as f:
while True:
if interrupt_check is not None and interrupt_check():
return None
chunk = f.read(_CHUNK)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
def verify_sha256(
path: str, expected: str, interrupt_check: Optional[InterruptCheck] = None
) -> None:
"""Raise :class:`ChecksumError` unless the file's SHA256 matches ``expected``."""
actual = sha256_file(path, interrupt_check)
if actual is None:
return # interrupted; caller will re-verify on resume
if actual.lower() != expected.lower():
raise ChecksumError(
f"sha256 mismatch: expected {expected.lower()}, got {actual.lower()}"
)

View File

@ -0,0 +1,53 @@
"""Dedup + catalog handoff — reuse the assets system (PRD section 8.5).
We do NOT build a parallel indexer. "Do I already have it?" is answered by
``resolve_existing`` (path) at enqueue time and, where a hash is known, by the
assets blake3 catalog. After a completed download we register the file
through the assets ingest path so it is cataloged and (eventually) hashed by
the existing enrichment worker.
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Optional
def _register_sync(abs_path: str) -> Optional[str]:
"""Register a finished file into the assets catalog. Returns asset hash."""
try:
from app.assets.services.ingest import register_file_in_place
except Exception as e: # assets package import failure — non-fatal
logging.debug("[model_downloader] assets ingest unavailable: %s", e)
return None
try:
result = register_file_in_place(abs_path, name=os.path.basename(abs_path), tags=[])
return result.asset.hash if result and result.asset else None
except Exception as e:
# The file is already safely on disk; cataloging is best-effort.
logging.warning(
"[model_downloader] could not register %s into assets catalog: %s",
abs_path, e,
)
return None
async def register_completed(abs_path: str) -> Optional[str]:
"""Catalog a completed download via the assets system (off the event loop)."""
return await asyncio.to_thread(_register_sync, abs_path)
def _find_by_hash_sync(blake3_hex: str) -> Optional[str]:
try:
from app.assets.services.asset_management import get_asset_by_hash
except Exception:
return None
asset = get_asset_by_hash("blake3:" + blake3_hex)
return asset.hash if asset is not None else None
async def find_existing_by_hash(blake3_hex: str) -> Optional[str]:
"""Pure DB lookup — never triggers hashing on the hot path."""
return await asyncio.to_thread(_find_by_hash_sync, blake3_hex)

View File

@ -0,0 +1,65 @@
"""Cheap structural validation, no full read (PRD section 8.2).
For ``.safetensors``/``.sft`` we parse the header (first few KB): it carries
the tensor table and the byte length of the data region. We assert
``file_size == 8 + header_len + data_region_len``. This detects truncation
and most corruption for free, before any crypto hashing. Other extensions
have no cheap structural check and pass through.
"""
from __future__ import annotations
import json
import os
import struct
_SAFETENSORS_EXTS = (".safetensors", ".sft")
# A sane upper bound so a corrupt header length can't make us read gigabytes.
_MAX_HEADER_BYTES = 100 * 1024 * 1024
class StructuralError(Exception):
"""The file failed its structural integrity check."""
def validate(path: str) -> None:
"""Validate the file at ``path``. Raises :class:`StructuralError` on failure."""
lower = path.lower()
if lower.endswith(_SAFETENSORS_EXTS):
_validate_safetensors(path)
# No structural check for other formats; the size + (optional) checksum
# gates in the engine cover those.
def _validate_safetensors(path: str) -> None:
file_size = os.path.getsize(path)
if file_size < 8:
raise StructuralError(f"file too small to be safetensors ({file_size} bytes)")
with open(path, "rb") as f:
header_len = struct.unpack("<Q", f.read(8))[0]
if header_len <= 0 or header_len > _MAX_HEADER_BYTES:
raise StructuralError(f"implausible safetensors header length {header_len}")
if 8 + header_len > file_size:
raise StructuralError("safetensors header extends past end of file")
try:
header = json.loads(f.read(header_len).decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError) as e:
raise StructuralError(f"safetensors header is not valid JSON: {e}") from e
data_len = 0
for name, entry in header.items():
if name == "__metadata__":
continue
if not isinstance(entry, dict) or "data_offsets" not in entry:
raise StructuralError(f"tensor {name!r} missing data_offsets")
offsets = entry["data_offsets"]
if not (isinstance(offsets, list) and len(offsets) == 2):
raise StructuralError(f"tensor {name!r} has malformed data_offsets")
data_len = max(data_len, int(offsets[1]))
expected = 8 + header_len + data_len
if file_size != expected:
raise StructuralError(
f"size mismatch: file is {file_size} bytes, header implies {expected} "
f"(8 + {header_len} header + {data_len} data)"
)

View File

@ -243,6 +243,14 @@ parser.add_argument("--enable-assets", action="store_true", help="Enable the ass
parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")
# ----- Model download manager (PRD: docs/prd-download-manager.md) -----
parser.add_argument("--download-segments", type=int, default=8, metavar="N", help="Number of parallel HTTP range segments per file for the model download manager (default: 8).")
parser.add_argument("--download-max-active", type=int, default=3, metavar="N", help="Maximum number of model downloads running concurrently (default: 3).")
parser.add_argument("--download-max-connections-per-host", type=int, default=16, metavar="N", help="Maximum simultaneous connections to a single host for the download manager (default: 16).")
parser.add_argument("--download-chunk-size", type=int, default=4 * 1024 * 1024, metavar="BYTES", help="Read chunk size in bytes for the download manager (default: 4 MiB).")
parser.add_argument("--download-allowed-hosts", type=str, nargs="*", default=[], metavar="HOST", help="Additional hostnames to add to the download manager allowlist (https only). The built-in defaults always include huggingface.co and civitai.com.")
parser.add_argument("--download-allow-any-extension", action="store_true", help="Allow the download manager to fetch files with any extension (default: only known model extensions like .safetensors).")
if comfy.options.args_parsing:
args = parser.parse_args()
else:

View File

@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
if (want_requant and len(fns) == 0 or update_weight):
seed = comfy.utils.string_to_seed(s.seed_key)
if isinstance(orig, QuantizedTensor):
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed)
else:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
if want_requant and len(fns) == 0:
@ -1306,8 +1306,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
# dtype is now implicit in the layout class
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
else:
weight = weight.to(self.weight.dtype)
if return_weight:

View File

@ -104,6 +104,7 @@ _CORE_FEATURE_FLAGS: dict[str, Any] = {
"extension": {"manager": {"supports_v4": True}},
"node_replacements": True,
"assets": args.enable_assets,
"server_side_model_downloads": True,
}
# CLI-provided flags cannot overwrite core flags

View File

@ -1,85 +1,68 @@
import os
import sys
import re
import ctypes
import logging
import ctypes.util
import importlib.util
from typing import TypedDict
import numpy as np
import torch
import nodes
import comfy_angle
from comfy_api.latest import ComfyExtension, io, ui
from typing_extensions import override
from utils.install_util import get_missing_requirements_message
logger = logging.getLogger(__name__)
def _check_opengl_availability():
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
logger.debug("_check_opengl_availability: starting")
missing = []
def _preload_angle():
egl_path = comfy_angle.get_egl_path()
gles_path = comfy_angle.get_glesv2_path()
# Check Python packages (using find_spec to avoid importing)
logger.debug("_check_opengl_availability: checking for glfw package")
if importlib.util.find_spec("glfw") is None:
missing.append("glfw")
if sys.platform == "win32":
angle_dir = comfy_angle.get_lib_dir()
os.add_dll_directory(angle_dir)
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
logger.debug("_check_opengl_availability: checking for OpenGL package")
if importlib.util.find_spec("OpenGL") is None:
missing.append("PyOpenGL")
if missing:
raise RuntimeError(
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
)
# On Linux without display, check if headless backends are available
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
if sys.platform.startswith("linux"):
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
if not has_display:
# Check for EGL or OSMesa libraries
logger.debug("_check_opengl_availability: checking for EGL library")
has_egl = ctypes.util.find_library("EGL")
logger.debug("_check_opengl_availability: checking for OSMesa library")
has_osmesa = ctypes.util.find_library("OSMesa")
# Error disabled for CI as it fails this check
# if not has_egl and not has_osmesa:
# raise RuntimeError(
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
# "See error below for installation instructions."
# )
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
logger.debug("_check_opengl_availability: completed")
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
ctypes.CDLL(str(egl_path), mode=mode)
ctypes.CDLL(str(gles_path), mode=mode)
# Run early check at import time
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
_check_opengl_availability()
# OpenGL modules - initialized lazily when context is created
gl = None
glfw = None
EGL = None
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
_preload_angle()
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
def _import_opengl():
"""Import OpenGL module. Called after context is created."""
global gl
if gl is None:
logger.debug("_import_opengl: importing OpenGL.GL")
import OpenGL.GL as _gl
gl = _gl
logger.debug("_import_opengl: import completed")
return gl
import OpenGL
OpenGL.USE_ACCELERATE = False
def _patch_find_library():
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
'libGLESv2'. Patch find_library to return the full ANGLE paths so
PyOpenGL loads the same libraries we pre-loaded."""
if sys.platform == "linux":
return
import ctypes.util
_orig = ctypes.util.find_library
def _patched(name):
if name == 'EGL':
return comfy_angle.get_egl_path()
if name == 'GLESv2':
return comfy_angle.get_glesv2_path()
return _orig(name)
ctypes.util.find_library = _patched
_patch_find_library()
from OpenGL import EGL
from OpenGL import GLES3 as gl
class SizeModeInput(TypedDict):
size_mode: str
width: int
@ -102,7 +85,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
# (-1,-1)---(3,-1)
#
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
VERTEX_SHADER = """#version 330 core
VERTEX_SHADER = """#version 300 es
out vec2 v_texCoord;
void main() {
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
@ -126,14 +109,99 @@ void main() {
"""
def _convert_es_to_desktop(source: str) -> str:
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
# Remove any existing #version directive
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
# Remove precision qualifiers (not needed in desktop GLSL)
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
# Prepend desktop GLSL version
return "#version 330 core\n" + source
def _egl_attribs(*values):
"""Build an EGL_NONE-terminated EGLint attribute array."""
vals = list(values) + [EGL.EGL_NONE]
return (ctypes.c_int32 * len(vals))(*vals)
# EGL platform extension constants
EGL_PLATFORM_ANGLE_ANGLE = 0x3202
EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203
EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450
EGL_MESA_PLATFORM_SURFACELESS = 0x31DD
_eglGetPlatformDisplayEXT = None
def _get_egl_platform_display_ext(platform, native_display, attribs):
"""Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL)."""
global _eglGetPlatformDisplayEXT
if _eglGetPlatformDisplayEXT is None:
from OpenGL import platform as _plat
egl_lib = _plat.PLATFORM.EGL
_get_proc = egl_lib.eglGetProcAddress
_get_proc.restype = ctypes.c_void_p
_get_proc.argtypes = [ctypes.c_char_p]
ptr = _get_proc(b"eglGetPlatformDisplayEXT")
if not ptr:
return None
func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p)
_eglGetPlatformDisplayEXT = func_type(ptr)
raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs)
if not raw:
return None
return ctypes.cast(raw, EGL.EGLDisplay)
def _get_egl_display():
"""Get an EGL display, trying the default first then ANGLE's Vulkan
platform for headless environments without a display server."""
failures = []
# Try the default display first (works when X11/Wayland is available)
display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
if display:
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
try:
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
return display, major.value, minor.value
except Exception as e:
failures.append(f"default: {e}")
logger.info("Default EGL display unavailable, trying headless fallbacks")
# Headless fallback strategies, tried in order:
headless_strategies = [
("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None),
("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None,
_egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)),
]
for name, platform, native_display, attribs in headless_strategies:
display = _get_egl_platform_display_ext(platform, native_display, attribs)
if not display:
failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display")
continue
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
try:
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
logger.info(f"Using EGL {name} platform (headless)")
return display, major.value, minor.value
failures.append(f"{name}: eglInitialize returned false")
except Exception as e:
failures.append(f"{name}: {e}")
continue
details = "\n".join(f" - {f}" for f in failures)
raise RuntimeError(
"Failed to initialize EGL display.\n"
"No display server and no headless EGL platform available.\n"
f"Tried:\n{details}\n"
"Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer."
)
def _gl_str(name):
"""Get an OpenGL string parameter."""
v = gl.glGetString(name)
if not v:
return "Unknown"
if isinstance(v, bytes):
return v.decode(errors="replace")
return ctypes.string_at(v).decode(errors="replace")
def _detect_output_count(source: str) -> int:
@ -159,163 +227,8 @@ def _detect_pass_count(source: str) -> int:
return 1
def _init_glfw():
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
logger.debug("_init_glfw: starting")
# On macOS, glfw.init() must be called from main thread or it hangs forever
if sys.platform == "darwin":
logger.debug("_init_glfw: skipping on macOS")
raise RuntimeError("GLFW backend not supported on macOS")
logger.debug("_init_glfw: importing glfw module")
import glfw as _glfw
logger.debug("_init_glfw: calling glfw.init()")
if not _glfw.init():
raise RuntimeError("glfw.init() failed")
try:
logger.debug("_init_glfw: setting window hints")
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
logger.debug("_init_glfw: calling create_window()")
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
if not window:
raise RuntimeError("glfw.create_window() failed")
logger.debug("_init_glfw: calling make_context_current()")
_glfw.make_context_current(window)
logger.debug("_init_glfw: completed successfully")
return window, _glfw
except Exception:
logger.debug("_init_glfw: failed, terminating glfw")
_glfw.terminate()
raise
def _init_egl():
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
logger.debug("_init_egl: starting")
from OpenGL import EGL as _EGL
from OpenGL.EGL import (
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
eglTerminate, eglDestroyContext, eglDestroySurface,
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
)
logger.debug("_init_egl: imports completed")
display = None
context = None
surface = None
try:
logger.debug("_init_egl: calling eglGetDisplay()")
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
if display == _EGL.EGL_NO_DISPLAY:
raise RuntimeError("eglGetDisplay() failed")
logger.debug("_init_egl: calling eglInitialize()")
major, minor = _EGL.EGLint(), _EGL.EGLint()
if not eglInitialize(display, major, minor):
display = None # Not initialized, don't terminate
raise RuntimeError("eglInitialize() failed")
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
config_attribs = [
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
EGL_DEPTH_SIZE, 0, EGL_NONE
]
configs = (_EGL.EGLConfig * 1)()
num_configs = _EGL.EGLint()
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
config = configs[0]
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
if not eglBindAPI(EGL_OPENGL_API):
raise RuntimeError("eglBindAPI() failed")
logger.debug("_init_egl: calling eglCreateContext()")
context_attribs = [
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
EGL_NONE
]
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
if context == EGL_NO_CONTEXT:
raise RuntimeError("eglCreateContext() failed")
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
if surface == _EGL.EGL_NO_SURFACE:
raise RuntimeError("eglCreatePbufferSurface() failed")
logger.debug("_init_egl: calling eglMakeCurrent()")
if not eglMakeCurrent(display, surface, surface, context):
raise RuntimeError("eglMakeCurrent() failed")
logger.debug("_init_egl: completed successfully")
return display, context, surface, _EGL
except Exception:
logger.debug("_init_egl: failed, cleaning up")
# Clean up any resources on failure
if surface is not None:
eglDestroySurface(display, surface)
if context is not None:
eglDestroyContext(display, context)
if display is not None:
eglTerminate(display)
raise
def _init_osmesa():
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
import ctypes
logger.debug("_init_osmesa: starting")
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
logger.debug("_init_osmesa: importing OpenGL.osmesa")
from OpenGL import GL as _gl
from OpenGL.osmesa import (
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
OSMESA_RGBA,
)
logger.debug("_init_osmesa: imports completed")
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
if not ctx:
raise RuntimeError("OSMesaCreateContextExt() failed")
width, height = 64, 64
buffer = (ctypes.c_ubyte * (width * height * 4))()
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
OSMesaDestroyContext(ctx)
raise RuntimeError("OSMesaMakeCurrent() failed")
logger.debug("_init_osmesa: completed successfully")
return ctx, buffer
class GLContext:
"""Manages OpenGL context and resources for shader execution.
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
"""
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
_instance = None
_initialized = False
@ -327,131 +240,105 @@ class GLContext:
def __init__(self):
if GLContext._initialized:
logger.debug("GLContext.__init__: already initialized, skipping")
return
logger.debug("GLContext.__init__: starting initialization")
global glfw, EGL
import time
start = time.perf_counter()
self._backend = None
self._window = None
self._egl_display = None
self._egl_context = None
self._egl_surface = None
self._osmesa_ctx = None
self._osmesa_buffer = None
self._display = None
self._surface = None
self._context = None
self._vao = None
# Try backends in order: GLFW → EGL → OSMesa
errors = []
logger.debug("GLContext.__init__: trying GLFW backend")
try:
self._window, glfw = _init_glfw()
self._backend = "glfw"
logger.debug("GLContext.__init__: GLFW backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
errors.append(("GLFW", e))
self._display, self._egl_major, self._egl_minor = _get_egl_display()
if self._backend is None:
logger.debug("GLContext.__init__: trying EGL backend")
try:
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
self._backend = "egl"
logger.debug("GLContext.__init__: EGL backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
errors.append(("EGL", e))
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
if self._backend is None:
logger.debug("GLContext.__init__: trying OSMesa backend")
try:
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
self._backend = "osmesa"
logger.debug("GLContext.__init__: OSMesa backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
errors.append(("OSMesa", e))
config = EGL.EGLConfig()
n_configs = ctypes.c_int32(0)
if not EGL.eglChooseConfig(
self._display,
_egl_attribs(
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
),
ctypes.byref(config), 1, ctypes.byref(n_configs),
) or n_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
if self._backend is None:
if sys.platform == "win32":
platform_help = (
"Windows: Ensure GPU drivers are installed and display is available.\n"
" CPU-only/headless mode is not supported on Windows."
)
elif sys.platform == "darwin":
platform_help = (
"macOS: GLFW is not supported.\n"
" Install OSMesa via Homebrew: brew install mesa\n"
" Then: pip install PyOpenGL PyOpenGL-accelerate"
)
else:
platform_help = (
"Linux: Install one of these backends:\n"
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
" Headless (CPU): sudo apt install libosmesa6"
)
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
raise RuntimeError(
f"Failed to create OpenGL context.\n\n"
f"Backend errors:\n{error_details}\n\n"
f"{platform_help}"
self._surface = EGL.eglCreatePbufferSurface(
self._display, config,
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
)
if not self._surface:
raise RuntimeError("eglCreatePbufferSurface() failed")
# Now import OpenGL.GL (after context is current)
logger.debug("GLContext.__init__: importing OpenGL.GL")
_import_opengl()
self._context = EGL.eglCreateContext(
self._display, config, EGL.EGL_NO_CONTEXT,
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
)
if not self._context:
raise RuntimeError("eglCreateContext() failed")
# Create VAO (required for core profile, but OSMesa may use compat profile)
logger.debug("GLContext.__init__: creating VAO")
try:
vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(vao)
self._vao = vao # Only store after successful bind
logger.debug("GLContext.__init__: VAO created successfully")
except Exception as e:
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
# OSMesa with older Mesa may not support VAOs
# Clean up if we created but couldn't bind
if vao:
try:
gl.glDeleteVertexArrays(1, [vao])
except Exception:
pass
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
raise RuntimeError("eglMakeCurrent() failed")
self._vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(self._vao)
except Exception:
self._cleanup()
raise
elapsed = (time.perf_counter() - start) * 1000
# Log device info
renderer = gl.glGetString(gl.GL_RENDERER)
vendor = gl.glGetString(gl.GL_VENDOR)
version = gl.glGetString(gl.GL_VERSION)
renderer = renderer.decode() if renderer else "Unknown"
vendor = vendor.decode() if vendor else "Unknown"
version = version.decode() if version else "Unknown"
renderer = _gl_str(gl.GL_RENDERER)
vendor = _gl_str(gl.GL_VENDOR)
version = _gl_str(gl.GL_VERSION)
GLContext._initialized = True
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}")
def make_current(self):
if self._backend == "glfw":
glfw.make_context_current(self._window)
elif self._backend == "egl":
from OpenGL.EGL import eglMakeCurrent
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
elif self._backend == "osmesa":
from OpenGL.osmesa import OSMesaMakeCurrent
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
err = EGL.eglGetError()
raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})")
if self._vao is not None:
gl.glBindVertexArray(self._vao)
def _cleanup(self):
if not self._display:
return
try:
if self._vao is not None:
gl.glDeleteVertexArrays(1, [self._vao])
self._vao = None
except Exception:
pass
try:
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
except Exception:
pass
try:
if self._context:
EGL.eglDestroyContext(self._display, self._context)
except Exception:
pass
try:
if self._surface:
EGL.eglDestroySurface(self._display, self._surface)
except Exception:
pass
try:
EGL.eglTerminate(self._display)
except Exception:
pass
self._display = None
def _compile_shader(source: str, shader_type: int) -> int:
"""Compile a shader and return its ID."""
@ -459,8 +346,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
gl.glShaderSource(shader, source)
gl.glCompileShader(shader)
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
error = gl.glGetShaderInfoLog(shader).decode()
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
error = gl.glGetShaderInfoLog(shader)
if isinstance(error, bytes):
error = error.decode(errors="replace")
gl.glDeleteShader(shader)
raise RuntimeError(f"Shader compilation failed:\n{error}")
@ -484,8 +373,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
gl.glDeleteShader(vertex_shader)
gl.glDeleteShader(fragment_shader)
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
error = gl.glGetProgramInfoLog(program).decode()
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
error = gl.glGetProgramInfoLog(program)
if isinstance(error, bytes):
error = error.decode(errors="replace")
gl.glDeleteProgram(program)
raise RuntimeError(f"Program linking failed:\n{error}")
@ -530,9 +421,6 @@ def _render_shader_batch(
ctx = GLContext()
ctx.make_current()
# Convert from GLSL ES to desktop GLSL 330
fragment_source = _convert_es_to_desktop(fragment_code)
# Detect how many outputs the shader actually uses
num_outputs = _detect_output_count(fragment_code)
@ -558,9 +446,9 @@ def _render_shader_batch(
try:
# Compile shaders (once for all batches)
try:
program = _create_program(VERTEX_SHADER, fragment_source)
program = _create_program(VERTEX_SHADER, fragment_code)
except RuntimeError:
logger.error(f"Fragment shader:\n{fragment_source}")
logger.error(f"Fragment shader:\n{fragment_code}")
raise
gl.glUseProgram(program)
@ -723,13 +611,13 @@ def _render_shader_batch(
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
# Read back outputs for this batch
# (glGetTexImage is synchronous, implicitly waits for rendering)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
batch_outputs = []
for tex in output_textures:
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
batch_outputs.append(img[::-1, :, :].copy())
for i in range(num_outputs):
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
buf = np.empty((height, width, 4), dtype=np.float32)
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
batch_outputs.append(buf[::-1, :, :].copy())
# Pad with black images for unused outputs
black_img = np.zeros((height, width, 4), dtype=np.float32)
@ -750,18 +638,18 @@ def _render_shader_batch(
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glUseProgram(0)
for tex in input_textures:
gl.glDeleteTextures(int(tex))
for tex in curve_textures:
gl.glDeleteTextures(int(tex))
for tex in output_textures:
gl.glDeleteTextures(int(tex))
for tex in ping_pong_textures:
gl.glDeleteTextures(int(tex))
if input_textures:
gl.glDeleteTextures(len(input_textures), input_textures)
if curve_textures:
gl.glDeleteTextures(len(curve_textures), curve_textures)
if output_textures:
gl.glDeleteTextures(len(output_textures), output_textures)
if ping_pong_textures:
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
if fbo is not None:
gl.glDeleteFramebuffers(1, [fbo])
for pp_fbo in ping_pong_fbos:
gl.glDeleteFramebuffers(1, [pp_fbo])
if ping_pong_fbos:
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
if program is not None:
gl.glDeleteProgram(program)

View File

@ -1113,32 +1113,6 @@ def full_type_name(klass):
return klass.__qualname__
return module + '.' + klass.__qualname__
def node_not_executable_reason(class_def, class_type):
"""Return a human-readable reason the node cannot be executed, or None if it's fine.
Catches a node whose declared entry point doesn't resolve to a real method
(e.g. a V1 ``FUNCTION = "invert"`` where the method is misspelled, or a V3 node
missing its ``execute`` override). Running this during validation surfaces the
problem before execution starts, instead of after upstream nodes have run.
Only the class is inspected; the node is never instantiated here, so a node's
``__init__`` side effects cannot run (or fail) during validation.
"""
try:
if issubclass(class_def, _ComfyNodeInternal):
# V3: validates that execute()/define_schema() overrides exist.
class_def.VALIDATE_CLASS()
return None
# V1: FUNCTION names the method to call; it must exist on the class.
function_name = getattr(class_def, "FUNCTION", None)
if function_name is None:
return f"'{class_type}' does not define FUNCTION"
if not callable(getattr(class_def, function_name, None)):
return f"'{class_type}' has no method '{function_name}' (declared in FUNCTION)"
return None
except Exception as ex:
return str(ex)
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
outputs = set()
for x in prompt:
@ -1174,35 +1148,6 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[
}
return (False, error, [], {})
# Make sure the node is actually executable (its FUNCTION/execute entry
# point resolves to a real method) before we touch any schema-derived
# attributes below or start execution. Catches code typos up front and
# attributes the error to the offending node.
not_executable = node_not_executable_reason(class_, class_type)
if not_executable is not None:
node_title = prompt[x].get('_meta', {}).get('title', class_type)
error = {
"type": "invalid_node_definition",
"message": "Node is not executable",
"details": f"{not_executable} (Node ID '#{x}')",
"extra_info": {
"node_id": x,
"class_type": class_type,
"node_title": node_title,
}
}
node_errors = {x: {
"errors": [{
"type": "invalid_node_definition",
"message": "Node is not executable",
"details": not_executable,
"extra_info": {},
}],
"dependent_outputs": [],
"class_type": class_type,
}}
return (False, error, [], node_errors)
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
if partial_execution_list is None or x in partial_execution_list:
outputs.add(x)

View File

@ -230,6 +230,93 @@ components:
- base_version
- workflow_json
type: object
DownloadEnqueueRequest:
description: Request body for enqueuing a server-side model download.
properties:
allow_any_extension:
default: false
description: Permit a non-model file extension (default only allows known model extensions).
type: boolean
credential_id:
description: Explicit per-host credential to use; otherwise auto-resolved by host. Still subject to the per-hop host match.
nullable: true
type: string
expected_sha256:
description: Optional hub-provided SHA256 to verify the completed file against (fail-closed).
nullable: true
type: string
model_id:
description: Destination as "<directory>/<filename>", resolving to a registered model folder (e.g. "loras/my_lora.safetensors").
type: string
priority:
default: 0
description: Scheduling priority; higher is admitted first.
type: integer
url:
description: Source URL; must be on the allowlist (host + scheme + extension).
type: string
required:
- url
- model_id
type: object
DownloadStatus:
description: Current state and live progress of a single download.
properties:
bytes_done:
type: integer
created_at:
type: integer
download_id:
format: uuid
type: string
error:
nullable: true
type: string
eta_seconds:
nullable: true
type: number
model_id:
type: string
priority:
type: integer
progress:
description: Fraction in [0,1]; null until total size is known.
nullable: true
type: number
segments:
description: Per-segment progress (segmented downloads only).
items:
properties:
bytes_done:
type: integer
idx:
type: integer
length:
type: integer
type: object
nullable: true
type: array
speed_bps:
nullable: true
type: number
status:
enum:
- queued
- active
- paused
- verifying
- completed
- failed
- cancelled
type: string
total_bytes:
nullable: true
type: integer
updated_at:
type: integer
url:
type: string
type: object
ErrorResponse:
description: Standard error response with a machine-readable code and human-readable message.
properties:
@ -511,6 +598,78 @@ components:
required:
- history
type: object
HostCredentialUpsert:
description: Request body for upserting a per-host credential. The secret is write-only.
properties:
auth_scheme:
default: bearer
description: How the secret is attached to requests.
enum:
- bearer
- header
- query
type: string
enabled:
default: true
type: boolean
header_name:
description: Header name when auth_scheme=header (defaults to Authorization).
nullable: true
type: string
host:
description: Normalized hostname the key applies to (e.g. "civitai.com").
type: string
label:
description: User-friendly name for display.
nullable: true
type: string
match_subdomains:
default: false
description: Also match label-boundary subdomains of host (off by default; unsafe for hub CDNs).
type: boolean
query_param:
description: Query parameter name when auth_scheme=query.
nullable: true
type: string
secret:
description: The API key. Write-only — never returned by any endpoint.
type: string
required:
- host
- secret
type: object
HostCredentialView:
description: Masked, API-safe view of a stored credential. Never includes the secret.
properties:
auth_scheme:
type: string
created_at:
type: integer
enabled:
type: boolean
header_name:
nullable: true
type: string
host:
type: string
id:
format: uuid
type: string
label:
nullable: true
type: string
match_subdomains:
type: boolean
query_param:
nullable: true
type: string
secret_last4:
description: Last 4 characters of the secret, for masked display only.
nullable: true
type: string
updated_at:
type: integer
type: object
JobCancelResponse:
description: Response for POST /api/jobs/{job_id}/cancel. Returned on both fresh cancels and idempotent no-ops.
properties:
@ -2350,6 +2509,330 @@ paths:
summary: Get tag histogram for filtered assets
tags:
- file
/api/download:
get:
description: List all known downloads (queued, active, paused, and terminal) with live progress.
operationId: listDownloads
responses:
"200":
content:
application/json:
schema:
properties:
downloads:
items:
$ref: '#/components/schemas/DownloadStatus'
type: array
type: object
description: List of downloads
summary: List downloads
tags:
- download
/api/download/availability:
post:
description: |
Bulk per-id availability for a set of model_ids declared in a workflow.
Returns whether each model is available on disk, currently downloading
(with progress), or missing, plus whether its URL is on the allowlist.
operationId: getModelsAvailability
requestBody:
content:
application/json:
schema:
properties:
models:
additionalProperties:
type: string
description: Map of "<directory>/<filename>" model_id to its declared source URL.
type: object
type: object
responses:
"200":
content:
application/json:
schema:
properties:
models:
additionalProperties: true
type: object
type: object
description: Per-id availability map
summary: Bulk model availability + status
tags:
- download
/api/download/credentials:
get:
description: List stored per-host credentials. Secrets are never returned; only masked metadata (last 4 chars, scheme, label).
operationId: listDownloadCredentials
responses:
"200":
content:
application/json:
schema:
properties:
credentials:
items:
$ref: '#/components/schemas/HostCredentialView'
type: array
type: object
description: Masked credential list
summary: List host credentials (masked)
tags:
- download
post:
description: |
Upsert (by host) a per-host API key used to authenticate downloads.
The secret is write-only: it is stored once here and never returned by any endpoint.
operationId: upsertDownloadCredential
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/HostCredentialUpsert'
responses:
"201":
content:
application/json:
schema:
$ref: '#/components/schemas/HostCredentialView'
description: Credential stored (masked view returned)
"400":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Invalid credential
summary: Upsert a host credential
tags:
- download
/api/download/credentials/{id}:
delete:
description: Delete a stored host credential.
operationId: deleteDownloadCredential
parameters:
- in: path
name: id
required: true
schema:
type: string
responses:
"200":
content:
application/json:
schema:
properties:
deleted:
type: boolean
type: object
description: Deleted
"404":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: No such credential
summary: Delete a host credential
tags:
- download
get:
description: Get a single host credential (masked; never includes the secret).
operationId: getDownloadCredential
parameters:
- in: path
name: id
required: true
schema:
type: string
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/HostCredentialView'
description: Masked credential
"404":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: No such credential
summary: Get a host credential (masked)
tags:
- download
/api/download/enqueue:
post:
description: |
Enqueue a server-side model download. The URL must be on the allowlist
(host + scheme + extension) and the model_id must be "<directory>/<filename>"
resolving to a registered model folder. Returns immediately; track progress
via GET /api/download/{id} or the "download_progress" websocket event.
operationId: enqueueDownload
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/DownloadEnqueueRequest'
responses:
"202":
content:
application/json:
schema:
properties:
accepted:
type: boolean
download_id:
format: uuid
type: string
type: object
description: Download accepted and queued
"400":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Invalid request (bad URL, model_id, or not allowlisted)
"409":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Already on disk or already downloading
summary: Enqueue a model download
tags:
- download
/api/download/{id}:
get:
description: Get the current status + progress of a single download.
operationId: getDownloadStatus
parameters:
- in: path
name: id
required: true
schema:
format: uuid
type: string
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/DownloadStatus'
description: Download status
"404":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: No such download
summary: Get download status
tags:
- download
/api/download/{id}/cancel:
post:
description: Cancel a download. The partial file is removed.
operationId: cancelDownload
parameters:
- in: path
name: id
required: true
schema:
format: uuid
type: string
responses:
"200":
content:
application/json:
schema:
properties:
ok:
type: boolean
type: object
description: Cancelled
summary: Cancel a download
tags:
- download
/api/download/{id}/pause:
post:
description: Pause a download. The partial file and per-segment offsets are retained for resume.
operationId: pauseDownload
parameters:
- in: path
name: id
required: true
schema:
format: uuid
type: string
responses:
"200":
content:
application/json:
schema:
properties:
ok:
type: boolean
type: object
description: Paused
summary: Pause a download
tags:
- download
/api/download/{id}/priority:
post:
description: Set a download's scheduling priority. Higher priority is admitted first when a slot frees.
operationId: setDownloadPriority
parameters:
- in: path
name: id
required: true
schema:
format: uuid
type: string
requestBody:
content:
application/json:
schema:
properties:
priority:
type: integer
required:
- priority
type: object
responses:
"200":
content:
application/json:
schema:
properties:
ok:
type: boolean
type: object
description: Priority updated
summary: Set download priority
tags:
- download
/api/download/{id}/resume:
post:
description: Resume a paused (or failed) download from its persisted offsets.
operationId: resumeDownload
parameters:
- in: path
name: id
required: true
schema:
format: uuid
type: string
responses:
"200":
content:
application/json:
schema:
properties:
ok:
type: boolean
type: object
description: Resumed
summary: Resume a download
tags:
- download
/api/embeddings:
get:
description: Returns the list of text-encoder embeddings available on disk.

View File

@ -22,7 +22,7 @@ alembic
SQLAlchemy>=2.0.0
filelock
av>=16.0.0
comfy-kitchen==0.2.12
comfy-kitchen==0.2.13
comfy-aimdo==0.4.10
requests
simpleeval>=1.0.0
@ -33,5 +33,5 @@ kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
PyOpenGL
glfw
PyOpenGL>=3.1.8
comfy-angle

View File

@ -45,6 +45,8 @@ from app.frontend_management import FrontendManager, parse_version
from comfy_api.internal import _ComfyNodeInternal
from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_routes
from app.model_downloader.api.routes import register_routes as register_model_downloader_routes
from app.model_downloader.manager import DOWNLOAD_MANAGER
from app.assets.services.ingest import register_file_in_place
from app.assets.services.asset_management import resolve_hash_to_path
@ -256,6 +258,7 @@ class PromptServer():
else:
register_assets_routes(self.app)
asset_seeder.disable()
register_model_downloader_routes(self.app)
routes = web.RouteTableDef()
self.routes = routes
self.last_node_id = None
@ -1182,6 +1185,24 @@ class PromptServer():
async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)
await self._setup_model_downloader()
async def _setup_model_downloader(self):
"""Start the download manager: push progress over the websocket and
resume any downloads interrupted by a previous run."""
def _notify(download_id: str) -> None:
try:
view = DOWNLOAD_MANAGER.status_sync(download_id)
if view is not None:
self.send_sync("download_progress", view)
except Exception:
logging.debug("download progress notify failed", exc_info=True)
DOWNLOAD_MANAGER.set_notify(_notify)
try:
await DOWNLOAD_MANAGER.start()
except Exception as e:
logging.warning("Failed to start model download manager: %s", e)
def add_routes(self):
self.user_manager.add_routes(self.routes)

View File

@ -1,137 +0,0 @@
"""Tests for pre-execution validation that a node is actually executable.
validate_prompt rejects a node whose declared entry point does not resolve to a
real method (a V1 FUNCTION typo, or a V3 node missing its execute override) before
any node runs, attributing the error to the offending node.
"""
import asyncio
import nodes
from comfy_api.latest import io
from execution import node_not_executable_reason, validate_prompt
class _GoodV1Node:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "run"
OUTPUT_NODE = True
CATEGORY = "Test"
def run(self):
return (None,)
class _TypoV1Node:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "invert" # method below is misspelled
OUTPUT_NODE = True
CATEGORY = "Test"
def invvert(self):
return (None,)
class _SideEffectInitV1Node:
"""Valid class-level method, but a constructor that must never run in validation."""
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "run"
OUTPUT_NODE = True
CATEGORY = "Test"
def __init__(self):
raise RuntimeError("__init__ must not run during validation")
def run(self):
return (None,)
def _v3_schema(node_id):
return io.Schema(
node_id=node_id,
display_name=node_id,
category="Test",
inputs=[],
outputs=[io.Image.Output()],
is_output_node=True,
)
class _GoodV3Node(io.ComfyNode):
@classmethod
def define_schema(cls):
return _v3_schema("GoodV3Node")
@classmethod
def execute(cls):
return io.NodeOutput(None)
class _TypoV3Node(io.ComfyNode):
@classmethod
def define_schema(cls):
return _v3_schema("TypoV3Node")
@classmethod
def exicute(cls): # typo: should be "execute"
return io.NodeOutput(None)
def _register(class_type, class_def):
nodes.NODE_CLASS_MAPPINGS[class_type] = class_def
def _validate(class_type):
prompt = {"1": {"class_type": class_type, "inputs": {}}}
return asyncio.run(validate_prompt("pid", prompt, None))
def test_good_node_passes():
_register("GoodV1Node", _GoodV1Node)
assert node_not_executable_reason(_GoodV1Node, "GoodV1Node") is None
valid, _, _, _ = _validate("GoodV1Node")
assert valid is True
def test_typo_node_rejected_with_node_error():
_register("TypoV1Node", _TypoV1Node)
valid, error, _, node_errors = _validate("TypoV1Node")
assert valid is False
assert error["type"] == "invalid_node_definition"
assert node_errors["1"]["class_type"] == "TypoV1Node"
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"
assert "invert" in node_errors["1"]["errors"][0]["details"]
def test_validation_does_not_instantiate_node():
"""A valid node is not constructed during validation, so __init__ never runs."""
_register("SideEffectInitV1Node", _SideEffectInitV1Node)
assert node_not_executable_reason(_SideEffectInitV1Node, "SideEffectInitV1Node") is None
valid, _, _, _ = _validate("SideEffectInitV1Node")
assert valid is True
def test_good_v3_node_passes():
_register("GoodV3Node", _GoodV3Node)
assert node_not_executable_reason(_GoodV3Node, "GoodV3Node") is None
valid, _, _, _ = _validate("GoodV3Node")
assert valid is True
def test_typo_v3_node_rejected_with_node_error():
_register("TypoV3Node", _TypoV3Node)
valid, error, _, node_errors = _validate("TypoV3Node")
assert valid is False
assert error["type"] == "invalid_node_definition"
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"

View File

@ -0,0 +1,63 @@
"""Shared fixtures for the model download manager tests.
These run in-process (no ComfyUI subprocess): a file-backed SQLite DB is
initialized once, a temp model folder is registered with ``folder_paths``, and
the shared aiohttp session is reset between tests so each async test gets a
session bound to its own event loop.
"""
from __future__ import annotations
import os
import tempfile
import pytest
@pytest.fixture(scope="session", autouse=True)
def _init_db():
import app.database.db as db
from comfy.cli_args import args
db_path = tempfile.mktemp(suffix="-dlmgr-test.sqlite3")
args.database_url = f"sqlite:///{db_path}"
db.init_db()
yield
try:
os.remove(db_path)
except OSError:
pass
@pytest.fixture(autouse=True)
def _reset_runtime():
"""Reset module singletons that hold event-loop-bound or cross-test state."""
import app.model_downloader.net.session as ns
from app.model_downloader.scheduler import SCHEDULER
ns._session = None
SCHEDULER._jobs.clear()
SCHEDULER._tasks.clear()
SCHEDULER._backoff_until.clear()
SCHEDULER._started = False
yield
ns._session = None
@pytest.fixture
def model_root(tmp_path):
"""Register a temp 'loras' model folder and return its absolute path."""
import folder_paths
root = tmp_path / "loras"
root.mkdir(parents=True, exist_ok=True)
saved = folder_paths.folder_names_and_paths.get("loras")
folder_paths.folder_names_and_paths["loras"] = (
[str(root)],
{".safetensors", ".sft", ".ckpt", ".pt", ".pth"},
)
yield str(root)
if saved is not None:
folder_paths.folder_names_and_paths["loras"] = saved
else:
folder_paths.folder_names_and_paths.pop("loras", None)

View File

@ -0,0 +1,110 @@
"""Unit tests for the credential store and the per-hop credential resolver.
Covers the critical rule: a secret is only ever attached when the current
hop's host matches a stored credential, and never over a non-https hop.
"""
from __future__ import annotations
import asyncio
import pytest
from app.model_downloader.credentials import resolver
from app.model_downloader.credentials.store import (
CREDENTIAL_STORE,
CredentialValidationError,
normalize_host,
)
from app.model_downloader.database.models import HostCredential
# ----- pure host normalization + matching -----
@pytest.mark.parametrize(
"raw,expected",
[
("Civitai.com", "civitai.com"),
("HuggingFace.co:443", "huggingface.co"),
(" Example.COM ", "example.com"),
],
)
def test_normalize_host(raw, expected):
assert normalize_host(raw) == expected
def _cred(**kw) -> HostCredential:
base = dict(
id="x", host="civitai.com", match_subdomains=False, auth_scheme="bearer",
secret="SECRET", enabled=True,
)
base.update(kw)
return HostCredential(**base)
def test_matches_exact_only_by_default():
c = _cred(host="civitai.com")
assert resolver._matches(c, "civitai.com") is True
assert resolver._matches(c, "api.civitai.com") is False
assert resolver._matches(c, "evil-civitai.com") is False
def test_matches_subdomain_label_boundary():
c = _cred(host="example.com", match_subdomains=True)
assert resolver._matches(c, "api.example.com") is True
assert resolver._matches(c, "example.com") is True
# not a label boundary -> no match
assert resolver._matches(c, "evil-example.com") is False
def test_build_auth_shapes():
assert resolver._build_auth(_cred(auth_scheme="bearer")).headers == {
"Authorization": "Bearer SECRET"
}
assert resolver._build_auth(
_cred(auth_scheme="header", header_name="X-Api-Key")
).headers == {"X-Api-Key": "SECRET"}
q = resolver._build_auth(_cred(auth_scheme="query", query_param="token"))
assert q.query == {"token": "SECRET"}
assert q.apply_to_url("https://civitai.com/x") == "https://civitai.com/x?token=SECRET"
# ----- DB-backed store + resolver -----
def test_store_upsert_is_write_only_and_masked():
async def _run():
view = await CREDENTIAL_STORE.upsert("civitai.com", "abcd1234", label="my key")
# The view never carries the secret, only the last 4.
assert not hasattr(view, "secret")
assert view.secret_last4 == "1234"
assert view.host == "civitai.com"
listed = await CREDENTIAL_STORE.list()
assert any(v.host == "civitai.com" for v in listed)
await CREDENTIAL_STORE.delete(view.id)
asyncio.run(_run())
def test_query_scheme_requires_param():
async def _run():
with pytest.raises(CredentialValidationError):
await CREDENTIAL_STORE.upsert("civitai.com", "k", auth_scheme="query")
asyncio.run(_run())
def test_resolver_never_crosses_host_boundary():
async def _run():
view = await CREDENTIAL_STORE.upsert("huggingface.co", "hf_secret_key")
try:
# matching host over https -> attached
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
assert auth is not None
assert auth.headers["Authorization"] == "Bearer hf_secret_key"
# CDN redirect host -> dropped
assert await resolver.resolve_auth_for_hop("cdn-lfs.huggingface.co", "https") is None
# non-https hop -> never attached
assert await resolver.resolve_auth_for_hop("huggingface.co", "http") is None
finally:
await CREDENTIAL_STORE.delete(view.id)
asyncio.run(_run())

View File

@ -0,0 +1,270 @@
"""Integration tests for the download engine against a local aiohttp server.
Covers single-stream and segmented transfers, deterministic resume from a
partial file, and cancel rollback. Async tests are driven via ``asyncio.run``
so no pytest-asyncio plugin is required.
"""
from __future__ import annotations
import asyncio
import os
import uuid
import pytest
from aiohttp import web
from comfy.cli_args import args
from app.model_downloader.constants import DownloadStatus
from app.model_downloader.database import queries
from app.model_downloader.engine.job import DownloadJob, JobSpec
from app.model_downloader.net.session import close_session
from app.model_downloader.security import paths
PAYLOAD_ETAG = '"v1"'
def _payload(n: int) -> bytes:
return bytes((i * 37 + 11) % 256 for i in range(n))
def _range_handler(payload: bytes):
async def handler(request: web.Request) -> web.Response:
rng = request.headers.get("Range")
if rng:
spec = rng.split("=", 1)[1]
s, _, e = spec.partition("-")
start = int(s)
end = int(e) if e else len(payload) - 1
chunk = payload[start : end + 1]
return web.Response(
status=206,
body=chunk,
headers={
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
"Accept-Ranges": "bytes",
"ETag": PAYLOAD_ETAG,
},
)
return web.Response(
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
)
return handler
def _noranges_handler(payload: bytes):
async def handler(request: web.Request) -> web.Response:
# Always full body, never advertises Accept-Ranges -> single-stream.
return web.Response(status=200, body=payload)
return handler
def _slow_handler(payload: bytes, chunk: int = 16384, delay: float = 0.01):
async def handler(request: web.Request) -> web.StreamResponse:
resp = web.StreamResponse(
status=200, headers={"Content-Length": str(len(payload))}
)
await resp.prepare(request)
for i in range(0, len(payload), chunk):
await resp.write(payload[i : i + chunk])
await asyncio.sleep(delay)
await resp.write_eof()
return resp
return handler
async def _serve(handler):
app = web.Application()
app.router.add_route("*", "/{name:.*}", handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", 0)
await site.start()
port = site._server.sockets[0].getsockname()[1]
return runner, port
def _insert(model_id: str, url: str, status: str = DownloadStatus.QUEUED) -> tuple[str, str, str]:
final_path, temp_path = paths.resolve_destination(model_id)
download_id = str(uuid.uuid4())
queries.insert_download(
{
"id": download_id,
"url": url,
"model_id": model_id,
"dest_path": final_path,
"temp_path": temp_path,
"status": status,
}
)
return download_id, final_path, temp_path
# ----- single-stream -----
def test_single_stream_download(model_root):
payload = _payload(300_000)
async def _run():
await close_session()
runner, port = await _serve(_noranges_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, _temp = _insert("loras/single.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/single.safetensors",
dest_path=final_path, temp_path=_temp,
))
status = await job.run()
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
assert os.path.exists(final_path)
assert open(final_path, "rb").read() == payload
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
# ----- segmented -----
def test_segmented_download(model_root):
payload = _payload(4 * 1024 * 1024) # 4 MiB -> multiple segments
async def _run():
await close_session()
runner, port = await _serve(_range_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/seg.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/seg.safetensors",
dest_path=final_path, temp_path=temp,
))
status = await job.run()
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
assert open(final_path, "rb").read() == payload
# More than one segment row was planned.
assert len(queries.list_segments(did)) > 1
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
# ----- deterministic resume from a partial file -----
def test_resume_from_partial(model_root):
payload = _payload(512 * 1024) # < 1 MiB -> single segment, but ranges work
async def _run():
await close_session()
runner, port = await _serve(_range_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/resume.safetensors", url)
# Simulate a prior partial: first 200 KiB already written, offset persisted.
prefix = 200 * 1024
os.makedirs(os.path.dirname(temp), exist_ok=True)
with open(temp, "wb") as f:
f.write(payload[:prefix])
queries.update_download(did, bytes_done=prefix, etag=PAYLOAD_ETAG)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/resume.safetensors",
dest_path=final_path, temp_path=temp, etag=PAYLOAD_ETAG,
))
status = await job.run()
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
assert open(final_path, "rb").read() == payload
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
# ----- cancel rollback -----
def test_cancel_rollback(model_root, monkeypatch):
monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False)
payload = _payload(1024 * 1024)
async def _run():
await close_session()
runner, port = await _serve(_slow_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/cancel.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/cancel.safetensors",
dest_path=final_path, temp_path=temp,
))
task = asyncio.ensure_future(job.run())
# Wait until some bytes have been written, then cancel.
for _ in range(200):
await asyncio.sleep(0.01)
if job.state.bytes_done > 0:
break
job.request_cancel()
status = await task
assert status == DownloadStatus.CANCELLED
assert not os.path.exists(temp)
assert not os.path.exists(final_path)
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
# ----- manager + scheduler end-to-end -----
def test_manager_enqueue_to_completion(model_root):
payload = _payload(2 * 1024 * 1024)
async def _run():
await close_session()
from app.model_downloader.manager import DOWNLOAD_MANAGER
runner, port = await _serve(_range_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/e2e.safetensors")
# Wait for completion.
final_path, _ = paths.resolve_destination("loras/e2e.safetensors")
for _ in range(500):
await asyncio.sleep(0.02)
row = queries.get_download(did)
if row.status in DownloadStatus.TERMINAL:
break
row = queries.get_download(did)
assert row.status == DownloadStatus.COMPLETED, row.error
assert open(final_path, "rb").read() == payload
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
def test_manager_rejects_disallowed_url(model_root):
async def _run():
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
with pytest.raises(DownloadError) as ei:
await DOWNLOAD_MANAGER.enqueue(
"https://evil.example.com/x.safetensors", "loras/bad.safetensors"
)
assert ei.value.code == "URL_NOT_ALLOWED"
asyncio.run(_run())

View File

@ -0,0 +1,71 @@
"""Unit tests for the segment planner and structural safetensors validation."""
from __future__ import annotations
import json
import struct
import pytest
from app.model_downloader.engine.planner import (
effective_segment_count,
plan_segments,
)
from app.model_downloader.verify import structural
# ----- planner -----
def test_plan_segments_covers_full_range_contiguously():
total = 1000
plans = plan_segments(total, 4)
assert len(plans) == 4
assert plans[0].start == 0
assert plans[-1].end == total - 1
# contiguous, no gaps/overlaps
for a, b in zip(plans, plans[1:]):
assert b.start == a.end + 1
assert sum(p.length for p in plans) == total
def test_effective_segment_count_falls_back_to_single():
# No range support -> single
assert effective_segment_count(10_000_000, False, 8) == 1
# Unknown size -> single
assert effective_segment_count(None, True, 8) == 1
# Tiny file -> fewer segments than configured
assert effective_segment_count(1024, True, 8) == 1
# Large file with range support -> configured count
assert effective_segment_count(1_000_000_000, True, 8) == 8
# ----- structural -----
def _make_safetensors(tensor_data_len: int, *, corrupt_size: bool = False) -> bytes:
header = {"t": {"dtype": "F32", "shape": [tensor_data_len], "data_offsets": [0, tensor_data_len]}}
header_bytes = json.dumps(header).encode("utf-8")
body = b"\x00" * tensor_data_len
if corrupt_size:
body = body[:-1] # truncate one byte
return struct.pack("<Q", len(header_bytes)) + header_bytes + body
def test_structural_valid_safetensors(tmp_path):
p = tmp_path / "ok.safetensors"
p.write_bytes(_make_safetensors(256))
structural.validate(str(p)) # no raise
def test_structural_detects_truncation(tmp_path):
p = tmp_path / "bad.safetensors"
p.write_bytes(_make_safetensors(256, corrupt_size=True))
with pytest.raises(structural.StructuralError):
structural.validate(str(p))
def test_structural_skips_unknown_extension(tmp_path):
p = tmp_path / "weights.bin"
p.write_bytes(b"anything")
structural.validate(str(p)) # no structural check, no raise

View File

@ -0,0 +1,111 @@
"""Unit tests for the security layer: allowlist, SSRF checks, path safety."""
from __future__ import annotations
import pytest
from app.model_downloader.security import allowlist, paths
from app.model_downloader.security.ssrf import (
SSRFError,
check_redirect_hop,
is_blocked_ip,
)
# ----- allowlist -----
@pytest.mark.parametrize(
"url,allowed",
[
("https://huggingface.co/org/repo/resolve/main/model.safetensors", True),
("https://civitai.com/api/download/x/model.safetensors", True),
("http://localhost/model.safetensors", True),
# off-list host
("https://evil.example.com/model.safetensors", False),
# http to a non-loopback allowlisted host is not permitted (https only)
("http://huggingface.co/org/repo/resolve/main/model.safetensors", False),
# bad extension on an allowed host
("https://huggingface.co/org/repo/resolve/main/config.json", False),
# userinfo trick: real host is the metadata IP, not 127.0.0.1
("http://127.0.0.1@169.254.169.254/x.safetensors", False),
],
)
def test_is_url_allowed(url, allowed):
assert allowlist.is_url_allowed(url) is allowed
def test_allow_any_extension_relaxes_extension_only():
url = "https://huggingface.co/org/repo/resolve/main/weights.bin"
assert allowlist.is_url_allowed(url) is True # .bin is in the known set
odd = "https://huggingface.co/org/repo/resolve/main/weights.zip"
assert allowlist.is_url_allowed(odd) is False
assert allowlist.is_url_allowed(odd, allow_any_extension=True) is True
# ----- SSRF: blocked IPs -----
@pytest.mark.parametrize(
"ip,blocked",
[
("169.254.169.254", True), # cloud metadata / link-local
("127.0.0.1", True),
("10.0.0.5", True),
("192.168.1.1", True),
("172.16.0.1", True),
("::1", True),
("0.0.0.0", True),
("8.8.8.8", False),
("1.1.1.1", False),
("not-an-ip", True), # unparseable -> refuse
],
)
def test_is_blocked_ip(ip, blocked):
assert is_blocked_ip(ip) is blocked
# ----- SSRF: redirect hop validation -----
def test_check_redirect_hop_rejects_bad_scheme_and_userinfo():
with pytest.raises(SSRFError):
check_redirect_hop("ftp://huggingface.co/x.safetensors")
with pytest.raises(SSRFError):
check_redirect_hop("https://user:pass@cdn.example.com/x")
# A CDN host that is NOT on the allowlist is allowed as a redirect target
# (private-IP protection is the resolver's job; credential leak is prevented
# by exact host matching).
assert check_redirect_hop("https://cdn-lfs.huggingface.co/abc") is not None
# ----- path safety -----
def test_parse_model_id_valid(model_root):
directory, filename = paths.parse_model_id("loras/my_lora.safetensors")
assert directory == "loras"
assert filename == "my_lora.safetensors"
@pytest.mark.parametrize(
"model_id",
[
"loras/../etc/passwd.safetensors", # traversal
"loras/sub/dir.safetensors", # nested
"unknownfolder/x.safetensors", # unknown folder
"loras/model.txt", # bad extension
"noslash.safetensors", # missing directory
"loras/", # empty filename
],
)
def test_parse_model_id_rejects(model_root, model_id):
with pytest.raises(paths.InvalidModelId):
paths.parse_model_id(model_id)
def test_resolve_destination_stays_in_root(model_root):
final_path, temp_path = paths.resolve_destination("loras/x.safetensors")
assert final_path.startswith(model_root)
assert temp_path.startswith(model_root)
assert temp_path != final_path