Compare commits

..

14 Commits

Author SHA1 Message Date
f25af71a5e test(assets): lock loader_path matrix (asymmetry, null, persist/read)
Cover the behaviour that has no production change but is easy to regress:
the extra-path asymmetry (loadable but no storage namespace), null
loader_path persistence for orphan files, and the response reading the
stored column with a compute fallback for un-backfilled rows.
2026-07-02 08:32:52 +12:00
dacae94cf4 feat(assets): rename response field to loader_path and persist it
Rename the in-root loader path response field from `file_path` to
`loader_path` (matching compute_loader_path), and persist it on
asset_references so the API reads it directly instead of re-resolving
against every registered model-folder base per request.

- add loader_path column (migration 0006) populated at scan/ingest from
  the already-computed loader path
- response prefers the stored value, falling back to compute for rows
  written before the column existed
2026-07-02 08:31:37 +12:00
92417a7ae4 feat(assets): add in-root loader file_path, rename storage locator to logical_path
Split the Asset response path fields so model-loader consumers get a
category-relative path. The namespaced storage locator moves to
`logical_path`; the new `file_path` is the in-root loader path (model
category dropped), e.g. models/checkpoints/foo/bar.safetensors -> foo/bar.safetensors.
2026-07-02 08:30:16 +12:00
ca5adea2e3 test(assets): make duplicate path normalization portable
Amp-Thread-ID: https://ampcode.com/threads/T-019ecf39-2e6f-747d-ae80-addba6b8e4f5
Co-authored-by: Amp <amp@ampcode.com>
2026-07-02 08:24:30 +12:00
ccc9387298 fix(assets): merge duplicate scan specs
Co-authored-by: Amp <amp@ampcode.com>
Amp-Thread-ID: https://ampcode.com/threads/T-019ecf39-2e6f-747d-ae80-addba6b8e4f5
2026-07-02 08:21:14 +12:00
04198cd192 Merge branch 'master' into synap5e/assets-namespaced-tags 2026-06-29 14:59:13 -07:00
a58473fd9b chore: update embedded docs to v0.5.6 (#14668)
Co-authored-by: Alexis Rolland <alexisrolland@hotmail.com>
2026-06-29 17:08:06 +08:00
79c555ce6b Fix int8 mm being skipped on offloaded lora weights. (#14669) 2026-06-28 23:52:36 -04:00
f19735759e ci: add team-gated Cursor review (thin caller for github-workflows) (#14527) 2026-06-27 23:34:30 -07:00
a95e461916 int8 support on turing GPUs. (#14662) 2026-06-27 15:53:11 -07:00
f74388346e fix(assets): mark path-derived upload tags automatic
Amp-Thread-ID: https://ampcode.com/threads/T-019ecf39-2e6f-747d-ae80-addba6b8e4f5
Co-authored-by: Amp <amp@ampcode.com>
2026-06-27 15:35:10 +12:00
44b3239068 feat(assets): add namespaced model type tags
Amp-Thread-ID: https://ampcode.com/threads/T-019ecf39-2e6f-747d-ae80-addba6b8e4f5
Co-authored-by: Amp <amp@ampcode.com>
2026-06-27 15:14:20 +12:00
603d891eaf Update GLSL node to use ANGLE library (CORE-162) (#13195) 2026-06-27 08:40:31 +08:00
470ac36a0a Fix int8 loras causing lower quality requant with wrong settings. (#14650)
* Update comfy-kitchen

* Support requantizing with same settings as orig quant.
2026-06-26 16:41:29 -07:00
44 changed files with 2089 additions and 750 deletions

38
.github/workflows/ci-cursor-review.yml vendored Normal file
View File

@ -0,0 +1,38 @@
name: CI - Cursor Review
# Thin caller for the shared reusable cursor-review workflow in
# Comfy-Org/github-workflows. The review logic (panel matrix, judge
# consolidation, prompts, extract/post/notify scripts) lives there as the
# single source of truth, so this repo only carries the repo-specific diff
# excludes.
on:
pull_request:
types: [labeled, unlabeled]
concurrency:
group: cursor-review-pr-${{ github.event.pull_request.number }}-${{ github.event.label.name }}
cancel-in-progress: true
jobs:
cursor-review:
if: github.event.label.name == 'cursor-review'
permissions:
contents: read
pull-requests: write
# SHA-pinned per zizmor `unpinned-uses: hash-pin`. Bump this SHA to pick up
# upstream changes; keep `workflows_ref` matching so prompts/scripts load
# from the same commit as the workflow definition.
uses: Comfy-Org/github-workflows/.github/workflows/cursor-review.yml@047ca48febe3a6647608ed2e0c4331b491cb9d6a # github-workflows#9
with:
workflows_ref: 047ca48febe3a6647608ed2e0c4331b491cb9d6a
diff_excludes: >-
:!**/.claude/**
:!**/dist/**
:!**/vendor/**
:!**/*.generated.*
:!**/*.min.js
:!**/*.min.css
secrets:
CURSOR_API_KEY: ${{ secrets.CURSOR_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}

View File

@ -0,0 +1,107 @@
"""
Allow case-sensitive tag names.
Revision ID: 0005_allow_case_sensitive_tags
Revises: 0004_drop_tag_type
Create Date: 2026-06-16
"""
import sqlalchemy as sa
from alembic import op
revision = "0005_allow_case_sensitive_tags"
down_revision = "0004_drop_tag_type"
branch_labels = None
depends_on = None
def upgrade() -> None:
bind = op.get_bind()
if bind.dialect.name == "sqlite":
# SQLite cannot ALTER/DROP CHECK constraints. Recreate the small tag
# vocabulary table without the lowercase constraint while preserving
# existing tag names.
op.execute("PRAGMA foreign_keys=OFF")
try:
op.execute(
"CREATE TABLE tags_new ("
"name VARCHAR(512) NOT NULL, "
"CONSTRAINT pk_tags PRIMARY KEY (name)"
")"
)
op.execute("INSERT INTO tags_new(name) SELECT name FROM tags")
op.execute("DROP TABLE tags")
op.execute("ALTER TABLE tags_new RENAME TO tags")
finally:
op.execute("PRAGMA foreign_keys=ON")
return
op.drop_constraint("ck_tags_ck_tags_lowercase", "tags", type_="check")
def downgrade() -> None:
# Existing mixed-case tags cannot satisfy the old constraint. Lowercase them
# before restoring it, merging duplicate vocabulary/link rows that collide.
bind = op.get_bind()
tag_names = [row[0] for row in bind.execute(sa.text("SELECT name FROM tags"))]
existing_names = set(tag_names)
lowercase_names = sorted({name.lower() for name in tag_names})
missing_lowercase_rows = [
{"name": name} for name in lowercase_names if name not in existing_names
]
if missing_lowercase_rows:
bind.execute(sa.text("INSERT INTO tags(name) VALUES (:name)"), missing_lowercase_rows)
link_rows = bind.execute(
sa.text(
"SELECT asset_reference_id, tag_name, origin, added_at "
"FROM asset_reference_tags "
"ORDER BY asset_reference_id, tag_name"
)
).mappings()
deduped_links = {}
for row in link_rows:
key = (row["asset_reference_id"], row["tag_name"].lower())
deduped_links.setdefault(
key,
{
"asset_reference_id": row["asset_reference_id"],
"tag_name": row["tag_name"].lower(),
"origin": row["origin"],
"added_at": row["added_at"],
},
)
op.execute("DELETE FROM asset_reference_tags")
if deduped_links:
bind.execute(
sa.text(
"INSERT INTO asset_reference_tags "
"(asset_reference_id, tag_name, origin, added_at) "
"VALUES (:asset_reference_id, :tag_name, :origin, :added_at)"
),
list(deduped_links.values()),
)
op.execute("DELETE FROM tags WHERE name != lower(name)")
if bind.dialect.name == "sqlite":
op.execute("PRAGMA foreign_keys=OFF")
try:
op.execute(
"CREATE TABLE tags_new ("
"name VARCHAR(512) NOT NULL, "
"CONSTRAINT pk_tags PRIMARY KEY (name), "
"CONSTRAINT ck_tags_lowercase CHECK (name = lower(name))"
")"
)
op.execute("INSERT INTO tags_new(name) SELECT name FROM tags")
op.execute("DROP TABLE tags")
op.execute("ALTER TABLE tags_new RENAME TO tags")
finally:
op.execute("PRAGMA foreign_keys=ON")
return
op.create_check_constraint(
"ck_tags_ck_tags_lowercase", "tags", "name = lower(name)"
)

View File

@ -0,0 +1,30 @@
"""
Add loader_path column to asset_references.
Stores the in-root loader path (path relative to the storage root with the
top-level model category dropped) derived from file_path at scan/ingest time,
so the assets API can return it without re-resolving against every registered
model-folder base on every request.
Revision ID: 0006_add_loader_path
Revises: 0005_allow_case_sensitive_tags
Create Date: 2026-07-02
"""
from alembic import op
import sqlalchemy as sa
revision = "0006_add_loader_path"
down_revision = "0005_allow_case_sensitive_tags"
branch_labels = None
depends_on = None
def upgrade() -> None:
with op.batch_alter_table("asset_references") as batch_op:
batch_op.add_column(sa.Column("loader_path", sa.Text(), nullable=True))
def downgrade() -> None:
with op.batch_alter_table("asset_references") as batch_op:
batch_op.drop_column("loader_path")

View File

@ -10,7 +10,6 @@ from typing import Any
from aiohttp import web
from pydantic import ValidationError
import folder_paths
from app import user_manager
from app.assets.api import schemas_in, schemas_out
from app.assets.services import schemas
@ -40,6 +39,10 @@ from app.assets.services import (
upload_from_temp_path,
)
from app.assets.services.cursor import InvalidCursorError
from app.assets.services.path_utils import (
compute_asset_response_paths,
compute_loader_path,
)
from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef()
@ -161,11 +164,25 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
preview_url = None
else:
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
if result.ref.file_path:
paths = compute_asset_response_paths(result.ref.file_path)
logical_path, display_name = paths if paths else (None, None)
# In-root loader path (model category dropped): what model loaders consume.
# Persisted at scan/ingest; fall back to computing for rows written
# before the column existed.
loader_path = result.ref.loader_path
if loader_path is None:
loader_path = compute_loader_path(result.ref.file_path)
else:
logical_path, display_name, loader_path = None, None, None
asset_content_hash = result.asset.hash if result.asset else None
return schemas_out.Asset(
id=result.ref.id,
name=result.ref.name,
hash=asset_content_hash,
loader_path=loader_path,
logical_path=logical_path,
display_name=display_name,
asset_hash=asset_content_hash,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
@ -416,17 +433,6 @@ async def upload_asset(request: web.Request) -> web.Response:
400, "INVALID_BODY", f"Validation failed: {ve.json()}"
)
if spec.tags and spec.tags[0] == "models":
if (
len(spec.tags) < 2
or spec.tags[1] not in folder_paths.folder_names_and_paths
):
delete_temp_file_if_exists(parsed.tmp_path)
category = spec.tags[1] if len(spec.tags) >= 2 else ""
return _build_error_response(
400, "INVALID_BODY", f"unknown models category '{category}'"
)
try:
# Fast path: hash exists, create AssetReference without writing anything
if spec.hash and parsed.provided_hash_exists is True:
@ -470,7 +476,7 @@ async def upload_asset(request: web.Request) -> web.Response:
return _build_error_response(400, e.code, str(e))
except ValueError as e:
delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, "BAD_REQUEST", str(e))
return _build_error_response(400, "INVALID_BODY", str(e))
except HashMismatchError as e:
delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, "HASH_MISMATCH", str(e))

View File

@ -140,7 +140,7 @@ class CreateFromHashBody(BaseModel):
if v is None:
return []
if isinstance(v, list):
out = [str(t).strip().lower() for t in v if str(t).strip()]
out = [str(t).strip() for t in v if str(t).strip()]
seen = set()
dedup = []
for t in out:
@ -149,7 +149,7 @@ class CreateFromHashBody(BaseModel):
dedup.append(t)
return dedup
if isinstance(v, str):
return [t.strip().lower() for t in v.split(",") if t.strip()]
return list(dict.fromkeys(t.strip() for t in v.split(",") if t.strip()))
return []
@ -206,7 +206,7 @@ class TagsListQuery(BaseModel):
if v is None:
return v
v = v.strip()
return v.lower() or None
return v or None
class TagsAdd(BaseModel):
@ -220,7 +220,7 @@ class TagsAdd(BaseModel):
for t in v:
if not isinstance(t, str):
raise TypeError("tags must be strings")
tnorm = t.strip().lower()
tnorm = t.strip()
if tnorm:
out.append(tnorm)
seen = set()
@ -239,8 +239,8 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: optional list; if provided, first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category
- tags: labels plus one destination role ('models'|'input'|'output') for new bytes;
if role == 'models', exactly one model_type:<folder_name> tag is required
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
@ -309,7 +309,7 @@ class UploadAssetSpec(BaseModel):
norm = []
seen = set()
for t in items:
tnorm = str(t).strip().lower()
tnorm = str(t).strip()
if tnorm and tnorm not in seen:
seen.add(tnorm)
norm.append(tnorm)
@ -335,14 +335,4 @@ class UploadAssetSpec(BaseModel):
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
raise ValueError("at least one tag is required for uploads")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError(
"models uploads require a category tag as the second tag"
)
return self

View File

@ -9,8 +9,24 @@ class Asset(BaseModel):
``id`` here is the AssetReference id, not the content-addressed Asset id."""
id: str
name: str
name: str = Field(
...,
deprecated=True,
description="Reference label, often caller-provided or derived from the filename. Deprecated for storage path/display semantics; use `loader_path`, `logical_path`, and `display_name` when present.",
)
hash: str | None = None
loader_path: str | None = Field(
default=None,
description="In-root loader path for filesystem-backed assets: the path relative to its storage root with the top-level model category dropped (e.g. `models/checkpoints/foo/bar.safetensors` -> `foo/bar.safetensors`). This is the value model loaders consume. `None` when the file is not within a recognized root or model category.",
)
logical_path: str | None = Field(
default=None,
description="Runtime storage locator for filesystem-backed assets, using Comfy storage namespaces such as `input/`, `output/`, `temp/`, or `models/` (e.g. `models/checkpoints/foo/bar.safetensors`). Not an absolute filesystem path, unique identity, or model loader path.",
)
display_name: str | None = Field(
default=None,
description="Human-facing label derived from `logical_path`, usually the path below the top-level storage namespace. Not unique.",
)
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None

View File

@ -140,7 +140,6 @@ async def parse_multipart_upload(
provided_mime_type = ((await field.text()) or "").strip() or None
elif fname == "preview_id":
provided_preview_id = ((await field.text()) or "").strip() or None
if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError(
400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'."

View File

@ -76,6 +76,10 @@ class AssetReference(Base):
# Cache state fields (from former AssetCacheState)
file_path: Mapped[str | None] = mapped_column(Text, nullable=True)
# In-root loader path derived from file_path at scan/ingest time (model
# category dropped). Persisted so responses read it directly instead of
# re-resolving against every registered model-folder base per request.
loader_path: Mapped[str | None] = mapped_column(Text, nullable=True)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)

View File

@ -650,6 +650,7 @@ def upsert_reference(
name: str,
mtime_ns: int,
owner_id: str = "",
loader_path: str | None = None,
) -> tuple[bool, bool]:
"""Upsert a reference by file_path. Returns (created, updated).
@ -659,6 +660,7 @@ def upsert_reference(
vals = {
"asset_id": asset_id,
"file_path": file_path,
"loader_path": loader_path,
"name": name,
"owner_id": owner_id,
"mtime_ns": int(mtime_ns),

View File

@ -265,6 +265,8 @@ def list_tags_with_usage(
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
prefix_filter = prefix.strip() if prefix else ""
counts_sq = (
select(
AssetReferenceTag.tag_name.label("tag_name"),
@ -293,9 +295,8 @@ def list_tags_with_usage(
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if prefix_filter:
q = q.where(func.substr(Tag.name, 1, len(prefix_filter)) == prefix_filter)
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
@ -306,9 +307,8 @@ def list_tags_with_usage(
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if prefix_filter:
total_q = total_q.where(func.substr(Tag.name, 1, len(prefix_filter)) == prefix_filter)
if not include_zero:
visible_tags_sq = (
select(AssetReferenceTag.tag_name)

View File

@ -41,10 +41,10 @@ def get_utc_now() -> datetime:
def normalize_tags(tags: list[str] | None) -> list[str]:
"""
Normalize a list of tags by:
- Stripping whitespace and converting to lowercase.
- Removing duplicates.
- Stripping whitespace.
- Removing exact duplicates while preserving order and case.
"""
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
return list(dict.fromkeys(t.strip() for t in (tags or []) if (t or "").strip()))
def validate_blake3_hash(s: str) -> str:

View File

@ -36,7 +36,7 @@ from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
from app.assets.services.image_dimensions import extract_image_dimensions
from app.assets.services.metadata_extract import extract_file_metadata
from app.assets.services.path_utils import (
compute_relative_filename,
compute_loader_path,
get_comfy_models_folders,
get_name_and_tags_from_asset_path,
)
@ -308,7 +308,7 @@ def build_asset_specs(
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
rel_fname = compute_relative_filename(abs_p)
rel_fname = compute_loader_path(abs_p)
# Extract metadata (tier 1: filesystem, tier 2: safetensors header)
metadata = None
@ -430,7 +430,7 @@ def enrich_asset(
return new_level
initial_mtime_ns = get_mtime_ns(stat_p)
rel_fname = compute_relative_filename(file_path)
rel_fname = compute_loader_path(file_path)
mime_type: str | None = None
metadata = None

View File

@ -38,7 +38,7 @@ from app.assets.database.queries import (
update_reference_updated_at,
)
from app.assets.helpers import select_best_live_path
from app.assets.services.path_utils import compute_relative_filename
from app.assets.services.path_utils import compute_loader_path
from app.assets.services.schemas import (
AssetData,
AssetDetailResult,
@ -91,7 +91,7 @@ def update_asset_metadata(
update_reference_name(session, reference_id=reference_id, name=name)
touched = True
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
computed_filename = compute_loader_path(ref.file_path) if ref.file_path else None
new_meta: dict | None = None
if user_metadata is not None:

View File

@ -56,6 +56,7 @@ class ReferenceRow(TypedDict):
id: str
asset_id: str
file_path: str
loader_path: str | None
mtime_ns: int
owner_id: str
name: str
@ -134,6 +135,14 @@ def batch_insert_seed_assets(
for spec in specs:
absolute_path = os.path.abspath(spec["abs_path"])
existing_asset_id = path_to_asset_id.get(absolute_path)
if existing_asset_id is not None:
existing_tags = asset_id_to_ref_data[existing_asset_id]["tags"]
asset_id_to_ref_data[existing_asset_id]["tags"] = list(
dict.fromkeys([*existing_tags, *spec["tags"]])
)
continue
asset_id = str(uuid.uuid4())
reference_id = str(uuid.uuid4())
absolute_path_list.append(absolute_path)
@ -164,6 +173,8 @@ def batch_insert_seed_assets(
"id": reference_id,
"asset_id": asset_id,
"file_path": absolute_path,
# spec["fname"] is compute_loader_path(abs_path) from build_asset_specs.
"loader_path": spec["fname"],
"mtime_ns": spec["mtime_ns"],
"owner_id": owner_id,
"name": spec["info_name"],

View File

@ -33,8 +33,9 @@ from app.assets.services.bulk_ingest import batch_insert_seed_assets
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.image_dimensions import extract_image_dimensions
from app.assets.services.path_utils import (
compute_relative_filename,
compute_loader_path,
get_name_and_tags_from_asset_path,
get_path_derived_tags_from_path,
resolve_destination_from_tags,
validate_path_within_base,
)
@ -91,6 +92,7 @@ def _ingest_file_from_path(
name=info_name or os.path.basename(locator),
mtime_ns=mtime_ns,
owner_id=owner_id,
loader_path=compute_loader_path(locator),
)
# Get the reference we just created/updated
@ -101,17 +103,32 @@ def _ingest_file_from_path(
if preview_id and ref.preview_id != preview_id:
ref.preview_id = preview_id
norm = normalize_tags(list(tags))
if norm:
try:
backend_tags = get_path_derived_tags_from_path(locator)
except ValueError:
backend_tags = []
caller_tags = normalize_tags(tags)
backend_tags = normalize_tags(backend_tags)
all_tags = normalize_tags([*caller_tags, *backend_tags])
if all_tags:
if require_existing_tags:
validate_tags_exist(session, norm)
add_tags_to_reference(
session,
reference_id=reference_id,
tags=norm,
origin=tag_origin,
create_if_missing=not require_existing_tags,
)
validate_tags_exist(session, all_tags)
if backend_tags:
add_tags_to_reference(
session,
reference_id=reference_id,
tags=backend_tags,
origin="automatic",
create_if_missing=not require_existing_tags,
)
if caller_tags:
add_tags_to_reference(
session,
reference_id=reference_id,
tags=caller_tags,
origin=tag_origin,
create_if_missing=not require_existing_tags,
)
_update_metadata_with_filename(
session,
@ -288,7 +305,7 @@ def _register_existing_asset(
return result
new_meta = dict(user_metadata)
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
computed_filename = compute_loader_path(ref.file_path) if ref.file_path else None
if computed_filename:
new_meta["filename"] = computed_filename
@ -335,7 +352,7 @@ def _update_metadata_with_filename(
current_metadata: dict | None,
user_metadata: dict[str, Any],
) -> None:
computed_filename = compute_relative_filename(file_path) if file_path else None
computed_filename = compute_loader_path(file_path) if file_path else None
current_meta = current_metadata or {}
new_meta = dict(current_meta)
@ -474,6 +491,10 @@ def upload_from_temp_path(
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
# Once content is already known, duplicate byte uploads are treated as
# reference-only creation. Request tags are labels only here: do not
# require upload destination tags, do not move bytes, and do not
# synthesize path-derived classification or uploaded provenance.
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
@ -535,7 +556,7 @@ def upload_from_temp_path(
owner_id=owner_id,
preview_id=preview_id,
user_metadata=user_metadata or {},
tags=tags,
tags=[*(tags or []), "uploaded"],
tag_origin="manual",
require_existing_tags=False,
)
@ -569,15 +590,19 @@ def register_file_in_place(
) -> UploadResult:
"""Register an already-saved file in the asset database without moving it.
Tags are derived from the filesystem path (root category + subfolder names),
merged with any caller-provided tags, matching the behavior of the scanner.
This helper is used by upload paths that have already written bytes before
registering the file, so it records the same ``uploaded`` tag as the
multipart byte-upload path.
Tags are derived from trusted filesystem classification and merged with any
caller-provided tags, matching the behavior of the scanner.
If the path is not under a known root, only the caller-provided tags are used.
"""
try:
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
except ValueError:
path_tags = []
merged_tags = normalize_tags([*path_tags, *tags])
merged_tags = normalize_tags([*path_tags, *tags, "uploaded"])
try:
digest, _ = hashing.compute_blake3_hash(abs_path)

View File

@ -3,10 +3,10 @@ from pathlib import Path
from typing import Literal
import folder_paths
from app.assets.helpers import normalize_tags
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
_NON_MODEL_FOLDER_NAMES = frozenset({"configs", "custom_nodes"})
_KNOWN_SUBFOLDER_TAGS = frozenset({"3d", "pasted", "painter", "threed", "webcam"})
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
@ -14,7 +14,7 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
Includes every category registered in folder_names_and_paths,
regardless of whether its paths are under the main models_dir,
but excludes non-model entries like custom_nodes.
but excludes non-model entries like configs and custom_nodes.
"""
targets: list[tuple[str, list[str]]] = []
for name, values in folder_paths.folder_names_and_paths.items():
@ -27,35 +27,37 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
if not tags:
raise ValueError("tags must not be empty")
root = tags[0].lower()
"""Validates and maps upload routing tags -> (base_dir, subdirs_for_fs).
The request tags are only used to choose the write destination. Extra tags
remain labels; they do not become path components or trusted classification.
"""
destination_roles = [t for t in tags if t in {"input", "models", "output"}]
if len(destination_roles) != 1:
raise ValueError("uploads require exactly one destination role: input, models, or output")
root = destination_roles[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
model_type_tags = [t for t in tags if t.startswith("model_type:")]
if len(model_type_tags) != 1:
raise ValueError("models uploads require exactly one model_type:<folder_name> tag")
folder_name = model_type_tags[0].split(":", 1)[1]
if not folder_name:
raise ValueError("models uploads require exactly one model_type:<folder_name> tag")
model_folder_paths = dict(get_comfy_models_folders())
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
bases = model_folder_paths[folder_name]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
raise ValueError(f"unknown model category '{folder_name}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
raise ValueError(f"no base path configured for category '{folder_name}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
elif root == "input":
base_dir = os.path.abspath(folder_paths.get_input_directory())
raw_subdirs = tags[1:]
elif root == "output":
base_dir = os.path.abspath(folder_paths.get_output_directory())
raw_subdirs = tags[1:]
else:
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
_sep_chars = frozenset(("/", "\\", os.sep))
for i in raw_subdirs:
if i in (".", "..") or _sep_chars & set(i):
raise ValueError("invalid path component in tags")
base_dir = os.path.abspath(folder_paths.get_output_directory())
return base_dir, raw_subdirs if raw_subdirs else []
return base_dir, []
def validate_path_within_base(candidate: str, base: str) -> None:
@ -65,14 +67,76 @@ def validate_path_within_base(candidate: str, base: str) -> None:
raise ValueError("destination escapes base directory")
def compute_relative_filename(file_path: str) -> str | None:
def _compute_relative_path(child: str, parent: str) -> str:
rel = os.path.relpath(os.path.abspath(child), os.path.abspath(parent))
if rel == ".":
return ""
return rel.replace(os.sep, "/")
def _is_relative_to(child: str, parent: str) -> bool:
return Path(os.path.abspath(child)).is_relative_to(os.path.abspath(parent))
def compute_asset_response_paths(file_path: str) -> tuple[str, str | None] | None:
"""Return public (file_path, display_name) response fields for a file path.
These fields are storage locators, not model-loader namespaces. Registered
model-folder membership is represented by backend tags such as
``model_type:<folder_name>``; response paths only use known storage roots.
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
fp_abs = os.path.abspath(file_path)
candidates: list[tuple[int, int, str, str]] = []
for order, (namespace, base) in enumerate(
(
("input", folder_paths.get_input_directory()),
("output", folder_paths.get_output_directory()),
("temp", folder_paths.get_temp_directory()),
("models", getattr(folder_paths, "models_dir", "")),
)
):
if not base:
continue
base_abs = os.path.abspath(base)
if _is_relative_to(fp_abs, base_abs):
candidates.append((len(base_abs), -order, namespace, base_abs))
if not candidates:
return None
_base_len, _order, namespace, base = max(candidates)
rel = _compute_relative_path(fp_abs, base)
public_path = f"{namespace}/{rel}" if rel else namespace
return public_path, rel or None
def compute_display_name(file_path: str) -> str | None:
"""Return the asset's `display_name`, or None for unknown paths."""
result = compute_asset_response_paths(file_path)
return result[1] if result else None
def compute_logical_path(file_path: str) -> str | None:
"""Return the asset's namespaced storage `logical_path`, or None for unknown paths."""
result = compute_asset_response_paths(file_path)
return result[0] if result else None
def compute_loader_path(file_path: str) -> str | None:
"""
Return the asset's in-root loader path: the path relative to the last
well-known folder (the model category), using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
This is the value model loaders consume (the model category is dropped). It
backs the public Asset response `file_path` field and the internal
``computed_filename`` metadata. The namespaced storage locator (`logical_path`)
and human-facing `display_name` come from compute_asset_response_paths().
For input/output/temp paths the full path relative to that root is returned.
For paths outside any known root, returns None.
"""
try:
root_category, rel_path = get_asset_category_and_relative_path(file_path)
@ -116,9 +180,10 @@ def get_asset_category_and_relative_path(
def _compute_relative(child: str, parent: str) -> str:
# Normalize relative path, stripping any leading ".." components
# by anchoring to root (os.sep) then computing relpath back from it.
return os.path.relpath(
rel = os.path.relpath(
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
)
return "" if rel == "." else rel.replace(os.sep, "/")
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
@ -149,25 +214,99 @@ def get_asset_category_and_relative_path(
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
normalized = os.path.relpath(os.path.join(os.sep, combined), os.sep)
return "models", normalized.replace(os.sep, "/")
raise ValueError(
f"Path is not within input, output, temp, or configured model bases: {file_path}"
)
def get_backend_system_tags_from_path(path: str) -> list[str]:
"""Return trusted backend tags derived from current filesystem facts.
The returned tags are only the backend-generated system tags: ``models``,
``model_type:<folder_name>``, ``input``, ``output``, and ``temp``. Model
type tags are based on registered folder names, not path components.
"""
fp_abs = os.path.abspath(path)
fp_path = Path(fp_abs)
tags: list[str] = []
def _add(tag: str) -> None:
if tag not in tags:
tags.append(tag)
for role, base in (
("input", folder_paths.get_input_directory()),
("output", folder_paths.get_output_directory()),
("temp", folder_paths.get_temp_directory()),
):
if fp_path.is_relative_to(os.path.abspath(base)):
_add(role)
model_types: list[str] = []
for folder_name, bases in get_comfy_models_folders():
for base in bases:
if fp_path.is_relative_to(os.path.abspath(base)):
model_types.append(folder_name)
break
if model_types:
_add("models")
for folder_name in model_types:
_add(f"model_type:{folder_name}")
if not tags:
raise ValueError(
f"Path is not within input, output, temp, or configured model bases: {path}"
)
return tags
def get_known_subfolder_tags(subfolder: str | None) -> list[str]:
"""Return tags for known UI/input subfolder names."""
if subfolder in _KNOWN_SUBFOLDER_TAGS:
return [subfolder]
return []
def get_known_input_subfolder_tags_from_path(path: str) -> list[str]:
"""Return known input-layout tags for files in canonical input subfolders.
These are compatibility tags for current UI-origin input directories such as
``pasted`` and ``webcam``. They are intentionally narrow: only files directly
inside a known top-level input directory receive the matching tag.
"""
fp_abs = os.path.abspath(path)
input_base = os.path.abspath(folder_paths.get_input_directory())
if not Path(fp_abs).is_relative_to(input_base):
return []
rel = os.path.relpath(fp_abs, input_base)
parts = Path(rel).parts
if len(parts) == 2:
return get_known_subfolder_tags(parts[0])
return []
def get_path_derived_tags_from_path(path: str) -> list[str]:
"""Return all backend-derived tags for an asset path."""
tags = get_backend_system_tags_from_path(path)
for tag in get_known_input_subfolder_tags_from_path(path):
if tag not in tags:
tags.append(tag)
return tags
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return (name, tags) derived from a filesystem path.
- name: base filename with extension
- tags: [root_category] + parent folder names in order
- tags: backend-derived tags from root/model classification and known input
subfolder layout conventions
Raises:
ValueError: path does not belong to any known root.
"""
root_category, some_path = get_asset_category_and_relative_path(file_path)
p = Path(some_path)
parent_parts = [
part for part in p.parent.parts if part not in (".", "..", p.anchor)
]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
return Path(file_path).name, get_path_derived_tags_from_path(file_path)

View File

@ -25,6 +25,7 @@ class ReferenceData:
preview_id: str | None
created_at: datetime
updated_at: datetime
loader_path: str | None = None
system_metadata: dict[str, Any] | None = None
job_id: str | None = None
last_access_time: datetime | None = None
@ -93,6 +94,7 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
id=ref.id,
name=ref.name,
file_path=ref.file_path,
loader_path=ref.loader_path,
user_metadata=ref.user_metadata,
preview_id=ref.preview_id,
system_metadata=ref.system_metadata,

View File

@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
if (want_requant and len(fns) == 0 or update_weight):
seed = comfy.utils.string_to_seed(s.seed_key)
if isinstance(orig, QuantizedTensor):
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed)
else:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
if want_requant and len(fns) == 0:
@ -1216,7 +1216,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
bias_dtype=input.dtype,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=want_requant,
want_requant=True,
)
weight = weight.to(dtype=input.dtype)
else:
@ -1306,8 +1306,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
# dtype is now implicit in the layout class
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
else:
weight = weight.to(self.weight.dtype)
if return_weight:

View File

@ -100,6 +100,7 @@ def _parse_cli_feature_flags() -> dict[str, Any]:
# Default server capabilities
_CORE_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"supports_model_type_tags": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
"node_replacements": True,

View File

@ -1,85 +1,68 @@
import os
import sys
import re
import ctypes
import logging
import ctypes.util
import importlib.util
from typing import TypedDict
import numpy as np
import torch
import nodes
import comfy_angle
from comfy_api.latest import ComfyExtension, io, ui
from typing_extensions import override
from utils.install_util import get_missing_requirements_message
logger = logging.getLogger(__name__)
def _check_opengl_availability():
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
logger.debug("_check_opengl_availability: starting")
missing = []
def _preload_angle():
egl_path = comfy_angle.get_egl_path()
gles_path = comfy_angle.get_glesv2_path()
# Check Python packages (using find_spec to avoid importing)
logger.debug("_check_opengl_availability: checking for glfw package")
if importlib.util.find_spec("glfw") is None:
missing.append("glfw")
if sys.platform == "win32":
angle_dir = comfy_angle.get_lib_dir()
os.add_dll_directory(angle_dir)
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
logger.debug("_check_opengl_availability: checking for OpenGL package")
if importlib.util.find_spec("OpenGL") is None:
missing.append("PyOpenGL")
if missing:
raise RuntimeError(
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
)
# On Linux without display, check if headless backends are available
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
if sys.platform.startswith("linux"):
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
if not has_display:
# Check for EGL or OSMesa libraries
logger.debug("_check_opengl_availability: checking for EGL library")
has_egl = ctypes.util.find_library("EGL")
logger.debug("_check_opengl_availability: checking for OSMesa library")
has_osmesa = ctypes.util.find_library("OSMesa")
# Error disabled for CI as it fails this check
# if not has_egl and not has_osmesa:
# raise RuntimeError(
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
# "See error below for installation instructions."
# )
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
logger.debug("_check_opengl_availability: completed")
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
ctypes.CDLL(str(egl_path), mode=mode)
ctypes.CDLL(str(gles_path), mode=mode)
# Run early check at import time
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
_check_opengl_availability()
# OpenGL modules - initialized lazily when context is created
gl = None
glfw = None
EGL = None
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
_preload_angle()
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
def _import_opengl():
"""Import OpenGL module. Called after context is created."""
global gl
if gl is None:
logger.debug("_import_opengl: importing OpenGL.GL")
import OpenGL.GL as _gl
gl = _gl
logger.debug("_import_opengl: import completed")
return gl
import OpenGL
OpenGL.USE_ACCELERATE = False
def _patch_find_library():
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
'libGLESv2'. Patch find_library to return the full ANGLE paths so
PyOpenGL loads the same libraries we pre-loaded."""
if sys.platform == "linux":
return
import ctypes.util
_orig = ctypes.util.find_library
def _patched(name):
if name == 'EGL':
return comfy_angle.get_egl_path()
if name == 'GLESv2':
return comfy_angle.get_glesv2_path()
return _orig(name)
ctypes.util.find_library = _patched
_patch_find_library()
from OpenGL import EGL
from OpenGL import GLES3 as gl
class SizeModeInput(TypedDict):
size_mode: str
width: int
@ -102,7 +85,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
# (-1,-1)---(3,-1)
#
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
VERTEX_SHADER = """#version 330 core
VERTEX_SHADER = """#version 300 es
out vec2 v_texCoord;
void main() {
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
@ -126,14 +109,99 @@ void main() {
"""
def _convert_es_to_desktop(source: str) -> str:
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
# Remove any existing #version directive
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
# Remove precision qualifiers (not needed in desktop GLSL)
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
# Prepend desktop GLSL version
return "#version 330 core\n" + source
def _egl_attribs(*values):
"""Build an EGL_NONE-terminated EGLint attribute array."""
vals = list(values) + [EGL.EGL_NONE]
return (ctypes.c_int32 * len(vals))(*vals)
# EGL platform extension constants
EGL_PLATFORM_ANGLE_ANGLE = 0x3202
EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203
EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450
EGL_MESA_PLATFORM_SURFACELESS = 0x31DD
_eglGetPlatformDisplayEXT = None
def _get_egl_platform_display_ext(platform, native_display, attribs):
"""Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL)."""
global _eglGetPlatformDisplayEXT
if _eglGetPlatformDisplayEXT is None:
from OpenGL import platform as _plat
egl_lib = _plat.PLATFORM.EGL
_get_proc = egl_lib.eglGetProcAddress
_get_proc.restype = ctypes.c_void_p
_get_proc.argtypes = [ctypes.c_char_p]
ptr = _get_proc(b"eglGetPlatformDisplayEXT")
if not ptr:
return None
func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p)
_eglGetPlatformDisplayEXT = func_type(ptr)
raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs)
if not raw:
return None
return ctypes.cast(raw, EGL.EGLDisplay)
def _get_egl_display():
"""Get an EGL display, trying the default first then ANGLE's Vulkan
platform for headless environments without a display server."""
failures = []
# Try the default display first (works when X11/Wayland is available)
display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
if display:
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
try:
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
return display, major.value, minor.value
except Exception as e:
failures.append(f"default: {e}")
logger.info("Default EGL display unavailable, trying headless fallbacks")
# Headless fallback strategies, tried in order:
headless_strategies = [
("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None),
("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None,
_egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)),
]
for name, platform, native_display, attribs in headless_strategies:
display = _get_egl_platform_display_ext(platform, native_display, attribs)
if not display:
failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display")
continue
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
try:
if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)):
logger.info(f"Using EGL {name} platform (headless)")
return display, major.value, minor.value
failures.append(f"{name}: eglInitialize returned false")
except Exception as e:
failures.append(f"{name}: {e}")
continue
details = "\n".join(f" - {f}" for f in failures)
raise RuntimeError(
"Failed to initialize EGL display.\n"
"No display server and no headless EGL platform available.\n"
f"Tried:\n{details}\n"
"Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer."
)
def _gl_str(name):
"""Get an OpenGL string parameter."""
v = gl.glGetString(name)
if not v:
return "Unknown"
if isinstance(v, bytes):
return v.decode(errors="replace")
return ctypes.string_at(v).decode(errors="replace")
def _detect_output_count(source: str) -> int:
@ -159,163 +227,8 @@ def _detect_pass_count(source: str) -> int:
return 1
def _init_glfw():
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
logger.debug("_init_glfw: starting")
# On macOS, glfw.init() must be called from main thread or it hangs forever
if sys.platform == "darwin":
logger.debug("_init_glfw: skipping on macOS")
raise RuntimeError("GLFW backend not supported on macOS")
logger.debug("_init_glfw: importing glfw module")
import glfw as _glfw
logger.debug("_init_glfw: calling glfw.init()")
if not _glfw.init():
raise RuntimeError("glfw.init() failed")
try:
logger.debug("_init_glfw: setting window hints")
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
logger.debug("_init_glfw: calling create_window()")
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
if not window:
raise RuntimeError("glfw.create_window() failed")
logger.debug("_init_glfw: calling make_context_current()")
_glfw.make_context_current(window)
logger.debug("_init_glfw: completed successfully")
return window, _glfw
except Exception:
logger.debug("_init_glfw: failed, terminating glfw")
_glfw.terminate()
raise
def _init_egl():
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
logger.debug("_init_egl: starting")
from OpenGL import EGL as _EGL
from OpenGL.EGL import (
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
eglTerminate, eglDestroyContext, eglDestroySurface,
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
)
logger.debug("_init_egl: imports completed")
display = None
context = None
surface = None
try:
logger.debug("_init_egl: calling eglGetDisplay()")
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
if display == _EGL.EGL_NO_DISPLAY:
raise RuntimeError("eglGetDisplay() failed")
logger.debug("_init_egl: calling eglInitialize()")
major, minor = _EGL.EGLint(), _EGL.EGLint()
if not eglInitialize(display, major, minor):
display = None # Not initialized, don't terminate
raise RuntimeError("eglInitialize() failed")
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
config_attribs = [
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
EGL_DEPTH_SIZE, 0, EGL_NONE
]
configs = (_EGL.EGLConfig * 1)()
num_configs = _EGL.EGLint()
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
config = configs[0]
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
if not eglBindAPI(EGL_OPENGL_API):
raise RuntimeError("eglBindAPI() failed")
logger.debug("_init_egl: calling eglCreateContext()")
context_attribs = [
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
EGL_NONE
]
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
if context == EGL_NO_CONTEXT:
raise RuntimeError("eglCreateContext() failed")
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
if surface == _EGL.EGL_NO_SURFACE:
raise RuntimeError("eglCreatePbufferSurface() failed")
logger.debug("_init_egl: calling eglMakeCurrent()")
if not eglMakeCurrent(display, surface, surface, context):
raise RuntimeError("eglMakeCurrent() failed")
logger.debug("_init_egl: completed successfully")
return display, context, surface, _EGL
except Exception:
logger.debug("_init_egl: failed, cleaning up")
# Clean up any resources on failure
if surface is not None:
eglDestroySurface(display, surface)
if context is not None:
eglDestroyContext(display, context)
if display is not None:
eglTerminate(display)
raise
def _init_osmesa():
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
import ctypes
logger.debug("_init_osmesa: starting")
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
logger.debug("_init_osmesa: importing OpenGL.osmesa")
from OpenGL import GL as _gl
from OpenGL.osmesa import (
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
OSMESA_RGBA,
)
logger.debug("_init_osmesa: imports completed")
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
if not ctx:
raise RuntimeError("OSMesaCreateContextExt() failed")
width, height = 64, 64
buffer = (ctypes.c_ubyte * (width * height * 4))()
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
OSMesaDestroyContext(ctx)
raise RuntimeError("OSMesaMakeCurrent() failed")
logger.debug("_init_osmesa: completed successfully")
return ctx, buffer
class GLContext:
"""Manages OpenGL context and resources for shader execution.
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
"""
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
_instance = None
_initialized = False
@ -327,131 +240,105 @@ class GLContext:
def __init__(self):
if GLContext._initialized:
logger.debug("GLContext.__init__: already initialized, skipping")
return
logger.debug("GLContext.__init__: starting initialization")
global glfw, EGL
import time
start = time.perf_counter()
self._backend = None
self._window = None
self._egl_display = None
self._egl_context = None
self._egl_surface = None
self._osmesa_ctx = None
self._osmesa_buffer = None
self._display = None
self._surface = None
self._context = None
self._vao = None
# Try backends in order: GLFW → EGL → OSMesa
errors = []
logger.debug("GLContext.__init__: trying GLFW backend")
try:
self._window, glfw = _init_glfw()
self._backend = "glfw"
logger.debug("GLContext.__init__: GLFW backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
errors.append(("GLFW", e))
self._display, self._egl_major, self._egl_minor = _get_egl_display()
if self._backend is None:
logger.debug("GLContext.__init__: trying EGL backend")
try:
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
self._backend = "egl"
logger.debug("GLContext.__init__: EGL backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
errors.append(("EGL", e))
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
if self._backend is None:
logger.debug("GLContext.__init__: trying OSMesa backend")
try:
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
self._backend = "osmesa"
logger.debug("GLContext.__init__: OSMesa backend succeeded")
except Exception as e:
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
errors.append(("OSMesa", e))
config = EGL.EGLConfig()
n_configs = ctypes.c_int32(0)
if not EGL.eglChooseConfig(
self._display,
_egl_attribs(
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
),
ctypes.byref(config), 1, ctypes.byref(n_configs),
) or n_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
if self._backend is None:
if sys.platform == "win32":
platform_help = (
"Windows: Ensure GPU drivers are installed and display is available.\n"
" CPU-only/headless mode is not supported on Windows."
)
elif sys.platform == "darwin":
platform_help = (
"macOS: GLFW is not supported.\n"
" Install OSMesa via Homebrew: brew install mesa\n"
" Then: pip install PyOpenGL PyOpenGL-accelerate"
)
else:
platform_help = (
"Linux: Install one of these backends:\n"
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
" Headless (CPU): sudo apt install libosmesa6"
)
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
raise RuntimeError(
f"Failed to create OpenGL context.\n\n"
f"Backend errors:\n{error_details}\n\n"
f"{platform_help}"
self._surface = EGL.eglCreatePbufferSurface(
self._display, config,
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
)
if not self._surface:
raise RuntimeError("eglCreatePbufferSurface() failed")
# Now import OpenGL.GL (after context is current)
logger.debug("GLContext.__init__: importing OpenGL.GL")
_import_opengl()
self._context = EGL.eglCreateContext(
self._display, config, EGL.EGL_NO_CONTEXT,
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
)
if not self._context:
raise RuntimeError("eglCreateContext() failed")
# Create VAO (required for core profile, but OSMesa may use compat profile)
logger.debug("GLContext.__init__: creating VAO")
try:
vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(vao)
self._vao = vao # Only store after successful bind
logger.debug("GLContext.__init__: VAO created successfully")
except Exception as e:
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
# OSMesa with older Mesa may not support VAOs
# Clean up if we created but couldn't bind
if vao:
try:
gl.glDeleteVertexArrays(1, [vao])
except Exception:
pass
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
raise RuntimeError("eglMakeCurrent() failed")
self._vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(self._vao)
except Exception:
self._cleanup()
raise
elapsed = (time.perf_counter() - start) * 1000
# Log device info
renderer = gl.glGetString(gl.GL_RENDERER)
vendor = gl.glGetString(gl.GL_VENDOR)
version = gl.glGetString(gl.GL_VERSION)
renderer = renderer.decode() if renderer else "Unknown"
vendor = vendor.decode() if vendor else "Unknown"
version = version.decode() if version else "Unknown"
renderer = _gl_str(gl.GL_RENDERER)
vendor = _gl_str(gl.GL_VENDOR)
version = _gl_str(gl.GL_VERSION)
GLContext._initialized = True
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}")
def make_current(self):
if self._backend == "glfw":
glfw.make_context_current(self._window)
elif self._backend == "egl":
from OpenGL.EGL import eglMakeCurrent
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
elif self._backend == "osmesa":
from OpenGL.osmesa import OSMesaMakeCurrent
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
err = EGL.eglGetError()
raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})")
if self._vao is not None:
gl.glBindVertexArray(self._vao)
def _cleanup(self):
if not self._display:
return
try:
if self._vao is not None:
gl.glDeleteVertexArrays(1, [self._vao])
self._vao = None
except Exception:
pass
try:
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
except Exception:
pass
try:
if self._context:
EGL.eglDestroyContext(self._display, self._context)
except Exception:
pass
try:
if self._surface:
EGL.eglDestroySurface(self._display, self._surface)
except Exception:
pass
try:
EGL.eglTerminate(self._display)
except Exception:
pass
self._display = None
def _compile_shader(source: str, shader_type: int) -> int:
"""Compile a shader and return its ID."""
@ -459,8 +346,10 @@ def _compile_shader(source: str, shader_type: int) -> int:
gl.glShaderSource(shader, source)
gl.glCompileShader(shader)
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
error = gl.glGetShaderInfoLog(shader).decode()
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
error = gl.glGetShaderInfoLog(shader)
if isinstance(error, bytes):
error = error.decode(errors="replace")
gl.glDeleteShader(shader)
raise RuntimeError(f"Shader compilation failed:\n{error}")
@ -484,8 +373,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int:
gl.glDeleteShader(vertex_shader)
gl.glDeleteShader(fragment_shader)
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
error = gl.glGetProgramInfoLog(program).decode()
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
error = gl.glGetProgramInfoLog(program)
if isinstance(error, bytes):
error = error.decode(errors="replace")
gl.glDeleteProgram(program)
raise RuntimeError(f"Program linking failed:\n{error}")
@ -530,9 +421,6 @@ def _render_shader_batch(
ctx = GLContext()
ctx.make_current()
# Convert from GLSL ES to desktop GLSL 330
fragment_source = _convert_es_to_desktop(fragment_code)
# Detect how many outputs the shader actually uses
num_outputs = _detect_output_count(fragment_code)
@ -558,9 +446,9 @@ def _render_shader_batch(
try:
# Compile shaders (once for all batches)
try:
program = _create_program(VERTEX_SHADER, fragment_source)
program = _create_program(VERTEX_SHADER, fragment_code)
except RuntimeError:
logger.error(f"Fragment shader:\n{fragment_source}")
logger.error(f"Fragment shader:\n{fragment_code}")
raise
gl.glUseProgram(program)
@ -723,13 +611,13 @@ def _render_shader_batch(
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
# Read back outputs for this batch
# (glGetTexImage is synchronous, implicitly waits for rendering)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
batch_outputs = []
for tex in output_textures:
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
batch_outputs.append(img[::-1, :, :].copy())
for i in range(num_outputs):
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
buf = np.empty((height, width, 4), dtype=np.float32)
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
batch_outputs.append(buf[::-1, :, :].copy())
# Pad with black images for unused outputs
black_img = np.zeros((height, width, 4), dtype=np.float32)
@ -750,18 +638,18 @@ def _render_shader_batch(
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glUseProgram(0)
for tex in input_textures:
gl.glDeleteTextures(int(tex))
for tex in curve_textures:
gl.glDeleteTextures(int(tex))
for tex in output_textures:
gl.glDeleteTextures(int(tex))
for tex in ping_pong_textures:
gl.glDeleteTextures(int(tex))
if input_textures:
gl.glDeleteTextures(len(input_textures), input_textures)
if curve_textures:
gl.glDeleteTextures(len(curve_textures), curve_textures)
if output_textures:
gl.glDeleteTextures(len(output_textures), output_textures)
if ping_pong_textures:
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
if fbo is not None:
gl.glDeleteFramebuffers(1, [fbo])
for pp_fbo in ping_pong_fbos:
gl.glDeleteFramebuffers(1, [pp_fbo])
if ping_pong_fbos:
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
if program is not None:
gl.glDeleteProgram(program)

View File

@ -1113,32 +1113,6 @@ def full_type_name(klass):
return klass.__qualname__
return module + '.' + klass.__qualname__
def node_not_executable_reason(class_def, class_type):
"""Return a human-readable reason the node cannot be executed, or None if it's fine.
Catches a node whose declared entry point doesn't resolve to a real method
(e.g. a V1 ``FUNCTION = "invert"`` where the method is misspelled, or a V3 node
missing its ``execute`` override). Running this during validation surfaces the
problem before execution starts, instead of after upstream nodes have run.
Only the class is inspected; the node is never instantiated here, so a node's
``__init__`` side effects cannot run (or fail) during validation.
"""
try:
if issubclass(class_def, _ComfyNodeInternal):
# V3: validates that execute()/define_schema() overrides exist.
class_def.VALIDATE_CLASS()
return None
# V1: FUNCTION names the method to call; it must exist on the class.
function_name = getattr(class_def, "FUNCTION", None)
if function_name is None:
return f"'{class_type}' does not define FUNCTION"
if not callable(getattr(class_def, function_name, None)):
return f"'{class_type}' has no method '{function_name}' (declared in FUNCTION)"
return None
except Exception as ex:
return str(ex)
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
outputs = set()
for x in prompt:
@ -1174,35 +1148,6 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[
}
return (False, error, [], {})
# Make sure the node is actually executable (its FUNCTION/execute entry
# point resolves to a real method) before we touch any schema-derived
# attributes below or start execution. Catches code typos up front and
# attributes the error to the offending node.
not_executable = node_not_executable_reason(class_, class_type)
if not_executable is not None:
node_title = prompt[x].get('_meta', {}).get('title', class_type)
error = {
"type": "invalid_node_definition",
"message": "Node is not executable",
"details": f"{not_executable} (Node ID '#{x}')",
"extra_info": {
"node_id": x,
"class_type": class_type,
"node_title": node_title,
}
}
node_errors = {x: {
"errors": [{
"type": "invalid_node_definition",
"message": "Node is not executable",
"details": not_executable,
"extra_info": {},
}],
"dependent_outputs": [],
"class_type": class_type,
}}
return (False, error, [], node_errors)
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
if partial_execution_list is None or x in partial_execution_list:
outputs.add(x)

View File

@ -7,18 +7,22 @@ components:
description: Timestamp when the asset was created
format: date-time
type: string
display_name:
description: Display name of the asset. Mirrors name for backwards compatibility.
nullable: true
type: string
file_path:
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
nullable: true
type: string
hash:
description: Blake3 hash of the asset content.
pattern: ^blake3:[a-f0-9]{64}$
type: string
loader_path:
description: 'In-root loader path for filesystem-backed assets: the path relative to its storage root with the top-level model category dropped (e.g. `models/checkpoints/foo/bar.safetensors` -> `foo/bar.safetensors`). This is the value model loaders consume. `None` when the file is not within a recognized root or model category.'
nullable: true
type: string
logical_path:
description: Runtime storage locator for filesystem-backed assets, using Comfy storage namespaces such as `input/`, `output/`, `temp/`, or `models/` (e.g. `models/checkpoints/foo/bar.safetensors`). Not an absolute filesystem path, unique identity, or model loader path.
nullable: true
type: string
display_name:
description: Human-facing label derived from `logical_path`, usually the path below the top-level storage namespace. Not unique.
nullable: true
type: string
id:
description: Unique identifier for the asset
format: uuid
@ -144,14 +148,6 @@ components:
AssetUpdated:
description: Response returned when an existing asset is successfully updated.
properties:
display_name:
description: Display name of the asset. Mirrors name for backwards compatibility.
nullable: true
type: string
file_path:
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
nullable: true
type: string
hash:
description: Blake3 hash of the asset content.
pattern: ^blake3:[a-f0-9]{64}$
@ -1644,7 +1640,7 @@ paths:
format: uuid
type: string
tags:
description: JSON-encoded array of freeform tag strings, e.g. '["models","checkpoint"]'. Common types include "models", "input", "output", and "temp", but any tag can be used in any order.
description: JSON-encoded array of tag strings. For new byte uploads, include exactly one destination role (`input`, `output`, or `models`); `models` uploads also require exactly one `model_type:<folder_name>` tag. Extra tags are stored as labels and do not create path components.
type: string
user_metadata:
description: Custom JSON metadata as a string
@ -1829,7 +1825,7 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/AssetUpdated'
$ref: '#/components/schemas/Asset'
description: Asset updated successfully
"400":
content:
@ -2470,6 +2466,9 @@ paths:
supports_preview_metadata:
description: Whether the server supports preview metadata
type: boolean
supports_model_type_tags:
description: Whether the server supports namespaced model type asset tags
type: boolean
type: object
description: Success
headers:

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.45.19
comfyui-workflow-templates==0.10.7
comfyui-embedded-docs==0.5.5
comfyui-embedded-docs==0.5.6
torch
torchsde
torchvision
@ -22,7 +22,7 @@ alembic
SQLAlchemy>=2.0.0
filelock
av>=16.0.0
comfy-kitchen==0.2.12
comfy-kitchen==0.2.14
comfy-aimdo==0.4.10
requests
simpleeval>=1.0.0
@ -33,5 +33,5 @@ kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
PyOpenGL
glfw
PyOpenGL>=3.1.8
comfy-angle

View File

@ -46,6 +46,7 @@ from comfy_api.internal import _ComfyNodeInternal
from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_routes
from app.assets.services.ingest import register_file_in_place
from app.assets.services.path_utils import get_known_subfolder_tags
from app.assets.services.asset_management import resolve_hash_to_path
from app.user_manager import UserManager
@ -440,7 +441,9 @@ class PromptServer():
if args.enable_assets:
try:
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
tags = [tag]
tags.extend(get_known_subfolder_tags(subfolder))
result = register_file_in_place(abs_path=filepath, name=filename, tags=tags)
resp["asset"] = {
"id": result.ref.id,
"name": result.ref.name,

View File

@ -8,6 +8,7 @@ upgrade/downgrade for 0003+.
"""
import os
import sqlite3
import pytest
from alembic import command
@ -30,6 +31,12 @@ def _make_config(db_path: str) -> Config:
return cfg
def _sqlite_path(cfg: Config) -> str:
url = cfg.get_main_option("sqlalchemy.url")
assert url is not None and url.startswith("sqlite:///")
return url.removeprefix("sqlite:///")
@pytest.fixture
def migration_db(tmp_path):
"""Yield an alembic Config pre-upgraded to the baseline revision."""
@ -55,3 +62,26 @@ def test_upgrade_downgrade_cycle(migration_db):
command.upgrade(migration_db, "head")
command.downgrade(migration_db, _BASELINE)
command.upgrade(migration_db, "head")
def test_case_sensitive_tags_downgrade_normalizes_existing_tags(migration_db):
"""Downgrading 0005 folds mixed-case tag vocabulary before restoring CHECK."""
command.upgrade(migration_db, "0005_allow_case_sensitive_tags")
db_path = _sqlite_path(migration_db)
with sqlite3.connect(db_path) as conn:
conn.execute("INSERT INTO tags(name) VALUES (?)", ("NewTag",))
conn.execute("INSERT INTO tags(name) VALUES (?)", ("newtag",))
conn.execute("INSERT INTO tags(name) VALUES (?)", ("model_type:LLM",))
command.downgrade(migration_db, "0004_drop_tag_type")
with sqlite3.connect(db_path) as conn:
tags = {row[0] for row in conn.execute("SELECT name FROM tags")}
assert "newtag" in tags
assert "model_type:llm" in tags
assert "NewTag" not in tags
assert "model_type:LLM" not in tags
with pytest.raises(sqlite3.IntegrityError):
conn.execute("INSERT INTO tags(name) VALUES (?)", ("Upper",))

View File

@ -234,7 +234,7 @@ def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_bas
p = getattr(request, "param", {}) or {}
tags: Optional[list[str]] = p.get("tags")
if tags is None:
tags = ["models", "checkpoints", "unit-tests", "alpha"]
tags = ["models", "model_type:checkpoints", "unit-tests", "alpha"]
meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None}
# Unique content per test so the seed always creates a fresh asset (201).
# Delete is now always a soft delete, so content from a prior test survives

View File

@ -133,6 +133,66 @@ class TestListReferencesPage:
assert total == 1
assert refs[0].name == "tagged"
def test_include_tags_filter_ands_persisted_model_tags(self, session: Session):
asset = _make_asset(session, "hash-model-tags")
checkpoint = _make_reference(session, asset, name="checkpoint")
lora = _make_reference(session, asset, name="lora")
input_ref = _make_reference(session, asset, name="input")
ensure_tags_exist(
session,
["models", "model_type:checkpoints", "model_type:loras", "unit-tests"],
)
add_tags_to_reference(
session,
reference_id=checkpoint.id,
tags=["models", "model_type:checkpoints", "unit-tests"],
origin="automatic",
)
add_tags_to_reference(
session,
reference_id=lora.id,
tags=["models", "model_type:loras", "unit-tests"],
origin="automatic",
)
add_tags_to_reference(
session,
reference_id=input_ref.id,
tags=["unit-tests"],
)
session.commit()
refs, _, total = list_references_page(
session,
include_tags=["models", "model_type:checkpoints", "unit-tests"],
)
assert total == 1
assert refs[0].id == checkpoint.id
def test_include_tags_filter_preserves_model_type_case(self, session: Session):
asset = _make_asset(session, "hash-model-case")
ref = _make_reference(session, asset, name="llm")
ensure_tags_exist(session, ["models", "model_type:LLM"])
add_tags_to_reference(
session,
reference_id=ref.id,
tags=["models", "model_type:LLM"],
origin="automatic",
)
session.commit()
refs, _, total = list_references_page(
session, include_tags=["models", "model_type:LLM"]
)
refs_lower, _, total_lower = list_references_page(
session, include_tags=["models", "model_type:llm"]
)
assert total == 1
assert refs[0].id == ref.id
assert total_lower == 0
assert refs_lower == []
def test_exclude_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, name="keep")

View File

@ -58,7 +58,7 @@ class TestEnsureTagsExist:
session.commit()
tags = session.query(Tag).all()
assert {t.name for t in tags} == {"alpha", "beta"}
assert {t.name for t in tags} == {"ALPHA", "Beta", "alpha"}
def test_empty_list_is_noop(self, session: Session):
ensure_tags_exist(session, [])
@ -258,6 +258,16 @@ class TestListTagsWithUsage:
tag_names = {name for name, _ in rows}
assert tag_names == {"alpha", "alphabet"}
def test_prefix_filter_is_case_sensitive(self, session: Session):
ensure_tags_exist(session, ["model_type:LLM", "model_type:llm"])
session.commit()
rows, total = list_tags_with_usage(session, prefix="model_type:L")
tag_names = {name for name, _ in rows}
assert tag_names == {"model_type:LLM"}
assert total == 1
def test_order_by_name(self, session: Session):
ensure_tags_exist(session, ["zebra", "alpha", "middle"])
session.commit()

View File

@ -0,0 +1,84 @@
"""Tests for how _build_asset_response derives the response `loader_path`.
Guards the persist-and-read contract: the response reads the stored
`loader_path` directly, and only recomputes when the column is NULL (rows
written before the column existed).
"""
from datetime import datetime
from pathlib import Path
from unittest.mock import patch
from app.assets.api.routes import _build_asset_response
from app.assets.services.schemas import AssetDetailResult, ReferenceData
_TS = datetime(2024, 1, 1, 0, 0, 0)
def _make_result(
*, file_path: str | None, loader_path: str | None
) -> AssetDetailResult:
ref = ReferenceData(
id="ref-1",
name="model.safetensors",
file_path=file_path,
loader_path=loader_path,
user_metadata=None,
preview_id=None,
created_at=_TS,
updated_at=_TS,
last_access_time=_TS,
)
return AssetDetailResult(ref=ref, asset=None, tags=[])
def test_uses_persisted_loader_path_without_recomputing():
"""A stored loader_path is returned verbatim, not re-derived from file_path.
The sentinel value could never be produced by compute_loader_path for this
file_path, so seeing it in the response proves the stored column is read.
"""
result = _make_result(
file_path="/unmatched/root/model.safetensors",
loader_path="SENTINEL/stored.safetensors",
)
resp = _build_asset_response(result)
assert resp.loader_path == "SENTINEL/stored.safetensors"
def test_falls_back_to_compute_when_stored_loader_path_is_null(tmp_path: Path):
"""A NULL column (pre-migration row) is backfilled at read time."""
models = tmp_path / "models"
ckpt = models / "checkpoints"
ckpt.mkdir(parents=True)
f = ckpt / "bar.safetensors"
f.touch()
with patch("app.assets.services.path_utils.folder_paths") as mock_fp, patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("checkpoints", [str(ckpt)])],
):
mock_fp.get_input_directory.return_value = str(tmp_path / "in")
mock_fp.get_output_directory.return_value = str(tmp_path / "out")
mock_fp.get_temp_directory.return_value = str(tmp_path / "tmp")
mock_fp.models_dir = str(models)
result = _make_result(file_path=str(f), loader_path=None)
resp = _build_asset_response(result)
assert resp.loader_path == "bar.safetensors"
assert resp.logical_path == "models/checkpoints/bar.safetensors"
assert resp.display_name == "checkpoints/bar.safetensors"
def test_all_path_fields_null_without_file_path():
"""API-created / hash-only references (no file_path) expose no paths."""
result = _make_result(file_path=None, loader_path=None)
resp = _build_asset_response(result)
assert resp.loader_path is None
assert resp.logical_path is None
assert resp.display_name is None

View File

@ -1,10 +1,14 @@
"""Tests for bulk ingest services."""
import os
from pathlib import Path
from unittest.mock import patch
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries import get_reference_tags
from app.assets.scanner import build_asset_specs
from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets
@ -101,6 +105,184 @@ class TestBatchInsertSeedAssets:
asset = session.query(Asset).filter_by(id=ref.asset_id).first()
assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}"
def test_duplicate_paths_merge_tags_before_insert(
self, session: Session, temp_dir: Path
):
"""Overlapping model-folder registrations can emit the same path twice."""
file_path = temp_dir / "shared.safetensors"
file_path.write_bytes(b"shared model")
specs: list[SeedAssetSpec] = [
{
"abs_path": str(file_path),
"size_bytes": 12,
"mtime_ns": 1234567890000000000,
"info_name": "Shared Model",
"tags": ["models", "model_type:checkpoints"],
"fname": "shared.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
},
{
"abs_path": str(file_path),
"size_bytes": 12,
"mtime_ns": 1234567890000000000,
"info_name": "Shared Model",
"tags": ["models", "model_type:diffusion_models"],
"fname": "shared.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
},
]
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
assert result.won_paths == 1
refs = session.query(AssetReference).all()
assert len(refs) == 1
assert set(get_reference_tags(session, reference_id=refs[0].id)) == {
"models",
"model_type:checkpoints",
"model_type:diffusion_models",
}
def test_duplicate_paths_are_merged_after_abspath_normalization(
self, session: Session, temp_dir: Path, monkeypatch
):
"""The scanner may emit equivalent paths with different spelling."""
file_path = temp_dir / "same-file.safetensors"
file_path.write_bytes(b"shared model")
monkeypatch.chdir(temp_dir)
relative_path = file_path.name
absolute_path = os.path.abspath(relative_path)
specs: list[SeedAssetSpec] = [
{
"abs_path": relative_path,
"size_bytes": 12,
"mtime_ns": 1234567890000000000,
"info_name": "Shared Model",
"tags": ["models", "model_type:checkpoints"],
"fname": "same-file.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
},
{
"abs_path": absolute_path,
"size_bytes": 12,
"mtime_ns": 1234567890000000000,
"info_name": "Shared Model",
"tags": ["models", "model_type:diffusion_models"],
"fname": "same-file.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
},
]
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
assert result.won_paths == 1
refs = session.query(AssetReference).all()
assert len(refs) == 1
assert refs[0].file_path == absolute_path
# loader_path is persisted from the spec's fname (compute_loader_path).
assert refs[0].loader_path == "same-file.safetensors"
assert set(get_reference_tags(session, reference_id=refs[0].id)) == {
"models",
"model_type:checkpoints",
"model_type:diffusion_models",
}
def test_scanner_duplicate_shared_model_paths_keep_all_model_type_tags(
self, session: Session, temp_dir: Path
):
"""Shared extra model roots make scanner collection emit duplicate paths."""
shared_root = temp_dir / "shared"
input_dir = temp_dir / "input"
output_dir = temp_dir / "output"
temp_root = temp_dir / "temp"
for directory in (shared_root, input_dir, output_dir, temp_root):
directory.mkdir()
file_path = shared_root / "dual_use_model.safetensors"
file_path.write_bytes(b"shared model")
with (
patch("app.assets.services.path_utils.folder_paths") as mock_fp,
patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[
("checkpoints", [str(shared_root)]),
("diffusion_models", [str(shared_root)]),
],
),
):
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.get_temp_directory.return_value = str(temp_root)
specs, tag_pool, skipped = build_asset_specs(
paths=[str(file_path), str(file_path)],
existing_paths=set(),
enable_metadata_extraction=False,
compute_hashes=False,
)
assert skipped == 0
assert len(specs) == 2
assert tag_pool == {
"models",
"model_type:checkpoints",
"model_type:diffusion_models",
}
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
assert result.won_paths == 1
refs = session.query(AssetReference).all()
assert len(refs) == 1
assert set(get_reference_tags(session, reference_id=refs[0].id)) == {
"models",
"model_type:checkpoints",
"model_type:diffusion_models",
}
def test_loader_path_persisted_as_null_when_fname_is_none(
self, session: Session, temp_dir: Path
):
"""A file with no in-root loader path (fname=None, e.g. an orphan under
models_root) persists loader_path as NULL rather than a synthesized value."""
file_path = temp_dir / "orphan.bin"
file_path.write_bytes(b"x")
specs: list[SeedAssetSpec] = [
{
"abs_path": str(file_path),
"size_bytes": 1,
"mtime_ns": 1234567890000000000,
"info_name": "orphan.bin",
"tags": [],
"fname": None,
"metadata": None,
"hash": None,
"mime_type": None,
}
]
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
refs = session.query(AssetReference).all()
assert len(refs) == 1
assert refs[0].file_path == str(file_path)
assert refs[0].loader_path is None
class TestMetadataExtraction:
def test_extracts_mime_type_for_model_files(self, temp_dir: Path):

View File

@ -94,6 +94,47 @@ class TestIngestFileFromPath:
ref_tags = get_reference_tags(session, reference_id=result.reference_id)
assert set(ref_tags) == {"models", "checkpoints"}
def test_path_derived_tags_use_automatic_origin(
self, mock_create_session, temp_dir: Path, session: Session
):
input_dir = temp_dir / "input"
output_dir = temp_dir / "output"
temp_root = temp_dir / "temp"
for directory in (input_dir, output_dir, temp_root):
directory.mkdir()
file_path = input_dir / "pasted" / "tagged.png"
file_path.parent.mkdir()
file_path.write_bytes(b"data")
with (
patch("app.assets.services.path_utils.folder_paths") as mock_fp,
patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[],
),
):
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.get_temp_directory.return_value = str(temp_root)
result = _ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:pathorigin",
size_bytes=4,
mtime_ns=1234567890000000000,
info_name="Tagged Asset",
tags=["input", "manual-label"],
)
assert result.reference_id is not None
links = session.query(AssetReferenceTag).filter_by(
asset_reference_id=result.reference_id
)
origin_by_tag = {link.tag_name: link.origin for link in links}
assert origin_by_tag["input"] == "automatic"
assert origin_by_tag["pasted"] == "automatic"
assert origin_by_tag["manual-label"] == "manual"
def test_idempotent_upsert(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "dup.bin"
file_path.write_bytes(b"content")

View File

@ -6,7 +6,16 @@ from unittest.mock import patch
import pytest
from app.assets.services.path_utils import get_asset_category_and_relative_path
from app.assets.services.path_utils import (
compute_display_name,
compute_loader_path,
compute_logical_path,
get_asset_category_and_relative_path,
get_known_input_subfolder_tags_from_path,
get_known_subfolder_tags,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags,
)
@pytest.fixture
@ -17,7 +26,8 @@ def fake_dirs():
input_dir = root_path / "input"
output_dir = root_path / "output"
temp_dir = root_path / "temp"
models_dir = root_path / "models" / "checkpoints"
models_root = root_path / "models"
models_dir = models_root / "checkpoints"
for d in (input_dir, output_dir, temp_dir, models_dir):
d.mkdir(parents=True)
@ -25,6 +35,7 @@ def fake_dirs():
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.get_temp_directory.return_value = str(temp_dir)
mock_fp.models_dir = str(models_root)
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
@ -34,6 +45,7 @@ def fake_dirs():
"input": input_dir,
"output": output_dir,
"temp": temp_dir,
"models_root": models_root,
"models": models_dir,
}
@ -76,6 +88,449 @@ class TestGetAssetCategoryAndRelativePath:
cat, rel = get_asset_category_and_relative_path(str(f))
assert cat == "models"
def test_model_path_tags_include_registered_model_type_only(self, fake_dirs):
f = fake_dirs["models"] / "subdir" / "model.safetensors"
f.parent.mkdir()
f.touch()
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "models" in tags
assert "model_type:checkpoints" in tags
assert "checkpoints" not in tags
assert "subdir" not in tags
def test_model_type_preserves_registered_folder_case(self, fake_dirs):
llm_dir = fake_dirs["models"].parent / "LLM"
llm_dir.mkdir()
f = llm_dir / "model.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("LLM", [str(llm_dir)])],
):
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "models" in tags
assert "model_type:LLM" in tags
assert "model_type:llm" not in tags
def test_path_components_do_not_create_model_type_tags(self, fake_dirs):
f = fake_dirs["models"] / "loras" / "model.safetensors"
f.parent.mkdir()
f.touch()
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "models" in tags
assert "model_type:checkpoints" in tags
assert "loras" not in tags
assert "model_type:loras" not in tags
def test_shared_root_returns_all_matching_model_type_tags(self, fake_dirs):
shared_root = fake_dirs["models"].parent / "shared"
shared_root.mkdir()
f = shared_root / "foo.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[
("checkpoints", [str(shared_root)]),
("loras", [str(shared_root)]),
],
):
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "models" in tags
assert "model_type:checkpoints" in tags
assert "model_type:loras" in tags
def test_output_backed_registered_folder_gets_model_and_output_tags(self, fake_dirs):
output_checkpoints_dir = fake_dirs["output"] / "checkpoints"
output_checkpoints_dir.mkdir()
f = output_checkpoints_dir / "saved.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("checkpoints", [str(output_checkpoints_dir)])],
):
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "models" in tags
assert "model_type:checkpoints" in tags
assert "output" in tags
def test_temp_path_tags_include_temp_not_output_or_preview(self, fake_dirs):
f = fake_dirs["temp"] / "preview.png"
f.touch()
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "temp" in tags
assert "output" not in tags
assert "preview:true" not in tags
def test_known_subfolder_tags_are_centralized(self):
assert get_known_subfolder_tags("pasted") == ["pasted"]
assert get_known_subfolder_tags("arbitrary") == []
def test_known_input_subfolder_tags_are_path_derived_for_direct_children(self, fake_dirs):
f = fake_dirs["input"] / "pasted" / "image.png"
f.parent.mkdir()
f.touch()
assert get_known_input_subfolder_tags_from_path(str(f)) == ["pasted"]
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "input" in tags
assert "pasted" in tags
def test_known_input_subfolder_tags_do_not_apply_to_nested_or_other_roots(self, fake_dirs):
nested = fake_dirs["input"] / "pasted" / "session" / "image.png"
output = fake_dirs["output"] / "pasted" / "image.png"
for path in (nested, output):
path.parent.mkdir(parents=True)
path.touch()
assert get_known_input_subfolder_tags_from_path(str(nested)) == []
assert get_known_input_subfolder_tags_from_path(str(output)) == []
def test_unknown_path_raises(self, fake_dirs):
with pytest.raises(ValueError, match="not within"):
get_asset_category_and_relative_path("/some/random/path.png")
class TestResponseStoragePaths:
def test_input_file_path_and_display_name_include_subfolder(self, fake_dirs):
sub = fake_dirs["input"] / "some" / "folder"
sub.mkdir(parents=True)
f = sub / "image.png"
f.touch()
assert compute_logical_path(str(f)) == "input/some/folder/image.png"
assert compute_display_name(str(f)) == "some/folder/image.png"
def test_output_file_path_and_display_name_include_subfolder(self, fake_dirs):
sub = fake_dirs["output"] / "renders"
sub.mkdir()
f = sub / "ComfyUI_00001_.png"
f.touch()
assert compute_logical_path(str(f)) == "output/renders/ComfyUI_00001_.png"
assert compute_display_name(str(f)) == "renders/ComfyUI_00001_.png"
def test_temp_file_path_and_display_name(self, fake_dirs):
f = fake_dirs["temp"] / "preview.png"
f.touch()
assert compute_logical_path(str(f)) == "temp/preview.png"
assert compute_display_name(str(f)) == "preview.png"
def test_exact_storage_root_has_no_display_name(self, fake_dirs):
assert compute_logical_path(str(fake_dirs["input"])) == "input"
assert compute_display_name(str(fake_dirs["input"])) is None
def test_longest_matching_builtin_root_wins(self, fake_dirs, tmp_path: Path):
nested_output = fake_dirs["input"] / "nested-output"
nested_output.mkdir()
f = nested_output / "image.png"
f.touch()
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
mock_fp.get_input_directory.return_value = str(fake_dirs["input"])
mock_fp.get_output_directory.return_value = str(nested_output)
mock_fp.get_temp_directory.return_value = str(tmp_path / "temp")
mock_fp.models_dir = str(fake_dirs["models_root"])
assert compute_logical_path(str(f)) == "output/image.png"
assert compute_display_name(str(f)) == "image.png"
def test_model_file_path_is_relative_to_physical_models_root(self, fake_dirs):
sub = fake_dirs["models"] / "flux"
sub.mkdir()
f = sub / "model.safetensors"
f.touch()
assert compute_logical_path(str(f)) == "models/checkpoints/flux/model.safetensors"
assert compute_display_name(str(f)) == "checkpoints/flux/model.safetensors"
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "model.safetensors"
assert "models" in tags
assert "model_type:checkpoints" in tags
assert "checkpoints" not in tags
assert "flux" not in tags
@pytest.mark.parametrize(
"folder_name",
["checkpoints", "clip", "vae", "diffusion_models", "loras"],
)
def test_output_model_folder_uses_output_storage_file_path(self, fake_dirs, folder_name):
output_model_dir = fake_dirs["output"] / folder_name
output_model_dir.mkdir(exist_ok=True)
default_model_dir = fake_dirs["models_root"] / folder_name
default_model_dir.mkdir(exist_ok=True)
f = output_model_dir / "saved.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[(folder_name, [str(default_model_dir), str(output_model_dir)])],
):
assert compute_logical_path(str(f)) == f"output/{folder_name}/saved.safetensors"
assert compute_display_name(str(f)) == f"{folder_name}/saved.safetensors"
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "saved.safetensors"
assert "output" in tags
assert "models" in tags
assert f"model_type:{folder_name}" in tags
assert folder_name not in tags
def test_output_model_subfolder_uses_output_storage_file_path(self, fake_dirs):
folder_name = "loras"
output_model_dir = fake_dirs["output"] / folder_name
subdir = output_model_dir / "experiments"
subdir.mkdir(parents=True)
f = subdir / "my_lora.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[(folder_name, [str(output_model_dir)])],
):
assert (
compute_logical_path(str(f))
== "output/loras/experiments/my_lora.safetensors"
)
assert compute_display_name(str(f)) == "loras/experiments/my_lora.safetensors"
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "my_lora.safetensors"
assert "output" in tags
assert "models" in tags
assert "model_type:loras" in tags
assert "loras" not in tags
assert "experiments" not in tags
def test_external_model_folder_without_provenance_has_no_file_path(self, tmp_path: Path):
external_checkpoints_dir = tmp_path / "external" / "not_named_like_category"
external_checkpoints_dir.mkdir(parents=True)
f = external_checkpoints_dir / "external.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("checkpoints", [str(external_checkpoints_dir)])],
):
assert compute_logical_path(str(f)) is None
assert compute_display_name(str(f)) is None
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "external.safetensors"
assert "models" in tags
assert "model_type:checkpoints" in tags
def test_same_relative_model_file_under_multiple_external_roots_has_no_storage_file_path(
self, tmp_path: Path
):
foo_dir = tmp_path / "foo"
bar_dir = tmp_path / "bar"
foo_dir.mkdir()
bar_dir.mkdir()
foo_file = foo_dir / "baz.safetensors"
bar_file = bar_dir / "baz.safetensors"
foo_file.touch()
bar_file.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("checkpoints", [str(foo_dir), str(bar_dir)])],
):
assert compute_logical_path(str(foo_file)) is None
assert compute_logical_path(str(bar_file)) is None
assert compute_display_name(str(foo_file)) is None
assert compute_display_name(str(bar_file)) is None
def test_output_clip_folder_uses_output_storage_and_text_encoder_tag(self, fake_dirs):
output_clip_dir = fake_dirs["output"] / "clip"
output_clip_dir.mkdir()
f = output_clip_dir / "clip_l.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("text_encoders", [str(output_clip_dir)])],
):
assert compute_logical_path(str(f)) == "output/clip/clip_l.safetensors"
assert compute_display_name(str(f)) == "clip/clip_l.safetensors"
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "clip_l.safetensors"
assert "output" in tags
assert "models" in tags
assert "model_type:text_encoders" in tags
assert "clip" not in tags
def test_physical_unet_folder_uses_storage_path_and_diffusion_models_tag(self, fake_dirs):
unet_dir = fake_dirs["models_root"] / "unet"
diffusion_models_dir = fake_dirs["models_root"] / "diffusion_models"
unet_dir.mkdir()
diffusion_models_dir.mkdir()
f = unet_dir / "wan.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("diffusion_models", [str(unet_dir), str(diffusion_models_dir)])],
):
assert compute_logical_path(str(f)) == "models/unet/wan.safetensors"
assert compute_display_name(str(f)) == "unet/wan.safetensors"
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "wan.safetensors"
assert "models" in tags
assert "model_type:diffusion_models" in tags
assert "unet" not in tags
def test_unregistered_file_under_physical_models_root_still_has_storage_file_path(self, fake_dirs):
f = fake_dirs["models_root"] / "not_registered" / "orphan.bin"
f.parent.mkdir()
f.touch()
assert compute_logical_path(str(f)) == "models/not_registered/orphan.bin"
assert compute_display_name(str(f)) == "not_registered/orphan.bin"
def test_output_checkpoint_folder_without_registration_has_only_output_tag(self, fake_dirs):
f = fake_dirs["output"] / "checkpoints" / "saved.safetensors"
f.parent.mkdir(exist_ok=True)
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[],
):
assert compute_logical_path(str(f)) == "output/checkpoints/saved.safetensors"
assert compute_display_name(str(f)) == "checkpoints/saved.safetensors"
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "saved.safetensors"
assert "output" in tags
assert "models" not in tags
assert not any(tag.startswith("model_type:") for tag in tags)
def test_unknown_path_returns_none(self):
assert compute_logical_path("/some/random/path.png") is None
assert compute_display_name("/some/random/path.png") is None
class TestLoaderPath:
"""In-root loader path: relative to the storage root, model category dropped."""
def test_model_loader_path_drops_category(self, fake_dirs):
sub = fake_dirs["models"] / "flux"
sub.mkdir()
f = sub / "model.safetensors"
f.touch()
# logical_path keeps the category, file_path (loader) drops it
assert compute_logical_path(str(f)) == "models/checkpoints/flux/model.safetensors"
assert compute_loader_path(str(f)) == "flux/model.safetensors"
def test_model_loader_path_flat_file(self, fake_dirs):
f = fake_dirs["models"] / "model.safetensors"
f.touch()
assert compute_loader_path(str(f)) == "model.safetensors"
def test_input_loader_path_keeps_subfolders(self, fake_dirs):
sub = fake_dirs["input"] / "some" / "folder"
sub.mkdir(parents=True)
f = sub / "image.png"
f.touch()
assert compute_loader_path(str(f)) == "some/folder/image.png"
def test_temp_loader_path(self, fake_dirs):
f = fake_dirs["temp"] / "preview.png"
f.touch()
assert compute_loader_path(str(f)) == "preview.png"
def test_unregistered_file_under_models_root_has_no_loader_path(self, fake_dirs):
# Under models_root but not within any registered category base.
f = fake_dirs["models_root"] / "not_registered" / "orphan.bin"
f.parent.mkdir()
f.touch()
# It still has a namespaced logical_path, but no loader path.
assert compute_logical_path(str(f)) == "models/not_registered/orphan.bin"
assert compute_loader_path(str(f)) is None
def test_extra_path_model_has_loader_path_but_no_logical_path(self, tmp_path: Path):
"""Registered category base outside models_dir (extra_model_paths style).
Loadable, so loader_path resolves; but it is not under any canonical
storage root, so logical_path/display_name are None. This asymmetry is
intentional: loader_path resolves every registered model-folder base,
logical_path only resolves the canonical storage roots.
"""
extra = tmp_path / "extra_ckpts"
extra.mkdir()
f = extra / "foo.safetensors"
f.touch()
with patch("app.assets.services.path_utils.folder_paths") as mock_fp, patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("checkpoints", [str(extra)])],
):
mock_fp.get_input_directory.return_value = str(tmp_path / "in")
mock_fp.get_output_directory.return_value = str(tmp_path / "out")
mock_fp.get_temp_directory.return_value = str(tmp_path / "tmp")
mock_fp.models_dir = str(tmp_path / "models") # extra is NOT under this
assert compute_loader_path(str(f)) == "foo.safetensors"
assert compute_logical_path(str(f)) is None
assert compute_display_name(str(f)) is None
def test_unknown_path_returns_none(self):
assert compute_loader_path("/some/random/path.png") is None
class TestResolveDestinationFromTags:
def test_extra_tags_are_not_path_components(self, fake_dirs):
base_dir, subdirs = resolve_destination_from_tags(["input", "unit-tests", "foo"])
assert base_dir == os.path.abspath(fake_dirs["input"])
assert subdirs == []
def test_model_upload_rejects_non_writable_registered_folders(self):
with tempfile.TemporaryDirectory() as root:
root_path = Path(root)
checkpoints_dir = root_path / "models" / "checkpoints"
configs_dir = root_path / "models" / "configs"
custom_nodes_dir = root_path / "custom_nodes"
for path in (checkpoints_dir, configs_dir, custom_nodes_dir):
path.mkdir(parents=True)
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
mock_fp.folder_names_and_paths = {
"checkpoints": ([str(checkpoints_dir)], set()),
"configs": ([str(configs_dir)], set()),
"custom_nodes": ([str(custom_nodes_dir)], set()),
}
base_dir, subdirs = resolve_destination_from_tags(
["models", "model_type:checkpoints"]
)
assert base_dir == os.path.abspath(checkpoints_dir)
assert subdirs == []
for folder_name in ("configs", "custom_nodes"):
with pytest.raises(ValueError, match="unknown model category"):
resolve_destination_from_tags(
["models", f"model_type:{folder_name}"]
)

View File

@ -19,7 +19,8 @@ def test_seed_asset_removed_when_file_is_deleted(
"""Asset without hash (seed) whose file disappears:
after triggering sync_seed_assets, Asset + AssetInfo disappear.
"""
# Create a file directly under input/unit-tests/<case> so tags include "unit-tests"
# Create a file directly under input/unit-tests/<case>. Backend tags only
# classify the root; nested path components are not exposed as tags.
case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed"
case_dir.mkdir(parents=True, exist_ok=True)
name = f"seed_{uuid.uuid4().hex[:8]}.bin"
@ -32,7 +33,7 @@ def test_seed_asset_removed_when_file_is_deleted(
# Verify it is visible via API and carries no hash (seed)
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
params={"include_tags": root, "name_contains": name},
timeout=120,
)
body1 = r1.json()
@ -54,7 +55,7 @@ def test_seed_asset_removed_when_file_is_deleted(
# It should disappear (AssetInfo and seed Asset gone)
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
params={"include_tags": root, "name_contains": name},
timeout=120,
)
body2 = r2.json()
@ -132,7 +133,7 @@ def test_hashed_asset_two_asset_infos_both_get_missing(
second_id = b2["id"]
# Remove the single underlying file
p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png")
p = comfy_tmp_base_dir / "input" / get_asset_filename(created["asset_hash"], ".png")
assert p.exists()
p.unlink()
@ -250,8 +251,7 @@ def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match(
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
aid = a["id"]
base = comfy_tmp_base_dir / root / "unit-tests" / scope
p = base / get_asset_filename(a["asset_hash"], ".bin")
p = comfy_tmp_base_dir / root / get_asset_filename(a["asset_hash"], ".bin")
st0 = p.stat()
orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000))

View File

@ -290,7 +290,7 @@ def test_metadata_filename_is_set_for_seed_asset_without_hash(
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
params={"include_tags": root, "name_contains": name},
timeout=120,
)
body = r1.json()

View File

@ -95,7 +95,7 @@ def test_download_chooses_existing_state_and_updates_access_time(
assert t1 > t0
@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True)
@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "model_type:checkpoints"]}], indirect=True)
def test_download_missing_file_returns_404(
http: requests.Session, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict
):

View File

@ -13,7 +13,7 @@ def _seed(asset_factory, make_asset_bytes, count: int, tag: str) -> list[str]:
for n in names:
asset_factory(
n,
["models", "checkpoints", "unit-tests", tag],
["models", "model_type:checkpoints", "unit-tests", tag],
{},
make_asset_bytes(n, size=2048),
)
@ -208,7 +208,7 @@ def test_cursor_walks_for_non_name_sorts(sort_field, http: requests.Session, api
names = []
for i in range(4):
n = f"cursor_{sort_field}_{i:02d}.safetensors"
asset_factory(n, ["models", "checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i))
asset_factory(n, ["models", "model_type:checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i))
names.append(n)
params = {

View File

@ -11,7 +11,7 @@ def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asse
for n in names:
asset_factory(
n,
["models", "checkpoints", "unit-tests", "paging"],
["models", "model_type:checkpoints", "unit-tests", "paging"],
{"epoch": 1},
make_asset_bytes(n, size=2048),
)
@ -45,8 +45,8 @@ def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asse
def test_list_assets_include_exclude_and_name_contains(http: requests.Session, api_base: str, asset_factory):
a = asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024)
b = asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024)
a = asset_factory("inc_a.safetensors", ["models", "model_type:checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024)
b = asset_factory("inc_b.safetensors", ["models", "model_type:checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024)
r = http.get(
api_base + "/api/assets",
@ -81,7 +81,7 @@ def test_list_assets_include_exclude_and_name_contains(http: requests.Session, a
def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-size"]
t = ["models", "model_type:checkpoints", "unit-tests", "lf-size"]
n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors"
asset_factory(n1, t, {}, make_asset_bytes(n1, 1024))
asset_factory(n2, t, {}, make_asset_bytes(n2, 2048))
@ -108,7 +108,7 @@ def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, mak
def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-upd"]
t = ["models", "model_type:checkpoints", "unit-tests", "lf-upd"]
a1 = asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200))
a2 = asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200))
@ -131,7 +131,7 @@ def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make
def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-access"]
t = ["models", "model_type:checkpoints", "unit-tests", "lf-access"]
asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100))
time.sleep(0.02)
a2 = asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100))
@ -154,14 +154,14 @@ def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory
def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-include"]
t = ["models", "model_type:checkpoints", "unit-tests", "lf-include"]
a = asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva"))
asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb"))
# CSV + case-insensitive
# CSV tag filters are whitespace-trimmed and case-sensitive.
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"},
params={"include_tags": "unit-tests,lf-include,alpha"},
timeout=120,
)
b1 = r1.json()
@ -196,14 +196,14 @@ def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factor
def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-exclude"]
t = ["models", "model_type:checkpoints", "unit-tests", "lf-exclude"]
a = asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900))
asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900))
# Exclude uppercase should work
# Exclude filters are case-sensitive.
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"},
params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "beta"},
timeout=120,
)
b1 = r1.json()
@ -225,7 +225,7 @@ def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory,
def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-name"]
t = ["models", "model_type:checkpoints", "unit-tests", "lf-name"]
a1 = asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800))
a2 = asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800))
@ -261,7 +261,7 @@ def test_list_assets_name_contains_case_and_specials(http, api_base, asset_facto
def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"]
t = ["models", "model_type:checkpoints", "unit-tests", "lf-pagelimits"]
asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600))
asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600))
asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600))
@ -319,7 +319,7 @@ def test_list_assets_name_contains_literal_underscore(
- foobar.safetensors (must NOT match)
"""
scope = f"lf-underscore-{uuid.uuid4().hex[:6]}"
tags = ["models", "checkpoints", "unit-tests", scope]
tags = ["models", "model_type:checkpoints", "unit-tests", scope]
a = asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700))
b = asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700))

View File

@ -5,7 +5,7 @@ def test_meta_and_across_keys_and_types(
http, api_base: str, asset_factory, make_asset_bytes
):
name = "mf_and_mix.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-and"]
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-and"]
meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
asset_factory(name, tags, meta, make_asset_bytes(name, 4096))
@ -41,7 +41,7 @@ def test_meta_and_across_keys_and_types(
def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes):
name = "mf_types.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-types"]
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-types"]
meta = {"epoch": 1, "active": True}
asset_factory(name, tags, meta, make_asset_bytes(name))
@ -95,7 +95,7 @@ def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory,
def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes):
name = "mf_list_scalars.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-list"]
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-list"]
meta = {"flags": ["red", "green"]}
asset_factory(name, tags, meta, make_asset_bytes(name, 3000))
@ -134,7 +134,7 @@ def test_meta_none_semantics_missing_or_null_and_any_of_with_none(
http, api_base, asset_factory, make_asset_bytes
):
# a1: key missing; a2: explicit null; a3: concrete value
t = ["models", "checkpoints", "unit-tests", "mf-none"]
t = ["models", "model_type:checkpoints", "unit-tests", "mf-none"]
a1 = asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1"))
a2 = asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2"))
a3 = asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3"))
@ -166,7 +166,7 @@ def test_meta_none_semantics_missing_or_null_and_any_of_with_none(
def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes):
name = "mf_nested_json.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-nested"]
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-nested"]
cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}}
asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200))
@ -197,7 +197,7 @@ def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_as
def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes):
name = "mf_list_objects.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-objlist"]
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-objlist"]
transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}]
asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048))
@ -228,7 +228,7 @@ def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_b
def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes):
name = "mf_keys_unicode.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-keys"]
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-keys"]
meta = {
"weird.key": "v1",
"path/like": 7,
@ -259,7 +259,7 @@ def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_
def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"]
t = ["models", "model_type:checkpoints", "unit-tests", "mf-zero-bool"]
a0 = asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025))
a1 = asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026))
@ -286,7 +286,7 @@ def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_as
def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes):
name = "mf_mixed_list.safetensors"
tags = ["models", "checkpoints", "unit-tests", "mf-mixed"]
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-mixed"]
meta = {"mix": ["1", 1, True, None]}
asset_factory(name, tags, meta, make_asset_bytes(name, 1999))
@ -311,7 +311,7 @@ def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, mak
def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes):
# Use a unique scope tag to avoid interference
t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"]
t = ["models", "model_type:checkpoints", "unit-tests", "mf-unknown-scope"]
x = asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua"))
y = asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub"))
@ -340,13 +340,13 @@ def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_
# alpha matches epoch=1; beta has epoch=2
a = asset_factory(
"mf_tag_alpha.safetensors",
["models", "checkpoints", "unit-tests", "mf-tag", "alpha"],
["models", "model_type:checkpoints", "unit-tests", "mf-tag", "alpha"],
{"epoch": 1},
make_asset_bytes("alpha"),
)
b = asset_factory(
"mf_tag_beta.safetensors",
["models", "checkpoints", "unit-tests", "mf-tag", "beta"],
["models", "model_type:checkpoints", "unit-tests", "mf-tag", "beta"],
{"epoch": 2},
make_asset_bytes("beta"),
)
@ -367,7 +367,7 @@ def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_
def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes):
# Three assets in same scope with different sizes and a common filter key
t = ["models", "checkpoints", "unit-tests", "mf-sort"]
t = ["models", "model_type:checkpoints", "unit-tests", "mf-sort"]
n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors"
asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024))
asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048))

View File

@ -29,7 +29,7 @@ def create_seed_file(comfy_tmp_base_dir: Path):
def find_asset(http: requests.Session, api_base: str):
"""Query API for assets matching scope and optional name."""
def _find(scope: str, name: str | None = None) -> list[dict]:
params = {"include_tags": f"unit-tests,{scope}"}
params = {"limit": "500"}
if name:
params["name_contains"] = name
r = http.get(f"{api_base}/api/assets", params=params, timeout=120)
@ -91,7 +91,7 @@ def test_hashed_asset_not_pruned_when_file_missing(
data = make_asset_bytes("test", 2048)
a = asset_factory("test.bin", ["input", "unit-tests", scope], {}, data)
path = comfy_tmp_base_dir / "input" / "unit-tests" / scope / get_asset_filename(a["asset_hash"], ".bin")
path = comfy_tmp_base_dir / "input" / get_asset_filename(a["asset_hash"], ".bin")
path.unlink()
trigger_sync_seed_assets(http, api_base)
@ -108,18 +108,20 @@ def test_prune_across_multiple_roots(
):
"""Prune correctly handles assets across input and output roots."""
scope = f"multi-{uuid.uuid4().hex[:6]}"
input_fp = create_seed_file("input", scope, "input.bin")
create_seed_file("output", scope, "output.bin")
input_name = f"{scope}-input.bin"
output_name = f"{scope}-output.bin"
input_fp = create_seed_file("input", scope, input_name)
create_seed_file("output", scope, output_name)
trigger_sync_seed_assets(http, api_base)
assert len(find_asset(scope)) == 2
assert find_asset(scope, input_name)
assert find_asset(scope, output_name)
input_fp.unlink()
trigger_sync_seed_assets(http, api_base)
remaining = find_asset(scope)
assert len(remaining) == 1
assert remaining[0]["name"] == "output.bin"
assert not find_asset(scope, input_name)
assert find_asset(scope, output_name)
@pytest.mark.parametrize("dirname", ["100%_done", "my_folder_name", "has spaces"])

View File

@ -10,9 +10,9 @@ def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict)
body1 = r1.json()
assert r1.status_code == 200
names = [t["name"] for t in body1["tags"]]
# A few system tags from migration should exist:
# A few selected contract tags should exist.
assert "models" in names
assert "checkpoints" in names
assert "model_type:checkpoints" in names
# Only used tags before we add anything new from this test cycle
r2 = http.get(api_base + "/api/tags", params={"include_zero": "false"}, timeout=120)
@ -21,7 +21,7 @@ def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict)
# We already seeded one asset via fixture, so used tags must be non-empty
used_names = [t["name"] for t in body2["tags"]]
assert "models" in used_names
assert "checkpoints" in used_names
assert "model_type:checkpoints" in used_names
# Prefix filter should refine the list
r3 = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}, timeout=120)
@ -45,7 +45,7 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory,
body1 = r1.json()
assert r1.status_code == 200
names = [t["name"] for t in body1["tags"]]
assert "models" in names and "checkpoints" in names
assert "models" in names and "model_type:checkpoints" in names
# Create a short-lived asset under input with a unique custom tag
scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}"
@ -89,28 +89,28 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory,
def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
# Add tags with duplicates and mixed case
payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]}
# Add tags with duplicates while preserving source case.
payload_add = {"tags": ["NewTag", "unit-tests", "NewTag", "BETA"]}
r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add, timeout=120)
b1 = r1.json()
assert r1.status_code == 200, b1
# normalized, deduplicated; 'unit-tests' was already present from the seed
assert set(b1["added"]) == {"newtag", "beta"}
# stripped, deduplicated; 'unit-tests' was already present from the seed
assert set(b1["added"]) == {"NewTag", "BETA"}
assert set(b1["already_present"]) == {"unit-tests"}
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
assert "NewTag" in b1["total_tags"] and "BETA" in b1["total_tags"]
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
g = rg.json()
assert rg.status_code == 200
tags_now = set(g["tags"])
assert {"newtag", "beta"}.issubset(tags_now)
assert {"NewTag", "BETA"}.issubset(tags_now)
# Remove a tag and a non-existent tag
payload_del = {"tags": ["newtag", "does-not-exist"]}
payload_del = {"tags": ["NewTag", "does-not-exist"]}
r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del, timeout=120)
b2 = r2.json()
assert r2.status_code == 200
assert set(b2["removed"]) == {"newtag"}
assert set(b2["removed"]) == {"NewTag"}
assert set(b2["not_present"]) == {"does-not-exist"}
# Verify remaining tags after deletion
@ -118,8 +118,44 @@ def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset
g2 = rg2.json()
assert rg2.status_code == 200
tags_later = set(g2["tags"])
assert "newtag" not in tags_later
assert "beta" in tags_later # still present
assert "NewTag" not in tags_later
assert "BETA" in tags_later # still present
def test_add_system_looking_tags_allowed_as_labels(
http: requests.Session, api_base: str, seeded_asset: dict
):
aid = seeded_asset["id"]
response = http.post(
f"{api_base}/api/assets/{aid}/tags",
json={
"tags": [
"models",
"model_type:manual",
"model:true",
"models:foo",
"input:true",
"output:true",
"uploaded:true",
"temp:true",
"temporary",
]
},
timeout=120,
)
body = response.json()
assert response.status_code == 200, body
assert "models" in body["total_tags"]
assert "model_type:manual" in body["total_tags"]
assert "model:true" in body["total_tags"]
assert "models:foo" in body["total_tags"]
assert "input:true" in body["total_tags"]
assert "output:true" in body["total_tags"]
assert "uploaded:true" in body["total_tags"]
assert "temp:true" in body["total_tags"]
assert "temporary" in body["total_tags"]
def test_tags_list_order_and_prefix(http: requests.Session, api_base: str, seeded_asset: dict):

View File

@ -1,11 +1,14 @@
import json
import uuid
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import requests
import pytest
from app.assets.api.schemas_in import UploadAssetSpec
from app.assets.api.schemas_out import Asset, AssetCreated
from helpers import get_asset_filename
def test_asset_created_inherits_hash_field():
@ -20,9 +23,18 @@ def test_asset_created_inherits_hash_field():
assert AssetCreated.model_fields["hash"].annotation == Asset.model_fields["hash"].annotation
def test_upload_asset_spec_ignores_subfolder_field():
spec = UploadAssetSpec.model_validate(
{"tags": ["input"], "subfolder": "pasted", "name": "image.png"}
)
assert "subfolder" not in UploadAssetSpec.model_fields
assert not hasattr(spec, "subfolder")
def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes):
name = "dup_a.safetensors"
tags = ["models", "checkpoints", "unit-tests", "alpha"]
tags = ["models", "model_type:checkpoints", "unit-tests", "alpha"]
meta = {"purpose": "dup"}
data = make_asset_bytes(name)
files = {"file": (name, data, "application/octet-stream")}
@ -43,6 +55,8 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma
assert a2["asset_hash"] == a1["asset_hash"]
assert a2["hash"] == a1["hash"]
assert a2["id"] != a1["id"] # new reference with same content
assert a2.get("loader_path") is None
assert a2.get("display_name") is None
# Third upload with the same data but different name also creates new AssetReference
files = {"file": (name, data, "application/octet-stream")}
@ -53,12 +67,14 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma
assert a3["asset_hash"] == a1["asset_hash"]
assert a3["id"] != a1["id"]
assert a3["id"] != a2["id"]
assert a3.get("loader_path") is None
assert a3.get("display_name") is None
def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str):
# Seed a small file first
name = "fastpath_seed.safetensors"
tags = ["models", "checkpoints", "unit-tests"]
tags = ["input", "unit-tests"]
meta = {}
files = {"file": (name, b"B" * 1024, "application/octet-stream")}
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
@ -69,9 +85,10 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
assert b1["hash"] == h
# Now POST /api/assets with only hash and no file
hash_only_tags = ["models", "checkpoints", "unit-tests", "hash-labels"]
files = [
("hash", (None, h)),
("tags", (None, json.dumps(tags))),
("tags", (None, json.dumps(hash_only_tags))),
("name", (None, "fastpath_copy.safetensors")),
("user_metadata", (None, json.dumps({"purpose": "copy"}))),
]
@ -81,6 +98,53 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
assert b2["created_new"] is False
assert b2["asset_hash"] == h
assert b2["hash"] == h
assert "models" in b2["tags"]
assert "checkpoints" in b2["tags"]
assert "uploaded" not in b2["tags"]
assert not any(tag.startswith("model_type:") for tag in b2["tags"])
assert b2.get("loader_path") is None
assert b2.get("display_name") is None
rg = http.get(f"{api_base}/api/assets/{b2['id']}", timeout=120)
detail = rg.json()
assert rg.status_code == 200, detail
assert detail.get("loader_path") is None
assert detail.get("display_name") is None
def test_create_from_hash_with_model_tags_does_not_synthesize_loader_path(
http: requests.Session, api_base: str
):
seed_name = "from_hash_seed.safetensors"
seed_tags = ["models", "model_type:checkpoints", "unit-tests"]
files = {"file": (seed_name, b"D" * 1024, "application/octet-stream")}
form = {
"tags": json.dumps(seed_tags),
"name": seed_name,
"user_metadata": json.dumps({}),
}
seed_r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
seed = seed_r.json()
assert seed_r.status_code == 201, seed
payload = {
"hash": seed["asset_hash"],
"name": "from_hash_copy.safetensors",
"tags": ["models", "model_type:checkpoints", "unit-tests", "spoofed"],
}
created_r = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120)
created = created_r.json()
assert created_r.status_code == 201, created
assert created["created_new"] is False
assert created["asset_hash"] == seed["asset_hash"]
assert created.get("loader_path") is None
assert created.get("display_name") is None
detail_r = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
detail = detail_r.json()
assert detail_r.status_code == 200, detail
assert detail.get("loader_path") is None
assert detail.get("display_name") is None
def test_upload_fastpath_with_known_hash_and_file(
@ -88,7 +152,7 @@ def test_upload_fastpath_with_known_hash_and_file(
):
# Seed
files = {"file": ("seed.safetensors", b"C" * 128, "application/octet-stream")}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})}
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})}
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
b1 = r1.json()
assert r1.status_code == 201, b1
@ -104,11 +168,49 @@ def test_upload_fastpath_with_known_hash_and_file(
assert b2["created_new"] is False
assert b2["asset_hash"] == h
assert b2["hash"] == h
assert "checkpoints" in b2["tags"]
assert "uploaded" not in b2["tags"]
assert not any(tag == "model_type:checkpoints" for tag in b2["tags"])
def test_duplicate_byte_upload_is_reference_only_and_does_not_need_destination(
http: requests.Session, api_base: str
):
data = b"duplicate-reference-only" * 64
seed_files = {"file": ("duplicate-seed.bin", data, "application/octet-stream")}
seed_form = {
"tags": json.dumps(["input", "unit-tests", "duplicate-seed"]),
"name": "duplicate-seed.bin",
"user_metadata": json.dumps({}),
}
seed_response = http.post(api_base + "/api/assets", data=seed_form, files=seed_files, timeout=120)
seed = seed_response.json()
assert seed_response.status_code == 201, seed
duplicate_files = {"file": ("duplicate-copy.bin", data, "application/octet-stream")}
duplicate_form = {
"tags": json.dumps(["not-a-destination", "unit-tests", "duplicate-copy"]),
"name": "duplicate-copy.bin",
"user_metadata": json.dumps({}),
}
duplicate_response = http.post(
api_base + "/api/assets", data=duplicate_form, files=duplicate_files, timeout=120
)
duplicate = duplicate_response.json()
assert duplicate_response.status_code == 200, duplicate
assert duplicate["created_new"] is False
assert duplicate["asset_hash"] == seed["asset_hash"]
assert "not-a-destination" in duplicate["tags"]
assert "uploaded" not in duplicate["tags"]
assert "input" not in duplicate["tags"]
assert duplicate.get("loader_path") is None
assert duplicate.get("display_name") is None
def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str):
data = [
("tags", "models,checkpoints"),
("tags", "models,model_type:checkpoints"),
("tags", json.dumps(["unit-tests", "alpha"])),
("name", "merge.safetensors"),
("user_metadata", json.dumps({"u": 1})),
@ -124,7 +226,77 @@ def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base
detail = rg.json()
assert rg.status_code == 200, detail
tags = set(detail["tags"])
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
assert {"models", "model_type:checkpoints", "unit-tests", "alpha"}.issubset(tags)
@pytest.mark.parametrize(
(
"tags",
"extension",
"expected_prefix",
"expected_display_prefix",
),
[
(["input", "unit-tests"], ".png", "input", ""),
(
["models", "model_type:checkpoints", "unit-tests"],
".safetensors",
"models/checkpoints",
"checkpoints/",
),
],
)
def test_upload_response_includes_loader_path_and_display_name(
tags: list[str],
extension: str,
expected_prefix: str,
expected_display_prefix: str,
http: requests.Session,
api_base: str,
make_asset_bytes,
):
scope = f"response-paths-{uuid.uuid4().hex[:6]}"
scoped_tags = [*tags, scope]
name = f"asset_response_path{extension}"
files = {"file": (name, make_asset_bytes(name, 1024), "application/octet-stream")}
form = {
"tags": json.dumps(scoped_tags),
"name": name,
"user_metadata": json.dumps({}),
}
created_r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
created = created_r.json()
assert created_r.status_code in (200, 201), created
stored_filename = get_asset_filename(created["asset_hash"], extension)
expected_suffix = stored_filename
expected_logical_path = f"{expected_prefix}/{expected_suffix}"
expected_display_name = f"{expected_display_prefix}{expected_suffix}"
# In-root loader path: model category dropped, no subfolders here -> just the filename.
expected_loader_path = expected_suffix
assert created["loader_path"] == expected_loader_path
assert created["logical_path"] == expected_logical_path
assert created["display_name"] == expected_display_name
detail_r = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
detail = detail_r.json()
assert detail_r.status_code == 200, detail
assert detail["loader_path"] == expected_loader_path
assert detail["logical_path"] == expected_logical_path
assert detail["display_name"] == expected_display_name
list_r = http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "limit": "50"},
timeout=120,
)
listed = list_r.json()
assert list_r.status_code == 200, listed
match = next(a for a in listed["assets"] if a["id"] == created["id"])
assert match["loader_path"] == expected_loader_path
assert match["logical_path"] == expected_logical_path
assert match["display_name"] == expected_display_name
@pytest.mark.parametrize("root", ["input", "output"])
@ -192,16 +364,55 @@ def test_create_from_hash_endpoint_404(http: requests.Session, api_base: str):
assert body["error"]["code"] == "ASSET_NOT_FOUND"
def test_create_from_hash_accepts_arbitrary_system_looking_tags(
http: requests.Session, api_base: str
):
files = {"file": ("hash-seed.bin", b"hash-seed" * 64, "application/octet-stream")}
form = {
"tags": json.dumps(["input", "unit-tests", "hash-seed"]),
"name": "hash-seed.bin",
"user_metadata": json.dumps({}),
}
seed_response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
seed = seed_response.json()
assert seed_response.status_code == 201, seed
response = http.post(
api_base + "/api/assets/from-hash",
json={
"hash": seed["asset_hash"],
"name": "hash-copy.bin",
"tags": [
"models",
"model:true",
"models:foo",
"temporary:true",
"unit-tests",
"hash-copy",
],
},
timeout=120,
)
body = response.json()
assert response.status_code == 201, body
assert "models" in body["tags"]
assert "model:true" in body["tags"]
assert "models:foo" in body["tags"]
assert "temporary:true" in body["tags"]
assert "uploaded" not in body["tags"]
def test_upload_zero_byte_rejected(http: requests.Session, api_base: str):
files = {"file": ("empty.safetensors", b"", "application/octet-stream")}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})}
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == "EMPTY_UPLOAD"
def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str):
def test_upload_rejects_arbitrary_labels_without_required_destination_role(http: requests.Session, api_base: str):
files = {"file": ("badroot.bin", b"A" * 64, "application/octet-stream")}
form = {"tags": json.dumps(["not-a-root", "whatever"]), "name": "badroot.bin", "user_metadata": json.dumps({})}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
@ -212,7 +423,7 @@ def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str)
def test_upload_user_metadata_must_be_json(http: requests.Session, api_base: str):
files = {"file": ("badmeta.bin", b"A" * 128, "application/octet-stream")}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"}
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400
@ -228,7 +439,7 @@ def test_upload_requires_multipart(http: requests.Session, api_base: str):
def test_upload_missing_file_and_hash(http: requests.Session, api_base: str):
files = [
("tags", (None, json.dumps(["models", "checkpoints", "unit-tests"]))),
("tags", (None, json.dumps(["models", "model_type:checkpoints", "unit-tests"]))),
("name", (None, "x.safetensors")),
]
r = http.post(api_base + "/api/assets", files=files, timeout=120)
@ -237,17 +448,33 @@ def test_upload_missing_file_and_hash(http: requests.Session, api_base: str):
assert body["error"]["code"] == "MISSING_FILE"
def test_upload_models_unknown_category(http: requests.Session, api_base: str):
def test_upload_models_unknown_model_type(http: requests.Session, api_base: str):
files = {"file": ("m.safetensors", b"A" * 128, "application/octet-stream")}
form = {"tags": json.dumps(["models", "no_such_category", "unit-tests"]), "name": "m.safetensors"}
form = {"tags": json.dumps(["models", "model_type:no_such_category", "unit-tests"]), "name": "m.safetensors"}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400
assert r.status_code == 400, body
assert body["error"]["code"] == "INVALID_BODY"
assert body["error"]["message"].startswith("unknown models category")
def test_upload_models_requires_category(http: requests.Session, api_base: str):
@pytest.mark.parametrize("model_type", ["configs", "custom_nodes"])
def test_upload_models_rejects_non_model_registered_folder(
model_type: str, http: requests.Session, api_base: str
):
files = {"file": ("not-a-model.py", b"A" * 128, "application/octet-stream")}
form = {
"tags": json.dumps(["models", f"model_type:{model_type}", "unit-tests"]),
"name": "not-a-model.py",
}
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 400, body
assert body["error"]["code"] == "INVALID_BODY"
def test_upload_models_requires_model_type(http: requests.Session, api_base: str):
files = {"file": ("nocat.safetensors", b"A" * 64, "application/octet-stream")}
form = {"tags": json.dumps(["models"]), "name": "nocat.safetensors", "user_metadata": json.dumps({})}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
@ -256,13 +483,152 @@ def test_upload_models_requires_category(http: requests.Session, api_base: str):
assert body["error"]["code"] == "INVALID_BODY"
def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
def test_upload_extra_tags_are_labels_not_path_components(http: requests.Session, api_base: str):
files = {"file": ("evil.safetensors", b"A" * 256, "application/octet-stream")}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"}
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
assert r.status_code == 201, body
assert ".." in body["tags"]
assert "zzz" in body["tags"]
assert "models" in body["tags"]
assert "model_type:checkpoints" in body["tags"]
@pytest.mark.parametrize(
("subfolder", "expected_tag", "unexpected_tags"),
[
("custom/session", None, {"custom", "session"}),
("pasted", "pasted", set()),
],
)
def test_upload_image_accepts_arbitrary_subfolder_but_only_known_values_become_tags(
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
subfolder: str,
expected_tag: str | None,
unexpected_tags: set[str],
):
name = f"upload-image-{uuid.uuid4().hex}.png"
files = {"image": (name, b"image-upload" * 64, "image/png")}
form = {"type": "input", "subfolder": subfolder}
response = http.post(api_base + "/upload/image", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 200, body
assert body["subfolder"] == subfolder
assert (comfy_tmp_base_dir / "input" / subfolder / body["name"]).exists()
asset = body["asset"]
tags = set(asset["tags"])
assert "input" in tags
assert "uploaded" in tags
if expected_tag:
assert expected_tag in tags
assert tags.isdisjoint(unexpected_tags)
def test_multipart_upload_accepts_system_looking_extra_labels(
http: requests.Session, api_base: str
):
files = {"file": ("relaxed-labels.bin", b"relaxed" * 64, "application/octet-stream")}
form = {
"tags": json.dumps(
[
"input",
"unit-tests",
"model:true",
"models:foo",
"temporary",
"uploaded:true",
]
),
"name": "relaxed-labels.bin",
"user_metadata": json.dumps({}),
}
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 201, body
assert "input" in body["tags"]
assert "model:true" in body["tags"]
assert "models:foo" in body["tags"]
assert "temporary" in body["tags"]
assert "uploaded:true" in body["tags"]
def test_multipart_upload_rejects_ambiguous_destination_roles(
http: requests.Session, api_base: str
):
files = {"file": ("ambiguous.bin", b"ambiguous" * 64, "application/octet-stream")}
form = {
"tags": json.dumps(["input", "output", "unit-tests"]),
"name": "ambiguous.bin",
"user_metadata": json.dumps({}),
}
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 400, body
assert body["error"]["code"] == "INVALID_BODY"
def test_multipart_upload_rejects_multiple_model_types_for_models_destination(
http: requests.Session, api_base: str
):
files = {"file": ("ambiguous-model.safetensors", b"ambiguous-model" * 64, "application/octet-stream")}
form = {
"tags": json.dumps(
["models", "model_type:checkpoints", "model_type:loras", "unit-tests"]
),
"name": "ambiguous-model.safetensors",
"user_metadata": json.dumps({}),
}
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 400, body
assert body["error"]["code"] == "INVALID_BODY"
@pytest.mark.parametrize(
("tags", "expected_root", "extension"),
[
(["input", "unit-tests", "upload-location-input"], "input", ".bin"),
(["output", "unit-tests", "upload-location-output"], "output", ".bin"),
(
["models", "model_type:checkpoints", "unit-tests", "upload-location-model"],
"models/checkpoints",
".safetensors",
),
],
)
def test_multipart_upload_role_selects_write_location(
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
tags: list[str],
expected_root: str,
extension: str,
):
role = next(tag for tag in tags if tag in {"input", "models", "output"})
name = f"{role}-role-upload{extension}"
files = {"file": (name, f"{role}-role-bytes".encode() * 64, "application/octet-stream")}
form = {
"tags": json.dumps(tags),
"name": name,
"user_metadata": json.dumps({}),
}
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 201, body
stored_name = get_asset_filename(body["asset_hash"], extension)
expected_disk_path = comfy_tmp_base_dir / expected_root / stored_name
assert expected_disk_path.exists()
def test_upload_empty_tags_rejected(http: requests.Session, api_base: str):

View File

@ -1,137 +0,0 @@
"""Tests for pre-execution validation that a node is actually executable.
validate_prompt rejects a node whose declared entry point does not resolve to a
real method (a V1 FUNCTION typo, or a V3 node missing its execute override) before
any node runs, attributing the error to the offending node.
"""
import asyncio
import nodes
from comfy_api.latest import io
from execution import node_not_executable_reason, validate_prompt
class _GoodV1Node:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "run"
OUTPUT_NODE = True
CATEGORY = "Test"
def run(self):
return (None,)
class _TypoV1Node:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "invert" # method below is misspelled
OUTPUT_NODE = True
CATEGORY = "Test"
def invvert(self):
return (None,)
class _SideEffectInitV1Node:
"""Valid class-level method, but a constructor that must never run in validation."""
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "run"
OUTPUT_NODE = True
CATEGORY = "Test"
def __init__(self):
raise RuntimeError("__init__ must not run during validation")
def run(self):
return (None,)
def _v3_schema(node_id):
return io.Schema(
node_id=node_id,
display_name=node_id,
category="Test",
inputs=[],
outputs=[io.Image.Output()],
is_output_node=True,
)
class _GoodV3Node(io.ComfyNode):
@classmethod
def define_schema(cls):
return _v3_schema("GoodV3Node")
@classmethod
def execute(cls):
return io.NodeOutput(None)
class _TypoV3Node(io.ComfyNode):
@classmethod
def define_schema(cls):
return _v3_schema("TypoV3Node")
@classmethod
def exicute(cls): # typo: should be "execute"
return io.NodeOutput(None)
def _register(class_type, class_def):
nodes.NODE_CLASS_MAPPINGS[class_type] = class_def
def _validate(class_type):
prompt = {"1": {"class_type": class_type, "inputs": {}}}
return asyncio.run(validate_prompt("pid", prompt, None))
def test_good_node_passes():
_register("GoodV1Node", _GoodV1Node)
assert node_not_executable_reason(_GoodV1Node, "GoodV1Node") is None
valid, _, _, _ = _validate("GoodV1Node")
assert valid is True
def test_typo_node_rejected_with_node_error():
_register("TypoV1Node", _TypoV1Node)
valid, error, _, node_errors = _validate("TypoV1Node")
assert valid is False
assert error["type"] == "invalid_node_definition"
assert node_errors["1"]["class_type"] == "TypoV1Node"
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"
assert "invert" in node_errors["1"]["errors"][0]["details"]
def test_validation_does_not_instantiate_node():
"""A valid node is not constructed during validation, so __init__ never runs."""
_register("SideEffectInitV1Node", _SideEffectInitV1Node)
assert node_not_executable_reason(_SideEffectInitV1Node, "SideEffectInitV1Node") is None
valid, _, _, _ = _validate("SideEffectInitV1Node")
assert valid is True
def test_good_v3_node_passes():
_register("GoodV3Node", _GoodV3Node)
assert node_not_executable_reason(_GoodV3Node, "GoodV3Node") is None
valid, _, _, _ = _validate("GoodV3Node")
assert valid is True
def test_typo_v3_node_rejected_with_node_error():
_register("TypoV3Node", _TypoV3Node)
valid, error, _, node_errors = _validate("TypoV3Node")
assert valid is False
assert error["type"] == "invalid_node_definition"
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"

View File

@ -29,6 +29,8 @@ class TestFeatureFlags:
features = get_server_features()
assert "supports_preview_metadata" in features
assert features["supports_preview_metadata"] is True
assert "supports_model_type_tags" in features
assert features["supports_model_type_tags"] is True
assert "max_upload_size" in features
assert isinstance(features["max_upload_size"], (int, float))

View File

@ -12,6 +12,8 @@ class TestWebSocketFeatureFlags:
# Check expected server features
assert "supports_preview_metadata" in features
assert features["supports_preview_metadata"] is True
assert "supports_model_type_tags" in features
assert features["supports_model_type_tags"] is True
assert "max_upload_size" in features
assert isinstance(features["max_upload_size"], (int, float))
@ -75,3 +77,5 @@ class TestWebSocketFeatureFlags:
assert server_message["type"] == "feature_flags"
assert "supports_preview_metadata" in server_message["data"]
assert server_message["data"]["supports_preview_metadata"] is True
assert "supports_model_type_tags" in server_message["data"]
assert server_message["data"]["supports_model_type_tags"] is True