mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 11:17:06 +08:00
Compare commits
5 Commits
synap5e/as
...
ListInput
| Author | SHA1 | Date | |
|---|---|---|---|
| 330a37db94 | |||
| 30b19c6872 | |||
| 2dd281d8a6 | |||
| 911e0b2acf | |||
| 46c7e8055c |
38
.github/workflows/ci-cursor-review.yml
vendored
38
.github/workflows/ci-cursor-review.yml
vendored
@ -1,38 +0,0 @@
|
||||
name: CI - Cursor Review
|
||||
|
||||
# Thin caller for the shared reusable cursor-review workflow in
|
||||
# Comfy-Org/github-workflows. The review logic (panel matrix, judge
|
||||
# consolidation, prompts, extract/post/notify scripts) lives there as the
|
||||
# single source of truth, so this repo only carries the repo-specific diff
|
||||
# excludes.
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [labeled, unlabeled]
|
||||
|
||||
concurrency:
|
||||
group: cursor-review-pr-${{ github.event.pull_request.number }}-${{ github.event.label.name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
cursor-review:
|
||||
if: github.event.label.name == 'cursor-review'
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
# SHA-pinned per zizmor `unpinned-uses: hash-pin`. Bump this SHA to pick up
|
||||
# upstream changes; keep `workflows_ref` matching so prompts/scripts load
|
||||
# from the same commit as the workflow definition.
|
||||
uses: Comfy-Org/github-workflows/.github/workflows/cursor-review.yml@047ca48febe3a6647608ed2e0c4331b491cb9d6a # github-workflows#9
|
||||
with:
|
||||
workflows_ref: 047ca48febe3a6647608ed2e0c4331b491cb9d6a
|
||||
diff_excludes: >-
|
||||
:!**/.claude/**
|
||||
:!**/dist/**
|
||||
:!**/vendor/**
|
||||
:!**/*.generated.*
|
||||
:!**/*.min.js
|
||||
:!**/*.min.css
|
||||
secrets:
|
||||
CURSOR_API_KEY: ${{ secrets.CURSOR_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
@ -1,107 +0,0 @@
|
||||
"""
|
||||
Allow case-sensitive tag names.
|
||||
|
||||
Revision ID: 0005_allow_case_sensitive_tags
|
||||
Revises: 0004_drop_tag_type
|
||||
Create Date: 2026-06-16
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "0005_allow_case_sensitive_tags"
|
||||
down_revision = "0004_drop_tag_type"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
if bind.dialect.name == "sqlite":
|
||||
# SQLite cannot ALTER/DROP CHECK constraints. Recreate the small tag
|
||||
# vocabulary table without the lowercase constraint while preserving
|
||||
# existing tag names.
|
||||
op.execute("PRAGMA foreign_keys=OFF")
|
||||
try:
|
||||
op.execute(
|
||||
"CREATE TABLE tags_new ("
|
||||
"name VARCHAR(512) NOT NULL, "
|
||||
"CONSTRAINT pk_tags PRIMARY KEY (name)"
|
||||
")"
|
||||
)
|
||||
op.execute("INSERT INTO tags_new(name) SELECT name FROM tags")
|
||||
op.execute("DROP TABLE tags")
|
||||
op.execute("ALTER TABLE tags_new RENAME TO tags")
|
||||
finally:
|
||||
op.execute("PRAGMA foreign_keys=ON")
|
||||
return
|
||||
|
||||
op.drop_constraint("ck_tags_ck_tags_lowercase", "tags", type_="check")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Existing mixed-case tags cannot satisfy the old constraint. Lowercase them
|
||||
# before restoring it, merging duplicate vocabulary/link rows that collide.
|
||||
bind = op.get_bind()
|
||||
|
||||
tag_names = [row[0] for row in bind.execute(sa.text("SELECT name FROM tags"))]
|
||||
existing_names = set(tag_names)
|
||||
lowercase_names = sorted({name.lower() for name in tag_names})
|
||||
missing_lowercase_rows = [
|
||||
{"name": name} for name in lowercase_names if name not in existing_names
|
||||
]
|
||||
if missing_lowercase_rows:
|
||||
bind.execute(sa.text("INSERT INTO tags(name) VALUES (:name)"), missing_lowercase_rows)
|
||||
|
||||
link_rows = bind.execute(
|
||||
sa.text(
|
||||
"SELECT asset_reference_id, tag_name, origin, added_at "
|
||||
"FROM asset_reference_tags "
|
||||
"ORDER BY asset_reference_id, tag_name"
|
||||
)
|
||||
).mappings()
|
||||
deduped_links = {}
|
||||
for row in link_rows:
|
||||
key = (row["asset_reference_id"], row["tag_name"].lower())
|
||||
deduped_links.setdefault(
|
||||
key,
|
||||
{
|
||||
"asset_reference_id": row["asset_reference_id"],
|
||||
"tag_name": row["tag_name"].lower(),
|
||||
"origin": row["origin"],
|
||||
"added_at": row["added_at"],
|
||||
},
|
||||
)
|
||||
|
||||
op.execute("DELETE FROM asset_reference_tags")
|
||||
if deduped_links:
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"INSERT INTO asset_reference_tags "
|
||||
"(asset_reference_id, tag_name, origin, added_at) "
|
||||
"VALUES (:asset_reference_id, :tag_name, :origin, :added_at)"
|
||||
),
|
||||
list(deduped_links.values()),
|
||||
)
|
||||
op.execute("DELETE FROM tags WHERE name != lower(name)")
|
||||
|
||||
if bind.dialect.name == "sqlite":
|
||||
op.execute("PRAGMA foreign_keys=OFF")
|
||||
try:
|
||||
op.execute(
|
||||
"CREATE TABLE tags_new ("
|
||||
"name VARCHAR(512) NOT NULL, "
|
||||
"CONSTRAINT pk_tags PRIMARY KEY (name), "
|
||||
"CONSTRAINT ck_tags_lowercase CHECK (name = lower(name))"
|
||||
")"
|
||||
)
|
||||
op.execute("INSERT INTO tags_new(name) SELECT name FROM tags")
|
||||
op.execute("DROP TABLE tags")
|
||||
op.execute("ALTER TABLE tags_new RENAME TO tags")
|
||||
finally:
|
||||
op.execute("PRAGMA foreign_keys=ON")
|
||||
return
|
||||
|
||||
op.create_check_constraint(
|
||||
"ck_tags_ck_tags_lowercase", "tags", "name = lower(name)"
|
||||
)
|
||||
@ -1,30 +0,0 @@
|
||||
"""
|
||||
Add loader_path column to asset_references.
|
||||
|
||||
Stores the in-root loader path (path relative to the storage root with the
|
||||
top-level model category dropped) derived from file_path at scan/ingest time,
|
||||
so the assets API can return it without re-resolving against every registered
|
||||
model-folder base on every request.
|
||||
|
||||
Revision ID: 0006_add_loader_path
|
||||
Revises: 0005_allow_case_sensitive_tags
|
||||
Create Date: 2026-07-02
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0006_add_loader_path"
|
||||
down_revision = "0005_allow_case_sensitive_tags"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with op.batch_alter_table("asset_references") as batch_op:
|
||||
batch_op.add_column(sa.Column("loader_path", sa.Text(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
with op.batch_alter_table("asset_references") as batch_op:
|
||||
batch_op.drop_column("loader_path")
|
||||
@ -10,6 +10,7 @@ from typing import Any
|
||||
from aiohttp import web
|
||||
from pydantic import ValidationError
|
||||
|
||||
import folder_paths
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in, schemas_out
|
||||
from app.assets.services import schemas
|
||||
@ -39,10 +40,6 @@ from app.assets.services import (
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.services.cursor import InvalidCursorError
|
||||
from app.assets.services.path_utils import (
|
||||
compute_display_name,
|
||||
compute_loader_path,
|
||||
)
|
||||
from app.assets.services.tagging import list_tag_histogram
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
@ -164,23 +161,11 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
preview_url = None
|
||||
else:
|
||||
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
|
||||
if result.ref.file_path:
|
||||
display_name = compute_display_name(result.ref.file_path)
|
||||
# In-root loader path (model category dropped): what model loaders consume.
|
||||
# Persisted at scan/ingest; fall back to computing for rows written
|
||||
# before the column existed.
|
||||
loader_path = result.ref.loader_path
|
||||
if loader_path is None:
|
||||
loader_path = compute_loader_path(result.ref.file_path)
|
||||
else:
|
||||
display_name, loader_path = None, None
|
||||
asset_content_hash = result.asset.hash if result.asset else None
|
||||
return schemas_out.Asset(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
hash=asset_content_hash,
|
||||
loader_path=loader_path,
|
||||
display_name=display_name,
|
||||
asset_hash=asset_content_hash,
|
||||
size=int(result.asset.size_bytes) if result.asset else None,
|
||||
mime_type=result.asset.mime_type if result.asset else None,
|
||||
@ -431,6 +416,17 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
400, "INVALID_BODY", f"Validation failed: {ve.json()}"
|
||||
)
|
||||
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
if (
|
||||
len(spec.tags) < 2
|
||||
or spec.tags[1] not in folder_paths.folder_names_and_paths
|
||||
):
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
category = spec.tags[1] if len(spec.tags) >= 2 else ""
|
||||
return _build_error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{category}'"
|
||||
)
|
||||
|
||||
try:
|
||||
# Fast path: hash exists, create AssetReference without writing anything
|
||||
if spec.hash and parsed.provided_hash_exists is True:
|
||||
@ -474,7 +470,7 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
return _build_error_response(400, e.code, str(e))
|
||||
except ValueError as e:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
return _build_error_response(400, "INVALID_BODY", str(e))
|
||||
return _build_error_response(400, "BAD_REQUEST", str(e))
|
||||
except HashMismatchError as e:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
return _build_error_response(400, "HASH_MISMATCH", str(e))
|
||||
|
||||
@ -140,7 +140,7 @@ class CreateFromHashBody(BaseModel):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
out = [str(t).strip() for t in v if str(t).strip()]
|
||||
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||
seen = set()
|
||||
dedup = []
|
||||
for t in out:
|
||||
@ -149,7 +149,7 @@ class CreateFromHashBody(BaseModel):
|
||||
dedup.append(t)
|
||||
return dedup
|
||||
if isinstance(v, str):
|
||||
return list(dict.fromkeys(t.strip() for t in v.split(",") if t.strip()))
|
||||
return [t.strip().lower() for t in v.split(",") if t.strip()]
|
||||
return []
|
||||
|
||||
|
||||
@ -206,7 +206,7 @@ class TagsListQuery(BaseModel):
|
||||
if v is None:
|
||||
return v
|
||||
v = v.strip()
|
||||
return v or None
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
@ -220,7 +220,7 @@ class TagsAdd(BaseModel):
|
||||
for t in v:
|
||||
if not isinstance(t, str):
|
||||
raise TypeError("tags must be strings")
|
||||
tnorm = t.strip()
|
||||
tnorm = t.strip().lower()
|
||||
if tnorm:
|
||||
out.append(tnorm)
|
||||
seen = set()
|
||||
@ -239,8 +239,8 @@ class TagsRemove(TagsAdd):
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
|
||||
- tags: labels plus one destination role ('models'|'input'|'output') for new bytes;
|
||||
if role == 'models', exactly one model_type:<folder_name> tag is required
|
||||
- tags: optional list; if provided, first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
|
||||
@ -309,7 +309,7 @@ class UploadAssetSpec(BaseModel):
|
||||
norm = []
|
||||
seen = set()
|
||||
for t in items:
|
||||
tnorm = str(t).strip()
|
||||
tnorm = str(t).strip().lower()
|
||||
if tnorm and tnorm not in seen:
|
||||
seen.add(tnorm)
|
||||
norm.append(tnorm)
|
||||
@ -335,4 +335,14 @@ class UploadAssetSpec(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("at least one tag is required for uploads")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError(
|
||||
"models uploads require a category tag as the second tag"
|
||||
)
|
||||
return self
|
||||
|
||||
@ -9,20 +9,8 @@ class Asset(BaseModel):
|
||||
``id`` here is the AssetReference id, not the content-addressed Asset id."""
|
||||
|
||||
id: str
|
||||
name: str = Field(
|
||||
...,
|
||||
deprecated=True,
|
||||
description="Reference label, often caller-provided or derived from the filename. Deprecated for storage path/display semantics; use `loader_path` and `display_name` when present.",
|
||||
)
|
||||
name: str
|
||||
hash: str | None = None
|
||||
loader_path: str | None = Field(
|
||||
default=None,
|
||||
description="In-root loader path for filesystem-backed assets: the path relative to its storage root with the top-level model category dropped (e.g. `models/checkpoints/foo/bar.safetensors` -> `foo/bar.safetensors`). This is the value model loaders consume. `None` when the file is not within a recognized root or model category.",
|
||||
)
|
||||
display_name: str | None = Field(
|
||||
default=None,
|
||||
description="Human-facing label for filesystem-backed assets: the path below the top-level storage namespace (e.g. `checkpoints/foo/bar.safetensors` under `models/`). Not unique.",
|
||||
)
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
|
||||
@ -140,6 +140,7 @@ async def parse_multipart_upload(
|
||||
provided_mime_type = ((await field.text()) or "").strip() or None
|
||||
elif fname == "preview_id":
|
||||
provided_preview_id = ((await field.text()) or "").strip() or None
|
||||
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
raise UploadError(
|
||||
400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'."
|
||||
|
||||
@ -76,10 +76,6 @@ class AssetReference(Base):
|
||||
|
||||
# Cache state fields (from former AssetCacheState)
|
||||
file_path: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# In-root loader path derived from file_path at scan/ingest time (model
|
||||
# category dropped). Persisted so responses read it directly instead of
|
||||
# re-resolving against every registered model-folder base per request.
|
||||
loader_path: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
@ -650,7 +650,6 @@ def upsert_reference(
|
||||
name: str,
|
||||
mtime_ns: int,
|
||||
owner_id: str = "",
|
||||
loader_path: str | None = None,
|
||||
) -> tuple[bool, bool]:
|
||||
"""Upsert a reference by file_path. Returns (created, updated).
|
||||
|
||||
@ -660,7 +659,6 @@ def upsert_reference(
|
||||
vals = {
|
||||
"asset_id": asset_id,
|
||||
"file_path": file_path,
|
||||
"loader_path": loader_path,
|
||||
"name": name,
|
||||
"owner_id": owner_id,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
|
||||
@ -265,8 +265,6 @@ def list_tags_with_usage(
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
prefix_filter = prefix.strip() if prefix else ""
|
||||
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetReferenceTag.tag_name.label("tag_name"),
|
||||
@ -295,8 +293,9 @@ def list_tags_with_usage(
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix_filter:
|
||||
q = q.where(func.substr(Tag.name, 1, len(prefix_filter)) == prefix_filter)
|
||||
if prefix:
|
||||
escaped, esc = escape_sql_like_string(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
@ -307,8 +306,9 @@ def list_tags_with_usage(
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix_filter:
|
||||
total_q = total_q.where(func.substr(Tag.name, 1, len(prefix_filter)) == prefix_filter)
|
||||
if prefix:
|
||||
escaped, esc = escape_sql_like_string(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
visible_tags_sq = (
|
||||
select(AssetReferenceTag.tag_name)
|
||||
|
||||
@ -41,10 +41,10 @@ def get_utc_now() -> datetime:
|
||||
def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
"""
|
||||
Normalize a list of tags by:
|
||||
- Stripping whitespace.
|
||||
- Removing exact duplicates while preserving order and case.
|
||||
- Stripping whitespace and converting to lowercase.
|
||||
- Removing duplicates.
|
||||
"""
|
||||
return list(dict.fromkeys(t.strip() for t in (tags or []) if (t or "").strip()))
|
||||
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
|
||||
|
||||
|
||||
def validate_blake3_hash(s: str) -> str:
|
||||
|
||||
@ -36,7 +36,7 @@ from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
|
||||
from app.assets.services.image_dimensions import extract_image_dimensions
|
||||
from app.assets.services.metadata_extract import extract_file_metadata
|
||||
from app.assets.services.path_utils import (
|
||||
compute_loader_path,
|
||||
compute_relative_filename,
|
||||
get_comfy_models_folders,
|
||||
get_name_and_tags_from_asset_path,
|
||||
)
|
||||
@ -63,7 +63,7 @@ RootType = Literal["models", "input", "output"]
|
||||
def get_prefixes_for_root(root: RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths, _exts in get_comfy_models_folders():
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
@ -81,7 +81,7 @@ def get_all_known_prefixes() -> list[str]:
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases, _exts in get_comfy_models_folders():
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
if not all(is_visible(part) for part in Path(rel_path).parts):
|
||||
@ -308,7 +308,7 @@ def build_asset_specs(
|
||||
if not stat_p.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||
rel_fname = compute_loader_path(abs_p)
|
||||
rel_fname = compute_relative_filename(abs_p)
|
||||
|
||||
# Extract metadata (tier 1: filesystem, tier 2: safetensors header)
|
||||
metadata = None
|
||||
@ -430,7 +430,7 @@ def enrich_asset(
|
||||
return new_level
|
||||
|
||||
initial_mtime_ns = get_mtime_ns(stat_p)
|
||||
rel_fname = compute_loader_path(file_path)
|
||||
rel_fname = compute_relative_filename(file_path)
|
||||
mime_type: str | None = None
|
||||
metadata = None
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ from app.assets.database.queries import (
|
||||
update_reference_updated_at,
|
||||
)
|
||||
from app.assets.helpers import select_best_live_path
|
||||
from app.assets.services.path_utils import compute_loader_path
|
||||
from app.assets.services.path_utils import compute_relative_filename
|
||||
from app.assets.services.schemas import (
|
||||
AssetData,
|
||||
AssetDetailResult,
|
||||
@ -91,7 +91,7 @@ def update_asset_metadata(
|
||||
update_reference_name(session, reference_id=reference_id, name=name)
|
||||
touched = True
|
||||
|
||||
computed_filename = compute_loader_path(ref.file_path) if ref.file_path else None
|
||||
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
|
||||
|
||||
new_meta: dict | None = None
|
||||
if user_metadata is not None:
|
||||
|
||||
@ -56,7 +56,6 @@ class ReferenceRow(TypedDict):
|
||||
id: str
|
||||
asset_id: str
|
||||
file_path: str
|
||||
loader_path: str | None
|
||||
mtime_ns: int
|
||||
owner_id: str
|
||||
name: str
|
||||
@ -135,14 +134,6 @@ def batch_insert_seed_assets(
|
||||
|
||||
for spec in specs:
|
||||
absolute_path = os.path.abspath(spec["abs_path"])
|
||||
existing_asset_id = path_to_asset_id.get(absolute_path)
|
||||
if existing_asset_id is not None:
|
||||
existing_tags = asset_id_to_ref_data[existing_asset_id]["tags"]
|
||||
asset_id_to_ref_data[existing_asset_id]["tags"] = list(
|
||||
dict.fromkeys([*existing_tags, *spec["tags"]])
|
||||
)
|
||||
continue
|
||||
|
||||
asset_id = str(uuid.uuid4())
|
||||
reference_id = str(uuid.uuid4())
|
||||
absolute_path_list.append(absolute_path)
|
||||
@ -173,8 +164,6 @@ def batch_insert_seed_assets(
|
||||
"id": reference_id,
|
||||
"asset_id": asset_id,
|
||||
"file_path": absolute_path,
|
||||
# spec["fname"] is compute_loader_path(abs_path) from build_asset_specs.
|
||||
"loader_path": spec["fname"],
|
||||
"mtime_ns": spec["mtime_ns"],
|
||||
"owner_id": owner_id,
|
||||
"name": spec["info_name"],
|
||||
|
||||
@ -33,9 +33,8 @@ from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.image_dimensions import extract_image_dimensions
|
||||
from app.assets.services.path_utils import (
|
||||
compute_loader_path,
|
||||
compute_relative_filename,
|
||||
get_name_and_tags_from_asset_path,
|
||||
get_path_derived_tags_from_path,
|
||||
resolve_destination_from_tags,
|
||||
validate_path_within_base,
|
||||
)
|
||||
@ -92,7 +91,6 @@ def _ingest_file_from_path(
|
||||
name=info_name or os.path.basename(locator),
|
||||
mtime_ns=mtime_ns,
|
||||
owner_id=owner_id,
|
||||
loader_path=compute_loader_path(locator),
|
||||
)
|
||||
|
||||
# Get the reference we just created/updated
|
||||
@ -103,32 +101,17 @@ def _ingest_file_from_path(
|
||||
if preview_id and ref.preview_id != preview_id:
|
||||
ref.preview_id = preview_id
|
||||
|
||||
try:
|
||||
backend_tags = get_path_derived_tags_from_path(locator)
|
||||
except ValueError:
|
||||
backend_tags = []
|
||||
caller_tags = normalize_tags(tags)
|
||||
backend_tags = normalize_tags(backend_tags)
|
||||
all_tags = normalize_tags([*caller_tags, *backend_tags])
|
||||
if all_tags:
|
||||
norm = normalize_tags(list(tags))
|
||||
if norm:
|
||||
if require_existing_tags:
|
||||
validate_tags_exist(session, all_tags)
|
||||
if backend_tags:
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=backend_tags,
|
||||
origin="automatic",
|
||||
create_if_missing=not require_existing_tags,
|
||||
)
|
||||
if caller_tags:
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=caller_tags,
|
||||
origin=tag_origin,
|
||||
create_if_missing=not require_existing_tags,
|
||||
)
|
||||
validate_tags_exist(session, norm)
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=norm,
|
||||
origin=tag_origin,
|
||||
create_if_missing=not require_existing_tags,
|
||||
)
|
||||
|
||||
_update_metadata_with_filename(
|
||||
session,
|
||||
@ -305,7 +288,7 @@ def _register_existing_asset(
|
||||
return result
|
||||
|
||||
new_meta = dict(user_metadata)
|
||||
computed_filename = compute_loader_path(ref.file_path) if ref.file_path else None
|
||||
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
@ -352,7 +335,7 @@ def _update_metadata_with_filename(
|
||||
current_metadata: dict | None,
|
||||
user_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
computed_filename = compute_loader_path(file_path) if file_path else None
|
||||
computed_filename = compute_relative_filename(file_path) if file_path else None
|
||||
|
||||
current_meta = current_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
@ -491,10 +474,6 @@ def upload_from_temp_path(
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
if existing is not None:
|
||||
# Once content is already known, duplicate byte uploads are treated as
|
||||
# reference-only creation. Request tags are labels only here: do not
|
||||
# require upload destination tags, do not move bytes, and do not
|
||||
# synthesize path-derived classification or uploaded provenance.
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
@ -556,7 +535,7 @@ def upload_from_temp_path(
|
||||
owner_id=owner_id,
|
||||
preview_id=preview_id,
|
||||
user_metadata=user_metadata or {},
|
||||
tags=[*(tags or []), "uploaded"],
|
||||
tags=tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
@ -590,19 +569,15 @@ def register_file_in_place(
|
||||
) -> UploadResult:
|
||||
"""Register an already-saved file in the asset database without moving it.
|
||||
|
||||
This helper is used by upload paths that have already written bytes before
|
||||
registering the file, so it records the same ``uploaded`` tag as the
|
||||
multipart byte-upload path.
|
||||
|
||||
Tags are derived from trusted filesystem classification and merged with any
|
||||
caller-provided tags, matching the behavior of the scanner.
|
||||
Tags are derived from the filesystem path (root category + subfolder names),
|
||||
merged with any caller-provided tags, matching the behavior of the scanner.
|
||||
If the path is not under a known root, only the caller-provided tags are used.
|
||||
"""
|
||||
try:
|
||||
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||
except ValueError:
|
||||
path_tags = []
|
||||
merged_tags = normalize_tags([*path_tags, *tags, "uploaded"])
|
||||
merged_tags = normalize_tags([*path_tags, *tags])
|
||||
|
||||
try:
|
||||
digest, _ = hashing.compute_blake3_hash(abs_path)
|
||||
|
||||
@ -3,66 +3,59 @@ from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import folder_paths
|
||||
from app.assets.helpers import normalize_tags
|
||||
|
||||
|
||||
_NON_MODEL_FOLDER_NAMES = frozenset({"configs", "custom_nodes"})
|
||||
_KNOWN_SUBFOLDER_TAGS = frozenset({"3d", "pasted", "painter", "threed", "webcam"})
|
||||
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
|
||||
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str], set[str]]]:
|
||||
"""Build list of (folder_name, base_paths[], extensions) for all model locations.
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build list of (folder_name, base_paths[]) for all model locations.
|
||||
|
||||
Includes every category registered in folder_names_and_paths,
|
||||
regardless of whether its paths are under the main models_dir,
|
||||
but excludes non-model entries like configs and custom_nodes.
|
||||
|
||||
An empty extensions set means the category accepts any extension,
|
||||
matching folder_paths.filter_files_extensions semantics.
|
||||
but excludes non-model entries like custom_nodes.
|
||||
"""
|
||||
targets: list[tuple[str, list[str], set[str]]] = []
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
for name, values in folder_paths.folder_names_and_paths.items():
|
||||
if name in _NON_MODEL_FOLDER_NAMES:
|
||||
continue
|
||||
paths, exts = values[0], values[1]
|
||||
paths, _exts = values[0], values[1]
|
||||
if paths:
|
||||
targets.append((name, paths, set(exts)))
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps upload routing tags -> (base_dir, subdirs_for_fs).
|
||||
|
||||
The request tags are only used to choose the write destination. Extra tags
|
||||
remain labels; they do not become path components or trusted classification.
|
||||
"""
|
||||
destination_roles = [t for t in tags if t in {"input", "models", "output"}]
|
||||
if len(destination_roles) != 1:
|
||||
raise ValueError("uploads require exactly one destination role: input, models, or output")
|
||||
|
||||
root = destination_roles[0]
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
if not tags:
|
||||
raise ValueError("tags must not be empty")
|
||||
root = tags[0].lower()
|
||||
if root == "models":
|
||||
model_type_tags = [t for t in tags if t.startswith("model_type:")]
|
||||
if len(model_type_tags) != 1:
|
||||
raise ValueError("models uploads require exactly one model_type:<folder_name> tag")
|
||||
folder_name = model_type_tags[0].split(":", 1)[1]
|
||||
if not folder_name:
|
||||
raise ValueError("models uploads require exactly one model_type:<folder_name> tag")
|
||||
model_folder_paths = {
|
||||
name: paths for name, paths, _exts in get_comfy_models_folders()
|
||||
}
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = model_folder_paths[folder_name]
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{folder_name}'")
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{folder_name}'")
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
elif root == "input":
|
||||
base_dir = os.path.abspath(folder_paths.get_input_directory())
|
||||
else:
|
||||
raw_subdirs = tags[1:]
|
||||
elif root == "output":
|
||||
base_dir = os.path.abspath(folder_paths.get_output_directory())
|
||||
raw_subdirs = tags[1:]
|
||||
else:
|
||||
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
|
||||
_sep_chars = frozenset(("/", "\\", os.sep))
|
||||
for i in raw_subdirs:
|
||||
if i in (".", "..") or _sep_chars & set(i):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, []
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
|
||||
def validate_path_within_base(candidate: str, base: str) -> None:
|
||||
@ -72,79 +65,14 @@ def validate_path_within_base(candidate: str, base: str) -> None:
|
||||
raise ValueError("destination escapes base directory")
|
||||
|
||||
|
||||
def _compute_relative_path(child: str, parent: str) -> str:
|
||||
rel = os.path.relpath(os.path.abspath(child), os.path.abspath(parent))
|
||||
if rel == ".":
|
||||
return ""
|
||||
return rel.replace(os.sep, "/")
|
||||
|
||||
|
||||
def _is_relative_to(child: str, parent: str) -> bool:
|
||||
return Path(os.path.abspath(child)).is_relative_to(os.path.abspath(parent))
|
||||
|
||||
|
||||
def compute_asset_response_paths(file_path: str) -> tuple[str, str | None] | None:
|
||||
"""Return (logical_path, display_name) for a file path.
|
||||
|
||||
``logical_path`` is the internal namespaced storage locator (e.g.
|
||||
``models/checkpoints/foo/bar.safetensors``); ``display_name`` is the
|
||||
human-facing label below that namespace, served on Asset responses. These
|
||||
are storage locators, not model-loader namespaces. Registered model-folder
|
||||
membership is represented by backend tags such as
|
||||
``model_type:<folder_name>``; these paths only use known storage roots.
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
candidates: list[tuple[int, int, str, str]] = []
|
||||
|
||||
for order, (namespace, base) in enumerate(
|
||||
(
|
||||
("input", folder_paths.get_input_directory()),
|
||||
("output", folder_paths.get_output_directory()),
|
||||
("temp", folder_paths.get_temp_directory()),
|
||||
("models", getattr(folder_paths, "models_dir", "")),
|
||||
)
|
||||
):
|
||||
if not base:
|
||||
continue
|
||||
base_abs = os.path.abspath(base)
|
||||
if _is_relative_to(fp_abs, base_abs):
|
||||
candidates.append((len(base_abs), -order, namespace, base_abs))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
_base_len, _order, namespace, base = max(candidates)
|
||||
rel = _compute_relative_path(fp_abs, base)
|
||||
public_path = f"{namespace}/{rel}" if rel else namespace
|
||||
return public_path, rel or None
|
||||
|
||||
|
||||
def compute_display_name(file_path: str) -> str | None:
|
||||
"""Return the asset's `display_name`, or None for unknown paths."""
|
||||
result = compute_asset_response_paths(file_path)
|
||||
return result[1] if result else None
|
||||
|
||||
|
||||
def compute_logical_path(file_path: str) -> str | None:
|
||||
"""Return the internal namespaced storage locator, or None for unknown paths."""
|
||||
result = compute_asset_response_paths(file_path)
|
||||
return result[0] if result else None
|
||||
|
||||
|
||||
def compute_loader_path(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the asset's in-root loader path: the path relative to the last
|
||||
well-known folder (the model category), using forward slashes, eg:
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
This is the value model loaders consume (the model category is dropped). It
|
||||
is persisted as ``AssetReference.loader_path`` and served as the public
|
||||
Asset response `loader_path` field. The human-facing `display_name` comes
|
||||
from compute_asset_response_paths().
|
||||
|
||||
For input/output/temp paths the full path relative to that root is returned.
|
||||
For paths outside any known root, returns None.
|
||||
For non-model paths, returns None.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_asset_category_and_relative_path(file_path)
|
||||
@ -188,10 +116,9 @@ def get_asset_category_and_relative_path(
|
||||
def _compute_relative(child: str, parent: str) -> str:
|
||||
# Normalize relative path, stripping any leading ".." components
|
||||
# by anchoring to root (os.sep) then computing relpath back from it.
|
||||
rel = os.path.relpath(
|
||||
return os.path.relpath(
|
||||
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
|
||||
)
|
||||
return "" if rel == "." else rel.replace(os.sep, "/")
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
@ -210,7 +137,7 @@ def get_asset_category_and_relative_path(
|
||||
|
||||
# 4) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases, _exts in get_comfy_models_folders():
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _check_is_within(fp_abs, base_abs):
|
||||
@ -222,112 +149,25 @@ def get_asset_category_and_relative_path(
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
normalized = os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
return "models", normalized.replace(os.sep, "/")
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not within input, output, temp, or configured model bases: {file_path}"
|
||||
)
|
||||
|
||||
|
||||
def get_backend_system_tags_from_path(path: str) -> list[str]:
|
||||
"""Return trusted backend tags derived from current filesystem facts.
|
||||
|
||||
The returned tags are only the backend-generated system tags: ``models``,
|
||||
``model_type:<folder_name>``, ``input``, ``output``, and ``temp``. Model
|
||||
type tags are based on registered folder names, not path components.
|
||||
|
||||
A ``model_type:<folder_name>`` tag is only emitted when the file's
|
||||
extension is accepted by that folder's registered extension set, so
|
||||
categories sharing a base directory (e.g. ``diffusion_models`` and a
|
||||
custom ``unet_gguf``) tag only the files they can actually load. Files
|
||||
under a model base whose extension matches no category still get the
|
||||
``models`` tag.
|
||||
"""
|
||||
fp_abs = os.path.abspath(path)
|
||||
fp_path = Path(fp_abs)
|
||||
tags: list[str] = []
|
||||
|
||||
def _add(tag: str) -> None:
|
||||
if tag not in tags:
|
||||
tags.append(tag)
|
||||
|
||||
for role, base in (
|
||||
("input", folder_paths.get_input_directory()),
|
||||
("output", folder_paths.get_output_directory()),
|
||||
("temp", folder_paths.get_temp_directory()),
|
||||
):
|
||||
if fp_path.is_relative_to(os.path.abspath(base)):
|
||||
_add(role)
|
||||
|
||||
ext = os.path.splitext(fp_abs)[1].lower()
|
||||
model_types: list[str] = []
|
||||
under_models_base = False
|
||||
for folder_name, bases, extensions in get_comfy_models_folders():
|
||||
for base in bases:
|
||||
if fp_path.is_relative_to(os.path.abspath(base)):
|
||||
under_models_base = True
|
||||
# Empty set accepts any extension, matching
|
||||
# folder_paths.filter_files_extensions semantics.
|
||||
if not extensions or ext in extensions:
|
||||
model_types.append(folder_name)
|
||||
break
|
||||
|
||||
if under_models_base:
|
||||
_add("models")
|
||||
for folder_name in model_types:
|
||||
_add(f"model_type:{folder_name}")
|
||||
|
||||
if not tags:
|
||||
raise ValueError(
|
||||
f"Path is not within input, output, temp, or configured model bases: {path}"
|
||||
)
|
||||
return tags
|
||||
|
||||
|
||||
def get_known_subfolder_tags(subfolder: str | None) -> list[str]:
|
||||
"""Return tags for known UI/input subfolder names."""
|
||||
if subfolder in _KNOWN_SUBFOLDER_TAGS:
|
||||
return [subfolder]
|
||||
return []
|
||||
|
||||
|
||||
def get_known_input_subfolder_tags_from_path(path: str) -> list[str]:
|
||||
"""Return known input-layout tags for files in canonical input subfolders.
|
||||
|
||||
These are compatibility tags for current UI-origin input directories such as
|
||||
``pasted`` and ``webcam``. They are intentionally narrow: only files directly
|
||||
inside a known top-level input directory receive the matching tag.
|
||||
"""
|
||||
fp_abs = os.path.abspath(path)
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if not Path(fp_abs).is_relative_to(input_base):
|
||||
return []
|
||||
|
||||
rel = os.path.relpath(fp_abs, input_base)
|
||||
parts = Path(rel).parts
|
||||
if len(parts) == 2:
|
||||
return get_known_subfolder_tags(parts[0])
|
||||
return []
|
||||
|
||||
|
||||
def get_path_derived_tags_from_path(path: str) -> list[str]:
|
||||
"""Return all backend-derived tags for an asset path."""
|
||||
tags = get_backend_system_tags_from_path(path)
|
||||
for tag in get_known_input_subfolder_tags_from_path(path):
|
||||
if tag not in tags:
|
||||
tags.append(tag)
|
||||
return tags
|
||||
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return (name, tags) derived from a filesystem path.
|
||||
|
||||
- name: base filename with extension
|
||||
- tags: backend-derived tags from root/model classification and known input
|
||||
subfolder layout conventions
|
||||
- tags: [root_category] + parent folder names in order
|
||||
|
||||
Raises:
|
||||
ValueError: path does not belong to any known root.
|
||||
"""
|
||||
return Path(file_path).name, get_path_derived_tags_from_path(file_path)
|
||||
root_category, some_path = get_asset_category_and_relative_path(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [
|
||||
part for part in p.parent.parts if part not in (".", "..", p.anchor)
|
||||
]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
|
||||
@ -25,7 +25,6 @@ class ReferenceData:
|
||||
preview_id: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
loader_path: str | None = None
|
||||
system_metadata: dict[str, Any] | None = None
|
||||
job_id: str | None = None
|
||||
last_access_time: datetime | None = None
|
||||
@ -94,7 +93,6 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
|
||||
id=ref.id,
|
||||
name=ref.name,
|
||||
file_path=ref.file_path,
|
||||
loader_path=ref.loader_path,
|
||||
user_metadata=ref.user_metadata,
|
||||
preview_id=ref.preview_id,
|
||||
system_metadata=ref.system_metadata,
|
||||
|
||||
@ -35,17 +35,7 @@ class ModelFileManager:
|
||||
for folder in model_types:
|
||||
if folder in folder_black_list:
|
||||
continue
|
||||
# Effective display filter: the folder's registered extension
|
||||
# set, or the global supported_pt_extensions for match-all
|
||||
# folders (empty set), resolved live so runtime registrations
|
||||
# by custom nodes are reflected.
|
||||
registered = folder_paths.folder_names_and_paths[folder][1]
|
||||
effective = set(registered) if registered else set(folder_paths.supported_pt_extensions)
|
||||
output_folders.append({
|
||||
"name": folder,
|
||||
"folders": folder_paths.get_folder_paths(folder),
|
||||
"extensions": sorted(effective),
|
||||
})
|
||||
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
|
||||
return web.json_response(output_folders)
|
||||
|
||||
# NOTE: This is an experiment to replace `/models/{folder}`
|
||||
|
||||
@ -1216,7 +1216,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
bias_dtype=input.dtype,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=True,
|
||||
want_requant=want_requant,
|
||||
)
|
||||
weight = weight.to(dtype=input.dtype)
|
||||
else:
|
||||
|
||||
@ -100,7 +100,6 @@ def _parse_cli_feature_flags() -> dict[str, Any]:
|
||||
# Default server capabilities
|
||||
_CORE_FEATURE_FLAGS: dict[str, Any] = {
|
||||
"supports_preview_metadata": True,
|
||||
"supports_model_type_tags": True,
|
||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
"node_replacements": True,
|
||||
|
||||
@ -1261,6 +1261,158 @@ class DynamicSlot(ComfyTypeI):
|
||||
out_dict[input_type][finalized_id] = value
|
||||
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||
|
||||
@comfytype(io_type="COMFY_DYNAMICGROUP_V3")
|
||||
class DynamicGroup(ComfyTypeI):
|
||||
"""A repeatable group of widget inputs (e.g. lora_name + strength stacked into N rows).
|
||||
|
||||
At execution time the node receives a ``list[dict]`` where each element is a row.
|
||||
|
||||
Example::
|
||||
|
||||
io.DynamicGroup.Input(
|
||||
"loras",
|
||||
template=[
|
||||
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras")),
|
||||
io.Float.Input("strength", default=1.0, min=-100, max=100, step=0.01),
|
||||
],
|
||||
min=0,
|
||||
max=50,
|
||||
)
|
||||
# execute receives: loras: list[dict] = [{"lora_name": "x.safetensors", "strength": 1.0}, ...]
|
||||
"""
|
||||
|
||||
Type = list[dict[str, Any]]
|
||||
_MaxRows = 100
|
||||
|
||||
class Input(DynamicInput):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
template: list["Input"],
|
||||
min: int = 0,
|
||||
max: int = 50,
|
||||
display_name: str = None,
|
||||
optional: bool = False,
|
||||
tooltip: str = None,
|
||||
lazy: bool = None,
|
||||
extra_dict=None,
|
||||
group_name: str = "Group",
|
||||
):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
# Validate template entries: only WidgetInput subclasses, no nesting
|
||||
assert len(template) > 0, "DynamicGroup template must have at least one field."
|
||||
for t in template:
|
||||
assert isinstance(t, WidgetInput), (
|
||||
f"DynamicGroup template field '{t.id}' must be a WidgetInput subclass "
|
||||
f"(Combo, Float, Int, String, Boolean, Color). Got {type(t).__name__}."
|
||||
)
|
||||
assert not isinstance(t, DynamicInput), (
|
||||
f"DynamicGroup template field '{t.id}' must not be a DynamicInput. "
|
||||
"Nesting dynamic inputs inside DynamicGroup is not supported."
|
||||
)
|
||||
# Enforce unique field ids within template
|
||||
field_ids = [t.id for t in template]
|
||||
assert len(field_ids) == len(set(field_ids)), (
|
||||
f"DynamicGroup template field ids must be unique within a row. Got: {field_ids}"
|
||||
)
|
||||
# Reject "." in group id and template field ids: slot_id encoding uses "." as a
|
||||
# delimiter (<group_id>.<row>.<field_id>), so any "." in these names would cause
|
||||
# path.split(".") to produce the wrong number of segments during decoding.
|
||||
assert "." not in id, (
|
||||
f"DynamicGroup id must not contain '.'. Got: '{id}'"
|
||||
)
|
||||
for t in template:
|
||||
assert "." not in t.id, (
|
||||
f"DynamicGroup template field id must not contain '.'. Got: '{t.id}'"
|
||||
)
|
||||
assert min >= 0, "DynamicGroup min must be >= 0."
|
||||
assert max >= 1, "DynamicGroup max must be >= 1."
|
||||
assert max <= DynamicGroup._MaxRows, f"DynamicGroup max must be <= {DynamicGroup._MaxRows}."
|
||||
assert min <= max, "DynamicGroup min must be <= max."
|
||||
self.template = template
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.group_name = group_name
|
||||
|
||||
def get_all(self) -> list["Input"]:
|
||||
return [self] + list(self.template)
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": create_input_dict_v1(self.template),
|
||||
"min": self.min,
|
||||
"max": self.max,
|
||||
"group_name": self.group_name,
|
||||
})
|
||||
|
||||
def validate(self):
|
||||
for t in self.template:
|
||||
t.validate()
|
||||
|
||||
@staticmethod
|
||||
def _expand_schema_for_dynamic(
|
||||
out_dict: dict[str, Any],
|
||||
live_inputs: dict[str, Any],
|
||||
value: tuple[str, dict[str, Any]],
|
||||
input_type: str,
|
||||
curr_prefix: list[str] | None,
|
||||
):
|
||||
info = value[1]
|
||||
min_rows: int = info.get("min", 0)
|
||||
max_rows: int = info.get("max", DynamicGroup._MaxRows)
|
||||
template: dict[str, Any] = info.get("template", {})
|
||||
|
||||
# Collect all template field specs across required/optional sections
|
||||
field_specs: list[tuple[str, tuple[str, dict[str, Any]], bool]] = []
|
||||
for field_required_key in ("required", "optional"):
|
||||
section = template.get(field_required_key, {})
|
||||
is_required_field = field_required_key == "required"
|
||||
for field_id, field_value in section.items():
|
||||
field_specs.append((field_id, field_value, is_required_field))
|
||||
|
||||
# Determine how many rows are currently present by scanning live_inputs
|
||||
finalized_prefix = finalize_prefix(curr_prefix)
|
||||
present_rows = 0
|
||||
for live_key in live_inputs:
|
||||
# Keys look like "<prefix>.<row>.<field_id>"
|
||||
if live_key.startswith(finalized_prefix + "."):
|
||||
remainder = live_key[len(finalized_prefix) + 1:]
|
||||
parts = remainder.split(".", 1)
|
||||
if len(parts) >= 1:
|
||||
try:
|
||||
row_idx = int(parts[0])
|
||||
present_rows = max(present_rows, row_idx + 1)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if present_rows > max_rows:
|
||||
raise ValueError(
|
||||
f"DynamicGroup input '{finalized_prefix}' received {present_rows} rows but max is {max_rows}."
|
||||
)
|
||||
row_count = max(min_rows, present_rows)
|
||||
|
||||
for row in range(row_count):
|
||||
for field_id, field_value, is_required_field in field_specs:
|
||||
slot_id = f"{finalized_prefix}.{row}.{field_id}"
|
||||
# The first `min_rows` rows are required if the field itself is required
|
||||
if row < min_rows and is_required_field:
|
||||
out_dict["required"][slot_id] = field_value
|
||||
else:
|
||||
out_dict["optional"][slot_id] = field_value
|
||||
# Register into dynamic_paths so build_nested_inputs places value at the right path
|
||||
out_dict["dynamic_paths"][slot_id] = slot_id
|
||||
|
||||
# Track the list root path so build_nested_inputs can convert the index dict to a list
|
||||
out_dict.setdefault("list_paths", set()).add(finalized_prefix)
|
||||
|
||||
# Handle the empty case (0 rows) – emit an empty-list default for the parent.
|
||||
# This must only fire when there are genuinely no rows; otherwise the parent
|
||||
# path would clobber the per-row dict built from the slot ids above.
|
||||
if row_count == 0:
|
||||
out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix
|
||||
out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_LIST
|
||||
|
||||
|
||||
@comfytype(io_type="IMAGECOMPARE")
|
||||
class ImageCompare(ComfyTypeI):
|
||||
Type = dict
|
||||
@ -1418,6 +1570,8 @@ def setup_dynamic_input_funcs():
|
||||
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
|
||||
# DynamicSlot.Input
|
||||
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
|
||||
# DynamicGroup.Input
|
||||
register_dynamic_input_func(DynamicGroup.io_type, DynamicGroup._expand_schema_for_dynamic)
|
||||
|
||||
if len(DYNAMIC_INPUT_LOOKUP) == 0:
|
||||
setup_dynamic_input_funcs()
|
||||
@ -1429,6 +1583,8 @@ class V3Data(TypedDict):
|
||||
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
||||
dynamic_paths_default_value: dict[str, Any]
|
||||
'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.'
|
||||
list_paths: set[str]
|
||||
'Set of top-level keys whose index-keyed dict values should be converted to a sorted list[dict] after build_nested_inputs runs.'
|
||||
create_dynamic_tuple: bool
|
||||
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
||||
|
||||
@ -1770,6 +1926,7 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
"optional": {},
|
||||
"dynamic_paths": {},
|
||||
"dynamic_paths_default_value": {},
|
||||
"list_paths": set(),
|
||||
}
|
||||
d = d.copy()
|
||||
# ignore hidden for parsing
|
||||
@ -1785,6 +1942,10 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None)
|
||||
if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0:
|
||||
v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value
|
||||
# list_paths: keys whose nested dict should be post-converted to a sorted list[dict]
|
||||
list_paths = out_dict.pop("list_paths", None)
|
||||
if list_paths:
|
||||
v3_data["list_paths"] = list_paths
|
||||
return out_dict, hidden, v3_data
|
||||
|
||||
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
||||
@ -1820,10 +1981,12 @@ def add_to_dict_v1(i: Input, d: dict):
|
||||
|
||||
class DynamicPathsDefaultValue:
|
||||
EMPTY_DICT = "empty_dict"
|
||||
EMPTY_LIST = "empty_list"
|
||||
|
||||
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
paths = v3_data.get("dynamic_paths", None)
|
||||
default_value_dict = v3_data.get("dynamic_paths_default_value", {})
|
||||
list_paths: set[str] = v3_data.get("list_paths", set()) or set()
|
||||
if paths is None:
|
||||
return values
|
||||
values = values.copy()
|
||||
@ -1846,6 +2009,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
default_option = default_value_dict.get(key, None)
|
||||
if default_option == DynamicPathsDefaultValue.EMPTY_DICT:
|
||||
value = {}
|
||||
elif default_option == DynamicPathsDefaultValue.EMPTY_LIST:
|
||||
value = []
|
||||
if create_tuple:
|
||||
value = (value, key)
|
||||
current[p] = value
|
||||
@ -1853,6 +2018,34 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
current = current.setdefault(p, {})
|
||||
|
||||
values.update(result)
|
||||
|
||||
# Post-pass: convert index-keyed dicts to sorted lists for io.DynamicGroup fields
|
||||
for list_path in list_paths:
|
||||
parts = list_path.split(".")
|
||||
# Navigate to the parent container, then convert the leaf
|
||||
container = values
|
||||
for part in parts[:-1]:
|
||||
if not isinstance(container, dict) or part not in container:
|
||||
container = None
|
||||
break
|
||||
container = container[part]
|
||||
if container is None:
|
||||
continue
|
||||
leaf_key = parts[-1]
|
||||
leaf = container.get(leaf_key, None)
|
||||
if isinstance(leaf, dict):
|
||||
try:
|
||||
sorted_rows = [leaf[k] for k in sorted(leaf.keys(), key=int)]
|
||||
container[leaf_key] = sorted_rows
|
||||
except (ValueError, TypeError):
|
||||
# Keys are not all integers; leave as-is
|
||||
pass
|
||||
elif isinstance(leaf, list):
|
||||
# Already a list (e.g. the EMPTY_LIST default was applied above)
|
||||
pass
|
||||
elif leaf is None:
|
||||
container[leaf_key] = []
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@ -2417,7 +2610,9 @@ __all__ = [
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
"DynamicCombo",
|
||||
"DynamicSlot",
|
||||
"Autogrow",
|
||||
"DynamicGroup",
|
||||
# Other classes
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
|
||||
39
openapi.yaml
39
openapi.yaml
@ -7,18 +7,18 @@ components:
|
||||
description: Timestamp when the asset was created
|
||||
format: date-time
|
||||
type: string
|
||||
display_name:
|
||||
description: Display name of the asset. Mirrors name for backwards compatibility.
|
||||
nullable: true
|
||||
type: string
|
||||
file_path:
|
||||
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
|
||||
nullable: true
|
||||
type: string
|
||||
hash:
|
||||
description: Blake3 hash of the asset content.
|
||||
pattern: ^blake3:[a-f0-9]{64}$
|
||||
type: string
|
||||
loader_path:
|
||||
description: 'In-root loader path for filesystem-backed assets: the path relative to its storage root with the top-level model category dropped (e.g. `models/checkpoints/foo/bar.safetensors` -> `foo/bar.safetensors`). This is the value model loaders consume. `None` when the file is not within a recognized root or model category.'
|
||||
nullable: true
|
||||
type: string
|
||||
display_name:
|
||||
description: 'Human-facing label for filesystem-backed assets: the path below the top-level storage namespace (e.g. `checkpoints/foo/bar.safetensors` under `models/`). Not unique.'
|
||||
nullable: true
|
||||
type: string
|
||||
id:
|
||||
description: Unique identifier for the asset
|
||||
format: uuid
|
||||
@ -144,6 +144,14 @@ components:
|
||||
AssetUpdated:
|
||||
description: Response returned when an existing asset is successfully updated.
|
||||
properties:
|
||||
display_name:
|
||||
description: Display name of the asset. Mirrors name for backwards compatibility.
|
||||
nullable: true
|
||||
type: string
|
||||
file_path:
|
||||
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
|
||||
nullable: true
|
||||
type: string
|
||||
hash:
|
||||
description: Blake3 hash of the asset content.
|
||||
pattern: ^blake3:[a-f0-9]{64}$
|
||||
@ -775,14 +783,6 @@ components:
|
||||
ModelFolder:
|
||||
description: Represents a folder containing models
|
||||
properties:
|
||||
extensions:
|
||||
description: 'Effective file-extension display filter for this folder: the registered extension set, or the global supported model extensions for folders registered without one (match-all). Resolved live, so runtime registrations by custom nodes are reflected.'
|
||||
example:
|
||||
- .ckpt
|
||||
- .safetensors
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
folders:
|
||||
description: List of paths where models of this type are stored
|
||||
example:
|
||||
@ -1644,7 +1644,7 @@ paths:
|
||||
format: uuid
|
||||
type: string
|
||||
tags:
|
||||
description: JSON-encoded array of tag strings. For new byte uploads, include exactly one destination role (`input`, `output`, or `models`); `models` uploads also require exactly one `model_type:<folder_name>` tag. Extra tags are stored as labels and do not create path components.
|
||||
description: JSON-encoded array of freeform tag strings, e.g. '["models","checkpoint"]'. Common types include "models", "input", "output", and "temp", but any tag can be used in any order.
|
||||
type: string
|
||||
user_metadata:
|
||||
description: Custom JSON metadata as a string
|
||||
@ -1829,7 +1829,7 @@ paths:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Asset'
|
||||
$ref: '#/components/schemas/AssetUpdated'
|
||||
description: Asset updated successfully
|
||||
"400":
|
||||
content:
|
||||
@ -2470,9 +2470,6 @@ paths:
|
||||
supports_preview_metadata:
|
||||
description: Whether the server supports preview metadata
|
||||
type: boolean
|
||||
supports_model_type_tags:
|
||||
description: Whether the server supports namespaced model type asset tags
|
||||
type: boolean
|
||||
type: object
|
||||
description: Success
|
||||
headers:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.45.19
|
||||
comfyui-workflow-templates==0.10.7
|
||||
comfyui-embedded-docs==0.5.6
|
||||
comfyui-embedded-docs==0.5.5
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=16.0.0
|
||||
comfy-kitchen==0.2.14
|
||||
comfy-kitchen==0.2.13
|
||||
comfy-aimdo==0.4.10
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
|
||||
@ -46,7 +46,6 @@ from comfy_api.internal import _ComfyNodeInternal
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.api.routes import register_assets_routes
|
||||
from app.assets.services.ingest import register_file_in_place
|
||||
from app.assets.services.path_utils import get_known_subfolder_tags
|
||||
from app.assets.services.asset_management import resolve_hash_to_path
|
||||
|
||||
from app.user_manager import UserManager
|
||||
@ -441,9 +440,7 @@ class PromptServer():
|
||||
if args.enable_assets:
|
||||
try:
|
||||
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
|
||||
tags = [tag]
|
||||
tags.extend(get_known_subfolder_tags(subfolder))
|
||||
result = register_file_in_place(abs_path=filepath, name=filename, tags=tags)
|
||||
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
|
||||
resp["asset"] = {
|
||||
"id": result.ref.id,
|
||||
"name": result.ref.name,
|
||||
|
||||
@ -24,28 +24,6 @@ def app(model_manager):
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
async def test_get_model_folders_includes_effective_extensions(aiohttp_client, app, tmp_path):
|
||||
"""Folders expose their effective display filter: the registered extension
|
||||
set, or the global supported_pt_extensions for match-all (empty) folders."""
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_checkpoints': ([str(tmp_path)], {'.safetensors', '.ckpt'}),
|
||||
'test_configs': ([str(tmp_path)], ['.yaml']),
|
||||
'test_match_all': ([str(tmp_path)], set()),
|
||||
'configs': ([str(tmp_path)], ['.yaml']),
|
||||
}), patch('folder_paths.supported_pt_extensions', {'.safetensors', '.bin'}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models')
|
||||
|
||||
assert response.status == 200
|
||||
folders = {f['name']: f for f in await response.json()}
|
||||
|
||||
assert 'configs' not in folders # blocklisted
|
||||
assert folders['test_checkpoints']['folders'] == [str(tmp_path)]
|
||||
assert folders['test_checkpoints']['extensions'] == ['.ckpt', '.safetensors']
|
||||
assert folders['test_configs']['extensions'] == ['.yaml']
|
||||
# Match-all folders substitute the live global set.
|
||||
assert folders['test_match_all']['extensions'] == ['.bin', '.safetensors']
|
||||
|
||||
async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
|
||||
img = Image.new('RGB', (100, 100), 'white')
|
||||
img_byte_arr = BytesIO()
|
||||
|
||||
@ -8,7 +8,6 @@ upgrade/downgrade for 0003+.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
|
||||
import pytest
|
||||
from alembic import command
|
||||
@ -31,12 +30,6 @@ def _make_config(db_path: str) -> Config:
|
||||
return cfg
|
||||
|
||||
|
||||
def _sqlite_path(cfg: Config) -> str:
|
||||
url = cfg.get_main_option("sqlalchemy.url")
|
||||
assert url is not None and url.startswith("sqlite:///")
|
||||
return url.removeprefix("sqlite:///")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration_db(tmp_path):
|
||||
"""Yield an alembic Config pre-upgraded to the baseline revision."""
|
||||
@ -62,26 +55,3 @@ def test_upgrade_downgrade_cycle(migration_db):
|
||||
command.upgrade(migration_db, "head")
|
||||
command.downgrade(migration_db, _BASELINE)
|
||||
command.upgrade(migration_db, "head")
|
||||
|
||||
|
||||
def test_case_sensitive_tags_downgrade_normalizes_existing_tags(migration_db):
|
||||
"""Downgrading 0005 folds mixed-case tag vocabulary before restoring CHECK."""
|
||||
command.upgrade(migration_db, "0005_allow_case_sensitive_tags")
|
||||
|
||||
db_path = _sqlite_path(migration_db)
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
conn.execute("INSERT INTO tags(name) VALUES (?)", ("NewTag",))
|
||||
conn.execute("INSERT INTO tags(name) VALUES (?)", ("newtag",))
|
||||
conn.execute("INSERT INTO tags(name) VALUES (?)", ("model_type:LLM",))
|
||||
|
||||
command.downgrade(migration_db, "0004_drop_tag_type")
|
||||
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
tags = {row[0] for row in conn.execute("SELECT name FROM tags")}
|
||||
assert "newtag" in tags
|
||||
assert "model_type:llm" in tags
|
||||
assert "NewTag" not in tags
|
||||
assert "model_type:LLM" not in tags
|
||||
|
||||
with pytest.raises(sqlite3.IntegrityError):
|
||||
conn.execute("INSERT INTO tags(name) VALUES (?)", ("Upper",))
|
||||
|
||||
@ -234,7 +234,7 @@ def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_bas
|
||||
p = getattr(request, "param", {}) or {}
|
||||
tags: Optional[list[str]] = p.get("tags")
|
||||
if tags is None:
|
||||
tags = ["models", "model_type:checkpoints", "unit-tests", "alpha"]
|
||||
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||
meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None}
|
||||
# Unique content per test so the seed always creates a fresh asset (201).
|
||||
# Delete is now always a soft delete, so content from a prior test survives
|
||||
|
||||
@ -133,66 +133,6 @@ class TestListReferencesPage:
|
||||
assert total == 1
|
||||
assert refs[0].name == "tagged"
|
||||
|
||||
def test_include_tags_filter_ands_persisted_model_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash-model-tags")
|
||||
checkpoint = _make_reference(session, asset, name="checkpoint")
|
||||
lora = _make_reference(session, asset, name="lora")
|
||||
input_ref = _make_reference(session, asset, name="input")
|
||||
ensure_tags_exist(
|
||||
session,
|
||||
["models", "model_type:checkpoints", "model_type:loras", "unit-tests"],
|
||||
)
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=checkpoint.id,
|
||||
tags=["models", "model_type:checkpoints", "unit-tests"],
|
||||
origin="automatic",
|
||||
)
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=lora.id,
|
||||
tags=["models", "model_type:loras", "unit-tests"],
|
||||
origin="automatic",
|
||||
)
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=input_ref.id,
|
||||
tags=["unit-tests"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(
|
||||
session,
|
||||
include_tags=["models", "model_type:checkpoints", "unit-tests"],
|
||||
)
|
||||
|
||||
assert total == 1
|
||||
assert refs[0].id == checkpoint.id
|
||||
|
||||
def test_include_tags_filter_preserves_model_type_case(self, session: Session):
|
||||
asset = _make_asset(session, "hash-model-case")
|
||||
ref = _make_reference(session, asset, name="llm")
|
||||
ensure_tags_exist(session, ["models", "model_type:LLM"])
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "model_type:LLM"],
|
||||
origin="automatic",
|
||||
)
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(
|
||||
session, include_tags=["models", "model_type:LLM"]
|
||||
)
|
||||
refs_lower, _, total_lower = list_references_page(
|
||||
session, include_tags=["models", "model_type:llm"]
|
||||
)
|
||||
|
||||
assert total == 1
|
||||
assert refs[0].id == ref.id
|
||||
assert total_lower == 0
|
||||
assert refs_lower == []
|
||||
|
||||
def test_exclude_tags_filter(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, name="keep")
|
||||
|
||||
@ -58,7 +58,7 @@ class TestEnsureTagsExist:
|
||||
session.commit()
|
||||
|
||||
tags = session.query(Tag).all()
|
||||
assert {t.name for t in tags} == {"ALPHA", "Beta", "alpha"}
|
||||
assert {t.name for t in tags} == {"alpha", "beta"}
|
||||
|
||||
def test_empty_list_is_noop(self, session: Session):
|
||||
ensure_tags_exist(session, [])
|
||||
@ -258,16 +258,6 @@ class TestListTagsWithUsage:
|
||||
tag_names = {name for name, _ in rows}
|
||||
assert tag_names == {"alpha", "alphabet"}
|
||||
|
||||
def test_prefix_filter_is_case_sensitive(self, session: Session):
|
||||
ensure_tags_exist(session, ["model_type:LLM", "model_type:llm"])
|
||||
session.commit()
|
||||
|
||||
rows, total = list_tags_with_usage(session, prefix="model_type:L")
|
||||
|
||||
tag_names = {name for name, _ in rows}
|
||||
assert tag_names == {"model_type:LLM"}
|
||||
assert total == 1
|
||||
|
||||
def test_order_by_name(self, session: Session):
|
||||
ensure_tags_exist(session, ["zebra", "alpha", "middle"])
|
||||
session.commit()
|
||||
|
||||
@ -1,82 +0,0 @@
|
||||
"""Tests for how _build_asset_response derives the response `loader_path`.
|
||||
|
||||
Guards the persist-and-read contract: the response reads the stored
|
||||
`loader_path` directly, and only recomputes when the column is NULL (rows
|
||||
written before the column existed).
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.assets.api.routes import _build_asset_response
|
||||
from app.assets.services.schemas import AssetDetailResult, ReferenceData
|
||||
|
||||
_TS = datetime(2024, 1, 1, 0, 0, 0)
|
||||
|
||||
|
||||
def _make_result(
|
||||
*, file_path: str | None, loader_path: str | None
|
||||
) -> AssetDetailResult:
|
||||
ref = ReferenceData(
|
||||
id="ref-1",
|
||||
name="model.safetensors",
|
||||
file_path=file_path,
|
||||
loader_path=loader_path,
|
||||
user_metadata=None,
|
||||
preview_id=None,
|
||||
created_at=_TS,
|
||||
updated_at=_TS,
|
||||
last_access_time=_TS,
|
||||
)
|
||||
return AssetDetailResult(ref=ref, asset=None, tags=[])
|
||||
|
||||
|
||||
def test_uses_persisted_loader_path_without_recomputing():
|
||||
"""A stored loader_path is returned verbatim, not re-derived from file_path.
|
||||
|
||||
The sentinel value could never be produced by compute_loader_path for this
|
||||
file_path, so seeing it in the response proves the stored column is read.
|
||||
"""
|
||||
result = _make_result(
|
||||
file_path="/unmatched/root/model.safetensors",
|
||||
loader_path="SENTINEL/stored.safetensors",
|
||||
)
|
||||
|
||||
resp = _build_asset_response(result)
|
||||
|
||||
assert resp.loader_path == "SENTINEL/stored.safetensors"
|
||||
|
||||
|
||||
def test_falls_back_to_compute_when_stored_loader_path_is_null(tmp_path: Path):
|
||||
"""A NULL column (pre-migration row) is backfilled at read time."""
|
||||
models = tmp_path / "models"
|
||||
ckpt = models / "checkpoints"
|
||||
ckpt.mkdir(parents=True)
|
||||
f = ckpt / "bar.safetensors"
|
||||
f.touch()
|
||||
|
||||
with patch("app.assets.services.path_utils.folder_paths") as mock_fp, patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[("checkpoints", [str(ckpt)], {".safetensors"})],
|
||||
):
|
||||
mock_fp.get_input_directory.return_value = str(tmp_path / "in")
|
||||
mock_fp.get_output_directory.return_value = str(tmp_path / "out")
|
||||
mock_fp.get_temp_directory.return_value = str(tmp_path / "tmp")
|
||||
mock_fp.models_dir = str(models)
|
||||
|
||||
result = _make_result(file_path=str(f), loader_path=None)
|
||||
resp = _build_asset_response(result)
|
||||
|
||||
assert resp.loader_path == "bar.safetensors"
|
||||
assert resp.display_name == "checkpoints/bar.safetensors"
|
||||
|
||||
|
||||
def test_all_path_fields_null_without_file_path():
|
||||
"""API-created / hash-only references (no file_path) expose no paths."""
|
||||
result = _make_result(file_path=None, loader_path=None)
|
||||
|
||||
resp = _build_asset_response(result)
|
||||
|
||||
assert resp.loader_path is None
|
||||
assert resp.display_name is None
|
||||
@ -1,14 +1,10 @@
|
||||
"""Tests for bulk ingest services."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.database.queries import get_reference_tags
|
||||
from app.assets.scanner import build_asset_specs
|
||||
from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets
|
||||
|
||||
|
||||
@ -105,184 +101,6 @@ class TestBatchInsertSeedAssets:
|
||||
asset = session.query(Asset).filter_by(id=ref.asset_id).first()
|
||||
assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}"
|
||||
|
||||
def test_duplicate_paths_merge_tags_before_insert(
|
||||
self, session: Session, temp_dir: Path
|
||||
):
|
||||
"""Overlapping model-folder registrations can emit the same path twice."""
|
||||
file_path = temp_dir / "shared.safetensors"
|
||||
file_path.write_bytes(b"shared model")
|
||||
|
||||
specs: list[SeedAssetSpec] = [
|
||||
{
|
||||
"abs_path": str(file_path),
|
||||
"size_bytes": 12,
|
||||
"mtime_ns": 1234567890000000000,
|
||||
"info_name": "Shared Model",
|
||||
"tags": ["models", "model_type:checkpoints"],
|
||||
"fname": "shared.safetensors",
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": "application/safetensors",
|
||||
},
|
||||
{
|
||||
"abs_path": str(file_path),
|
||||
"size_bytes": 12,
|
||||
"mtime_ns": 1234567890000000000,
|
||||
"info_name": "Shared Model",
|
||||
"tags": ["models", "model_type:diffusion_models"],
|
||||
"fname": "shared.safetensors",
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": "application/safetensors",
|
||||
},
|
||||
]
|
||||
|
||||
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||
|
||||
assert result.inserted_refs == 1
|
||||
assert result.won_paths == 1
|
||||
refs = session.query(AssetReference).all()
|
||||
assert len(refs) == 1
|
||||
assert set(get_reference_tags(session, reference_id=refs[0].id)) == {
|
||||
"models",
|
||||
"model_type:checkpoints",
|
||||
"model_type:diffusion_models",
|
||||
}
|
||||
|
||||
def test_duplicate_paths_are_merged_after_abspath_normalization(
|
||||
self, session: Session, temp_dir: Path, monkeypatch
|
||||
):
|
||||
"""The scanner may emit equivalent paths with different spelling."""
|
||||
file_path = temp_dir / "same-file.safetensors"
|
||||
file_path.write_bytes(b"shared model")
|
||||
monkeypatch.chdir(temp_dir)
|
||||
relative_path = file_path.name
|
||||
absolute_path = os.path.abspath(relative_path)
|
||||
|
||||
specs: list[SeedAssetSpec] = [
|
||||
{
|
||||
"abs_path": relative_path,
|
||||
"size_bytes": 12,
|
||||
"mtime_ns": 1234567890000000000,
|
||||
"info_name": "Shared Model",
|
||||
"tags": ["models", "model_type:checkpoints"],
|
||||
"fname": "same-file.safetensors",
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": "application/safetensors",
|
||||
},
|
||||
{
|
||||
"abs_path": absolute_path,
|
||||
"size_bytes": 12,
|
||||
"mtime_ns": 1234567890000000000,
|
||||
"info_name": "Shared Model",
|
||||
"tags": ["models", "model_type:diffusion_models"],
|
||||
"fname": "same-file.safetensors",
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": "application/safetensors",
|
||||
},
|
||||
]
|
||||
|
||||
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||
|
||||
assert result.inserted_refs == 1
|
||||
assert result.won_paths == 1
|
||||
refs = session.query(AssetReference).all()
|
||||
assert len(refs) == 1
|
||||
assert refs[0].file_path == absolute_path
|
||||
# loader_path is persisted from the spec's fname (compute_loader_path).
|
||||
assert refs[0].loader_path == "same-file.safetensors"
|
||||
assert set(get_reference_tags(session, reference_id=refs[0].id)) == {
|
||||
"models",
|
||||
"model_type:checkpoints",
|
||||
"model_type:diffusion_models",
|
||||
}
|
||||
|
||||
def test_scanner_duplicate_shared_model_paths_keep_all_model_type_tags(
|
||||
self, session: Session, temp_dir: Path
|
||||
):
|
||||
"""Shared extra model roots make scanner collection emit duplicate paths."""
|
||||
shared_root = temp_dir / "shared"
|
||||
input_dir = temp_dir / "input"
|
||||
output_dir = temp_dir / "output"
|
||||
temp_root = temp_dir / "temp"
|
||||
for directory in (shared_root, input_dir, output_dir, temp_root):
|
||||
directory.mkdir()
|
||||
file_path = shared_root / "dual_use_model.safetensors"
|
||||
file_path.write_bytes(b"shared model")
|
||||
|
||||
with (
|
||||
patch("app.assets.services.path_utils.folder_paths") as mock_fp,
|
||||
patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[
|
||||
("checkpoints", [str(shared_root)], {".safetensors"}),
|
||||
("diffusion_models", [str(shared_root)], {".safetensors"}),
|
||||
],
|
||||
),
|
||||
):
|
||||
mock_fp.get_input_directory.return_value = str(input_dir)
|
||||
mock_fp.get_output_directory.return_value = str(output_dir)
|
||||
mock_fp.get_temp_directory.return_value = str(temp_root)
|
||||
|
||||
specs, tag_pool, skipped = build_asset_specs(
|
||||
paths=[str(file_path), str(file_path)],
|
||||
existing_paths=set(),
|
||||
enable_metadata_extraction=False,
|
||||
compute_hashes=False,
|
||||
)
|
||||
|
||||
assert skipped == 0
|
||||
assert len(specs) == 2
|
||||
assert tag_pool == {
|
||||
"models",
|
||||
"model_type:checkpoints",
|
||||
"model_type:diffusion_models",
|
||||
}
|
||||
|
||||
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||
|
||||
assert result.inserted_refs == 1
|
||||
assert result.won_paths == 1
|
||||
refs = session.query(AssetReference).all()
|
||||
assert len(refs) == 1
|
||||
assert set(get_reference_tags(session, reference_id=refs[0].id)) == {
|
||||
"models",
|
||||
"model_type:checkpoints",
|
||||
"model_type:diffusion_models",
|
||||
}
|
||||
|
||||
def test_loader_path_persisted_as_null_when_fname_is_none(
|
||||
self, session: Session, temp_dir: Path
|
||||
):
|
||||
"""A file with no in-root loader path (fname=None, e.g. an orphan under
|
||||
models_root) persists loader_path as NULL rather than a synthesized value."""
|
||||
file_path = temp_dir / "orphan.bin"
|
||||
file_path.write_bytes(b"x")
|
||||
|
||||
specs: list[SeedAssetSpec] = [
|
||||
{
|
||||
"abs_path": str(file_path),
|
||||
"size_bytes": 1,
|
||||
"mtime_ns": 1234567890000000000,
|
||||
"info_name": "orphan.bin",
|
||||
"tags": [],
|
||||
"fname": None,
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": None,
|
||||
}
|
||||
]
|
||||
|
||||
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||
|
||||
assert result.inserted_refs == 1
|
||||
refs = session.query(AssetReference).all()
|
||||
assert len(refs) == 1
|
||||
assert refs[0].file_path == str(file_path)
|
||||
assert refs[0].loader_path is None
|
||||
|
||||
|
||||
class TestMetadataExtraction:
|
||||
def test_extracts_mime_type_for_model_files(self, temp_dir: Path):
|
||||
|
||||
@ -94,47 +94,6 @@ class TestIngestFileFromPath:
|
||||
ref_tags = get_reference_tags(session, reference_id=result.reference_id)
|
||||
assert set(ref_tags) == {"models", "checkpoints"}
|
||||
|
||||
def test_path_derived_tags_use_automatic_origin(
|
||||
self, mock_create_session, temp_dir: Path, session: Session
|
||||
):
|
||||
input_dir = temp_dir / "input"
|
||||
output_dir = temp_dir / "output"
|
||||
temp_root = temp_dir / "temp"
|
||||
for directory in (input_dir, output_dir, temp_root):
|
||||
directory.mkdir()
|
||||
file_path = input_dir / "pasted" / "tagged.png"
|
||||
file_path.parent.mkdir()
|
||||
file_path.write_bytes(b"data")
|
||||
|
||||
with (
|
||||
patch("app.assets.services.path_utils.folder_paths") as mock_fp,
|
||||
patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
mock_fp.get_input_directory.return_value = str(input_dir)
|
||||
mock_fp.get_output_directory.return_value = str(output_dir)
|
||||
mock_fp.get_temp_directory.return_value = str(temp_root)
|
||||
|
||||
result = _ingest_file_from_path(
|
||||
abs_path=str(file_path),
|
||||
asset_hash="blake3:pathorigin",
|
||||
size_bytes=4,
|
||||
mtime_ns=1234567890000000000,
|
||||
info_name="Tagged Asset",
|
||||
tags=["input", "manual-label"],
|
||||
)
|
||||
|
||||
assert result.reference_id is not None
|
||||
links = session.query(AssetReferenceTag).filter_by(
|
||||
asset_reference_id=result.reference_id
|
||||
)
|
||||
origin_by_tag = {link.tag_name: link.origin for link in links}
|
||||
assert origin_by_tag["input"] == "automatic"
|
||||
assert origin_by_tag["pasted"] == "automatic"
|
||||
assert origin_by_tag["manual-label"] == "manual"
|
||||
|
||||
def test_idempotent_upsert(self, mock_create_session, temp_dir: Path, session: Session):
|
||||
file_path = temp_dir / "dup.bin"
|
||||
file_path.write_bytes(b"content")
|
||||
|
||||
@ -6,16 +6,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.services.path_utils import (
|
||||
compute_display_name,
|
||||
compute_loader_path,
|
||||
compute_logical_path,
|
||||
get_asset_category_and_relative_path,
|
||||
get_known_input_subfolder_tags_from_path,
|
||||
get_known_subfolder_tags,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
)
|
||||
from app.assets.services.path_utils import get_asset_category_and_relative_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -26,8 +17,7 @@ def fake_dirs():
|
||||
input_dir = root_path / "input"
|
||||
output_dir = root_path / "output"
|
||||
temp_dir = root_path / "temp"
|
||||
models_root = root_path / "models"
|
||||
models_dir = models_root / "checkpoints"
|
||||
models_dir = root_path / "models" / "checkpoints"
|
||||
for d in (input_dir, output_dir, temp_dir, models_dir):
|
||||
d.mkdir(parents=True)
|
||||
|
||||
@ -35,17 +25,15 @@ def fake_dirs():
|
||||
mock_fp.get_input_directory.return_value = str(input_dir)
|
||||
mock_fp.get_output_directory.return_value = str(output_dir)
|
||||
mock_fp.get_temp_directory.return_value = str(temp_dir)
|
||||
mock_fp.models_dir = str(models_root)
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[("checkpoints", [str(models_dir)], {".safetensors"})],
|
||||
return_value=[("checkpoints", [str(models_dir)])],
|
||||
):
|
||||
yield {
|
||||
"input": input_dir,
|
||||
"output": output_dir,
|
||||
"temp": temp_dir,
|
||||
"models_root": models_root,
|
||||
"models": models_dir,
|
||||
}
|
||||
|
||||
@ -88,502 +76,6 @@ class TestGetAssetCategoryAndRelativePath:
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "models"
|
||||
|
||||
def test_model_path_tags_include_registered_model_type_only(self, fake_dirs):
|
||||
f = fake_dirs["models"] / "subdir" / "model.safetensors"
|
||||
f.parent.mkdir()
|
||||
f.touch()
|
||||
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "models" in tags
|
||||
assert "model_type:checkpoints" in tags
|
||||
assert "checkpoints" not in tags
|
||||
assert "subdir" not in tags
|
||||
|
||||
def test_model_type_preserves_registered_folder_case(self, fake_dirs):
|
||||
llm_dir = fake_dirs["models"].parent / "LLM"
|
||||
llm_dir.mkdir()
|
||||
f = llm_dir / "model.safetensors"
|
||||
f.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[("LLM", [str(llm_dir)], {".safetensors"})],
|
||||
):
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "models" in tags
|
||||
assert "model_type:LLM" in tags
|
||||
assert "model_type:llm" not in tags
|
||||
|
||||
def test_path_components_do_not_create_model_type_tags(self, fake_dirs):
|
||||
f = fake_dirs["models"] / "loras" / "model.safetensors"
|
||||
f.parent.mkdir()
|
||||
f.touch()
|
||||
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "models" in tags
|
||||
assert "model_type:checkpoints" in tags
|
||||
assert "loras" not in tags
|
||||
assert "model_type:loras" not in tags
|
||||
|
||||
def test_shared_root_returns_all_matching_model_type_tags(self, fake_dirs):
|
||||
shared_root = fake_dirs["models"].parent / "shared"
|
||||
shared_root.mkdir()
|
||||
f = shared_root / "foo.safetensors"
|
||||
f.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[
|
||||
("checkpoints", [str(shared_root)], {".safetensors"}),
|
||||
("loras", [str(shared_root)], {".safetensors"}),
|
||||
],
|
||||
):
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "models" in tags
|
||||
assert "model_type:checkpoints" in tags
|
||||
assert "model_type:loras" in tags
|
||||
|
||||
def test_shared_root_model_type_tags_respect_bucket_extensions(self, fake_dirs):
|
||||
"""Buckets sharing a base dir only tag files matching their extensions."""
|
||||
shared_root = fake_dirs["models"].parent / "unet"
|
||||
shared_root.mkdir()
|
||||
safetensors_file = shared_root / "wan.safetensors"
|
||||
gguf_file = shared_root / "wan.gguf"
|
||||
safetensors_file.touch()
|
||||
gguf_file.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[
|
||||
("diffusion_models", [str(shared_root)], {".safetensors"}),
|
||||
("unet_gguf", [str(shared_root)], {".gguf"}),
|
||||
],
|
||||
):
|
||||
_name, safetensors_tags = get_name_and_tags_from_asset_path(str(safetensors_file))
|
||||
_name, gguf_tags = get_name_and_tags_from_asset_path(str(gguf_file))
|
||||
|
||||
assert "model_type:diffusion_models" in safetensors_tags
|
||||
assert "model_type:unet_gguf" not in safetensors_tags
|
||||
assert "model_type:unet_gguf" in gguf_tags
|
||||
assert "model_type:diffusion_models" not in gguf_tags
|
||||
|
||||
def test_empty_extension_set_tags_any_extension(self, fake_dirs):
|
||||
"""Custom buckets registered without extensions accept every file."""
|
||||
custom_root = fake_dirs["models"].parent / "custom_bucket"
|
||||
custom_root.mkdir()
|
||||
f = custom_root / "weights.bin"
|
||||
f.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[("custom_bucket", [str(custom_root)], set())],
|
||||
):
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "models" in tags
|
||||
assert "model_type:custom_bucket" in tags
|
||||
|
||||
def test_no_extension_match_keeps_models_tag_without_model_type(self, fake_dirs):
|
||||
f = fake_dirs["models"] / "notes.txt"
|
||||
f.touch()
|
||||
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "models" in tags
|
||||
assert not any(tag.startswith("model_type:") for tag in tags)
|
||||
|
||||
def test_output_backed_registered_folder_gets_model_and_output_tags(self, fake_dirs):
|
||||
output_checkpoints_dir = fake_dirs["output"] / "checkpoints"
|
||||
output_checkpoints_dir.mkdir()
|
||||
f = output_checkpoints_dir / "saved.safetensors"
|
||||
f.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[("checkpoints", [str(output_checkpoints_dir)], {".safetensors"})],
|
||||
):
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "models" in tags
|
||||
assert "model_type:checkpoints" in tags
|
||||
assert "output" in tags
|
||||
|
||||
def test_temp_path_tags_include_temp_not_output_or_preview(self, fake_dirs):
|
||||
f = fake_dirs["temp"] / "preview.png"
|
||||
f.touch()
|
||||
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
|
||||
assert "temp" in tags
|
||||
assert "output" not in tags
|
||||
assert "preview:true" not in tags
|
||||
|
||||
def test_known_subfolder_tags_are_centralized(self):
|
||||
assert get_known_subfolder_tags("pasted") == ["pasted"]
|
||||
assert get_known_subfolder_tags("arbitrary") == []
|
||||
|
||||
def test_known_input_subfolder_tags_are_path_derived_for_direct_children(self, fake_dirs):
|
||||
f = fake_dirs["input"] / "pasted" / "image.png"
|
||||
f.parent.mkdir()
|
||||
f.touch()
|
||||
|
||||
assert get_known_input_subfolder_tags_from_path(str(f)) == ["pasted"]
|
||||
|
||||
_name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert "input" in tags
|
||||
assert "pasted" in tags
|
||||
|
||||
def test_known_input_subfolder_tags_do_not_apply_to_nested_or_other_roots(self, fake_dirs):
|
||||
nested = fake_dirs["input"] / "pasted" / "session" / "image.png"
|
||||
output = fake_dirs["output"] / "pasted" / "image.png"
|
||||
for path in (nested, output):
|
||||
path.parent.mkdir(parents=True)
|
||||
path.touch()
|
||||
|
||||
assert get_known_input_subfolder_tags_from_path(str(nested)) == []
|
||||
assert get_known_input_subfolder_tags_from_path(str(output)) == []
|
||||
|
||||
def test_unknown_path_raises(self, fake_dirs):
|
||||
with pytest.raises(ValueError, match="not within"):
|
||||
get_asset_category_and_relative_path("/some/random/path.png")
|
||||
|
||||
|
||||
class TestResponseStoragePaths:
|
||||
def test_input_file_path_and_display_name_include_subfolder(self, fake_dirs):
|
||||
sub = fake_dirs["input"] / "some" / "folder"
|
||||
sub.mkdir(parents=True)
|
||||
f = sub / "image.png"
|
||||
f.touch()
|
||||
|
||||
assert compute_logical_path(str(f)) == "input/some/folder/image.png"
|
||||
assert compute_display_name(str(f)) == "some/folder/image.png"
|
||||
|
||||
def test_output_file_path_and_display_name_include_subfolder(self, fake_dirs):
|
||||
sub = fake_dirs["output"] / "renders"
|
||||
sub.mkdir()
|
||||
f = sub / "ComfyUI_00001_.png"
|
||||
f.touch()
|
||||
|
||||
assert compute_logical_path(str(f)) == "output/renders/ComfyUI_00001_.png"
|
||||
assert compute_display_name(str(f)) == "renders/ComfyUI_00001_.png"
|
||||
|
||||
def test_temp_file_path_and_display_name(self, fake_dirs):
|
||||
f = fake_dirs["temp"] / "preview.png"
|
||||
f.touch()
|
||||
|
||||
assert compute_logical_path(str(f)) == "temp/preview.png"
|
||||
assert compute_display_name(str(f)) == "preview.png"
|
||||
|
||||
def test_exact_storage_root_has_no_display_name(self, fake_dirs):
|
||||
assert compute_logical_path(str(fake_dirs["input"])) == "input"
|
||||
assert compute_display_name(str(fake_dirs["input"])) is None
|
||||
|
||||
def test_longest_matching_builtin_root_wins(self, fake_dirs, tmp_path: Path):
|
||||
nested_output = fake_dirs["input"] / "nested-output"
|
||||
nested_output.mkdir()
|
||||
f = nested_output / "image.png"
|
||||
f.touch()
|
||||
|
||||
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
|
||||
mock_fp.get_input_directory.return_value = str(fake_dirs["input"])
|
||||
mock_fp.get_output_directory.return_value = str(nested_output)
|
||||
mock_fp.get_temp_directory.return_value = str(tmp_path / "temp")
|
||||
mock_fp.models_dir = str(fake_dirs["models_root"])
|
||||
|
||||
assert compute_logical_path(str(f)) == "output/image.png"
|
||||
assert compute_display_name(str(f)) == "image.png"
|
||||
|
||||
def test_model_file_path_is_relative_to_physical_models_root(self, fake_dirs):
|
||||
sub = fake_dirs["models"] / "flux"
|
||||
sub.mkdir()
|
||||
f = sub / "model.safetensors"
|
||||
f.touch()
|
||||
|
||||
assert compute_logical_path(str(f)) == "models/checkpoints/flux/model.safetensors"
|
||||
assert compute_display_name(str(f)) == "checkpoints/flux/model.safetensors"
|
||||
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "model.safetensors"
|
||||
assert "models" in tags
|
||||
assert "model_type:checkpoints" in tags
|
||||
assert "checkpoints" not in tags
|
||||
assert "flux" not in tags
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"folder_name",
|
||||
["checkpoints", "clip", "vae", "diffusion_models", "loras"],
|
||||
)
|
||||
def test_output_model_folder_uses_output_storage_file_path(self, fake_dirs, folder_name):
|
||||
output_model_dir = fake_dirs["output"] / folder_name
|
||||
output_model_dir.mkdir(exist_ok=True)
|
||||
default_model_dir = fake_dirs["models_root"] / folder_name
|
||||
default_model_dir.mkdir(exist_ok=True)
|
||||
f = output_model_dir / "saved.safetensors"
|
||||
f.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[
|
||||
(folder_name, [str(default_model_dir), str(output_model_dir)], {".safetensors"})
|
||||
],
|
||||
):
|
||||
assert compute_logical_path(str(f)) == f"output/{folder_name}/saved.safetensors"
|
||||
assert compute_display_name(str(f)) == f"{folder_name}/saved.safetensors"
|
||||
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "saved.safetensors"
|
||||
assert "output" in tags
|
||||
assert "models" in tags
|
||||
assert f"model_type:{folder_name}" in tags
|
||||
assert folder_name not in tags
|
||||
|
||||
def test_output_model_subfolder_uses_output_storage_file_path(self, fake_dirs):
|
||||
folder_name = "loras"
|
||||
output_model_dir = fake_dirs["output"] / folder_name
|
||||
subdir = output_model_dir / "experiments"
|
||||
subdir.mkdir(parents=True)
|
||||
f = subdir / "my_lora.safetensors"
|
||||
f.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[(folder_name, [str(output_model_dir)], {".safetensors"})],
|
||||
):
|
||||
assert (
|
||||
compute_logical_path(str(f))
|
||||
== "output/loras/experiments/my_lora.safetensors"
|
||||
)
|
||||
assert compute_display_name(str(f)) == "loras/experiments/my_lora.safetensors"
|
||||
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "my_lora.safetensors"
|
||||
assert "output" in tags
|
||||
assert "models" in tags
|
||||
assert "model_type:loras" in tags
|
||||
assert "loras" not in tags
|
||||
assert "experiments" not in tags
|
||||
|
||||
def test_external_model_folder_without_provenance_has_no_file_path(self, tmp_path: Path):
|
||||
external_checkpoints_dir = tmp_path / "external" / "not_named_like_category"
|
||||
external_checkpoints_dir.mkdir(parents=True)
|
||||
f = external_checkpoints_dir / "external.safetensors"
|
||||
f.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[("checkpoints", [str(external_checkpoints_dir)], {".safetensors"})],
|
||||
):
|
||||
assert compute_logical_path(str(f)) is None
|
||||
assert compute_display_name(str(f)) is None
|
||||
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "external.safetensors"
|
||||
assert "models" in tags
|
||||
assert "model_type:checkpoints" in tags
|
||||
|
||||
def test_same_relative_model_file_under_multiple_external_roots_has_no_storage_file_path(
|
||||
self, tmp_path: Path
|
||||
):
|
||||
foo_dir = tmp_path / "foo"
|
||||
bar_dir = tmp_path / "bar"
|
||||
foo_dir.mkdir()
|
||||
bar_dir.mkdir()
|
||||
foo_file = foo_dir / "baz.safetensors"
|
||||
bar_file = bar_dir / "baz.safetensors"
|
||||
foo_file.touch()
|
||||
bar_file.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[("checkpoints", [str(foo_dir), str(bar_dir)], {".safetensors"})],
|
||||
):
|
||||
assert compute_logical_path(str(foo_file)) is None
|
||||
assert compute_logical_path(str(bar_file)) is None
|
||||
assert compute_display_name(str(foo_file)) is None
|
||||
assert compute_display_name(str(bar_file)) is None
|
||||
|
||||
def test_output_clip_folder_uses_output_storage_and_text_encoder_tag(self, fake_dirs):
|
||||
output_clip_dir = fake_dirs["output"] / "clip"
|
||||
output_clip_dir.mkdir()
|
||||
f = output_clip_dir / "clip_l.safetensors"
|
||||
f.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[("text_encoders", [str(output_clip_dir)], {".safetensors"})],
|
||||
):
|
||||
assert compute_logical_path(str(f)) == "output/clip/clip_l.safetensors"
|
||||
assert compute_display_name(str(f)) == "clip/clip_l.safetensors"
|
||||
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "clip_l.safetensors"
|
||||
assert "output" in tags
|
||||
assert "models" in tags
|
||||
assert "model_type:text_encoders" in tags
|
||||
assert "clip" not in tags
|
||||
|
||||
def test_physical_unet_folder_uses_storage_path_and_diffusion_models_tag(self, fake_dirs):
|
||||
unet_dir = fake_dirs["models_root"] / "unet"
|
||||
diffusion_models_dir = fake_dirs["models_root"] / "diffusion_models"
|
||||
unet_dir.mkdir()
|
||||
diffusion_models_dir.mkdir()
|
||||
f = unet_dir / "wan.safetensors"
|
||||
f.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[
|
||||
("diffusion_models", [str(unet_dir), str(diffusion_models_dir)], {".safetensors"})
|
||||
],
|
||||
):
|
||||
assert compute_logical_path(str(f)) == "models/unet/wan.safetensors"
|
||||
assert compute_display_name(str(f)) == "unet/wan.safetensors"
|
||||
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "wan.safetensors"
|
||||
assert "models" in tags
|
||||
assert "model_type:diffusion_models" in tags
|
||||
assert "unet" not in tags
|
||||
|
||||
def test_unregistered_file_under_physical_models_root_still_has_storage_file_path(self, fake_dirs):
|
||||
f = fake_dirs["models_root"] / "not_registered" / "orphan.bin"
|
||||
f.parent.mkdir()
|
||||
f.touch()
|
||||
|
||||
assert compute_logical_path(str(f)) == "models/not_registered/orphan.bin"
|
||||
assert compute_display_name(str(f)) == "not_registered/orphan.bin"
|
||||
|
||||
def test_output_checkpoint_folder_without_registration_has_only_output_tag(self, fake_dirs):
|
||||
f = fake_dirs["output"] / "checkpoints" / "saved.safetensors"
|
||||
f.parent.mkdir(exist_ok=True)
|
||||
f.touch()
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[],
|
||||
):
|
||||
assert compute_logical_path(str(f)) == "output/checkpoints/saved.safetensors"
|
||||
assert compute_display_name(str(f)) == "checkpoints/saved.safetensors"
|
||||
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "saved.safetensors"
|
||||
assert "output" in tags
|
||||
assert "models" not in tags
|
||||
assert not any(tag.startswith("model_type:") for tag in tags)
|
||||
|
||||
def test_unknown_path_returns_none(self):
|
||||
assert compute_logical_path("/some/random/path.png") is None
|
||||
assert compute_display_name("/some/random/path.png") is None
|
||||
|
||||
|
||||
class TestLoaderPath:
|
||||
"""In-root loader path: relative to the storage root, model category dropped."""
|
||||
|
||||
def test_model_loader_path_drops_category(self, fake_dirs):
|
||||
sub = fake_dirs["models"] / "flux"
|
||||
sub.mkdir()
|
||||
f = sub / "model.safetensors"
|
||||
f.touch()
|
||||
|
||||
# logical_path keeps the category, file_path (loader) drops it
|
||||
assert compute_logical_path(str(f)) == "models/checkpoints/flux/model.safetensors"
|
||||
assert compute_loader_path(str(f)) == "flux/model.safetensors"
|
||||
|
||||
def test_model_loader_path_flat_file(self, fake_dirs):
|
||||
f = fake_dirs["models"] / "model.safetensors"
|
||||
f.touch()
|
||||
|
||||
assert compute_loader_path(str(f)) == "model.safetensors"
|
||||
|
||||
def test_input_loader_path_keeps_subfolders(self, fake_dirs):
|
||||
sub = fake_dirs["input"] / "some" / "folder"
|
||||
sub.mkdir(parents=True)
|
||||
f = sub / "image.png"
|
||||
f.touch()
|
||||
|
||||
assert compute_loader_path(str(f)) == "some/folder/image.png"
|
||||
|
||||
def test_temp_loader_path(self, fake_dirs):
|
||||
f = fake_dirs["temp"] / "preview.png"
|
||||
f.touch()
|
||||
|
||||
assert compute_loader_path(str(f)) == "preview.png"
|
||||
|
||||
def test_unregistered_file_under_models_root_has_no_loader_path(self, fake_dirs):
|
||||
# Under models_root but not within any registered category base.
|
||||
f = fake_dirs["models_root"] / "not_registered" / "orphan.bin"
|
||||
f.parent.mkdir()
|
||||
f.touch()
|
||||
|
||||
# It still has a namespaced logical_path, but no loader path.
|
||||
assert compute_logical_path(str(f)) == "models/not_registered/orphan.bin"
|
||||
assert compute_loader_path(str(f)) is None
|
||||
|
||||
def test_extra_path_model_has_loader_path_but_no_logical_path(self, tmp_path: Path):
|
||||
"""Registered category base outside models_dir (extra_model_paths style).
|
||||
|
||||
Loadable, so loader_path resolves; but it is not under any canonical
|
||||
storage root, so logical_path/display_name are None. This asymmetry is
|
||||
intentional: loader_path resolves every registered model-folder base,
|
||||
logical_path only resolves the canonical storage roots.
|
||||
"""
|
||||
extra = tmp_path / "extra_ckpts"
|
||||
extra.mkdir()
|
||||
f = extra / "foo.safetensors"
|
||||
f.touch()
|
||||
|
||||
with patch("app.assets.services.path_utils.folder_paths") as mock_fp, patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[("checkpoints", [str(extra)], {".safetensors"})],
|
||||
):
|
||||
mock_fp.get_input_directory.return_value = str(tmp_path / "in")
|
||||
mock_fp.get_output_directory.return_value = str(tmp_path / "out")
|
||||
mock_fp.get_temp_directory.return_value = str(tmp_path / "tmp")
|
||||
mock_fp.models_dir = str(tmp_path / "models") # extra is NOT under this
|
||||
|
||||
assert compute_loader_path(str(f)) == "foo.safetensors"
|
||||
assert compute_logical_path(str(f)) is None
|
||||
assert compute_display_name(str(f)) is None
|
||||
|
||||
def test_unknown_path_returns_none(self):
|
||||
assert compute_loader_path("/some/random/path.png") is None
|
||||
|
||||
|
||||
class TestResolveDestinationFromTags:
|
||||
def test_extra_tags_are_not_path_components(self, fake_dirs):
|
||||
base_dir, subdirs = resolve_destination_from_tags(["input", "unit-tests", "foo"])
|
||||
|
||||
assert base_dir == os.path.abspath(fake_dirs["input"])
|
||||
assert subdirs == []
|
||||
|
||||
def test_model_upload_rejects_non_writable_registered_folders(self):
|
||||
with tempfile.TemporaryDirectory() as root:
|
||||
root_path = Path(root)
|
||||
checkpoints_dir = root_path / "models" / "checkpoints"
|
||||
configs_dir = root_path / "models" / "configs"
|
||||
custom_nodes_dir = root_path / "custom_nodes"
|
||||
for path in (checkpoints_dir, configs_dir, custom_nodes_dir):
|
||||
path.mkdir(parents=True)
|
||||
|
||||
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
|
||||
mock_fp.folder_names_and_paths = {
|
||||
"checkpoints": ([str(checkpoints_dir)], set()),
|
||||
"configs": ([str(configs_dir)], set()),
|
||||
"custom_nodes": ([str(custom_nodes_dir)], set()),
|
||||
}
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(
|
||||
["models", "model_type:checkpoints"]
|
||||
)
|
||||
assert base_dir == os.path.abspath(checkpoints_dir)
|
||||
assert subdirs == []
|
||||
|
||||
for folder_name in ("configs", "custom_nodes"):
|
||||
with pytest.raises(ValueError, match="unknown model category"):
|
||||
resolve_destination_from_tags(
|
||||
["models", f"model_type:{folder_name}"]
|
||||
)
|
||||
|
||||
@ -19,8 +19,7 @@ def test_seed_asset_removed_when_file_is_deleted(
|
||||
"""Asset without hash (seed) whose file disappears:
|
||||
after triggering sync_seed_assets, Asset + AssetInfo disappear.
|
||||
"""
|
||||
# Create a file directly under input/unit-tests/<case>. Backend tags only
|
||||
# classify the root; nested path components are not exposed as tags.
|
||||
# Create a file directly under input/unit-tests/<case> so tags include "unit-tests"
|
||||
case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed"
|
||||
case_dir.mkdir(parents=True, exist_ok=True)
|
||||
name = f"seed_{uuid.uuid4().hex[:8]}.bin"
|
||||
@ -33,7 +32,7 @@ def test_seed_asset_removed_when_file_is_deleted(
|
||||
# Verify it is visible via API and carries no hash (seed)
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": root, "name_contains": name},
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body1 = r1.json()
|
||||
@ -55,7 +54,7 @@ def test_seed_asset_removed_when_file_is_deleted(
|
||||
# It should disappear (AssetInfo and seed Asset gone)
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": root, "name_contains": name},
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body2 = r2.json()
|
||||
@ -133,7 +132,7 @@ def test_hashed_asset_two_asset_infos_both_get_missing(
|
||||
second_id = b2["id"]
|
||||
|
||||
# Remove the single underlying file
|
||||
p = comfy_tmp_base_dir / "input" / get_asset_filename(created["asset_hash"], ".png")
|
||||
p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png")
|
||||
assert p.exists()
|
||||
p.unlink()
|
||||
|
||||
@ -251,7 +250,8 @@ def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match(
|
||||
|
||||
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = a["id"]
|
||||
p = comfy_tmp_base_dir / root / get_asset_filename(a["asset_hash"], ".bin")
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
p = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||
st0 = p.stat()
|
||||
orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
@ -290,7 +290,7 @@ def test_metadata_filename_is_set_for_seed_asset_without_hash(
|
||||
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": root, "name_contains": name},
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body = r1.json()
|
||||
|
||||
@ -95,7 +95,7 @@ def test_download_chooses_existing_state_and_updates_access_time(
|
||||
assert t1 > t0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "model_type:checkpoints"]}], indirect=True)
|
||||
@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True)
|
||||
def test_download_missing_file_returns_404(
|
||||
http: requests.Session, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict
|
||||
):
|
||||
|
||||
@ -13,7 +13,7 @@ def _seed(asset_factory, make_asset_bytes, count: int, tag: str) -> list[str]:
|
||||
for n in names:
|
||||
asset_factory(
|
||||
n,
|
||||
["models", "model_type:checkpoints", "unit-tests", tag],
|
||||
["models", "checkpoints", "unit-tests", tag],
|
||||
{},
|
||||
make_asset_bytes(n, size=2048),
|
||||
)
|
||||
@ -208,7 +208,7 @@ def test_cursor_walks_for_non_name_sorts(sort_field, http: requests.Session, api
|
||||
names = []
|
||||
for i in range(4):
|
||||
n = f"cursor_{sort_field}_{i:02d}.safetensors"
|
||||
asset_factory(n, ["models", "model_type:checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i))
|
||||
asset_factory(n, ["models", "checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i))
|
||||
names.append(n)
|
||||
|
||||
params = {
|
||||
|
||||
@ -11,7 +11,7 @@ def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asse
|
||||
for n in names:
|
||||
asset_factory(
|
||||
n,
|
||||
["models", "model_type:checkpoints", "unit-tests", "paging"],
|
||||
["models", "checkpoints", "unit-tests", "paging"],
|
||||
{"epoch": 1},
|
||||
make_asset_bytes(n, size=2048),
|
||||
)
|
||||
@ -45,8 +45,8 @@ def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asse
|
||||
|
||||
|
||||
def test_list_assets_include_exclude_and_name_contains(http: requests.Session, api_base: str, asset_factory):
|
||||
a = asset_factory("inc_a.safetensors", ["models", "model_type:checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024)
|
||||
b = asset_factory("inc_b.safetensors", ["models", "model_type:checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024)
|
||||
a = asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024)
|
||||
b = asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024)
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
@ -81,7 +81,7 @@ def test_list_assets_include_exclude_and_name_contains(http: requests.Session, a
|
||||
|
||||
|
||||
def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "model_type:checkpoints", "unit-tests", "lf-size"]
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-size"]
|
||||
n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors"
|
||||
asset_factory(n1, t, {}, make_asset_bytes(n1, 1024))
|
||||
asset_factory(n2, t, {}, make_asset_bytes(n2, 2048))
|
||||
@ -108,7 +108,7 @@ def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, mak
|
||||
|
||||
|
||||
def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "model_type:checkpoints", "unit-tests", "lf-upd"]
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-upd"]
|
||||
a1 = asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200))
|
||||
a2 = asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200))
|
||||
|
||||
@ -131,7 +131,7 @@ def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make
|
||||
|
||||
|
||||
def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "model_type:checkpoints", "unit-tests", "lf-access"]
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-access"]
|
||||
asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100))
|
||||
time.sleep(0.02)
|
||||
a2 = asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100))
|
||||
@ -154,14 +154,14 @@ def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory
|
||||
|
||||
|
||||
def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "model_type:checkpoints", "unit-tests", "lf-include"]
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-include"]
|
||||
a = asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva"))
|
||||
asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb"))
|
||||
|
||||
# CSV tag filters are whitespace-trimmed and case-sensitive.
|
||||
# CSV + case-insensitive
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-include,alpha"},
|
||||
params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
@ -196,14 +196,14 @@ def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factor
|
||||
|
||||
|
||||
def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "model_type:checkpoints", "unit-tests", "lf-exclude"]
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-exclude"]
|
||||
a = asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900))
|
||||
asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900))
|
||||
|
||||
# Exclude filters are case-sensitive.
|
||||
# Exclude uppercase should work
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "beta"},
|
||||
params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
@ -225,7 +225,7 @@ def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory,
|
||||
|
||||
|
||||
def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "model_type:checkpoints", "unit-tests", "lf-name"]
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-name"]
|
||||
a1 = asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800))
|
||||
a2 = asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800))
|
||||
|
||||
@ -261,7 +261,7 @@ def test_list_assets_name_contains_case_and_specials(http, api_base, asset_facto
|
||||
|
||||
|
||||
def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "model_type:checkpoints", "unit-tests", "lf-pagelimits"]
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"]
|
||||
asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600))
|
||||
asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600))
|
||||
asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600))
|
||||
@ -319,7 +319,7 @@ def test_list_assets_name_contains_literal_underscore(
|
||||
- foobar.safetensors (must NOT match)
|
||||
"""
|
||||
scope = f"lf-underscore-{uuid.uuid4().hex[:6]}"
|
||||
tags = ["models", "model_type:checkpoints", "unit-tests", scope]
|
||||
tags = ["models", "checkpoints", "unit-tests", scope]
|
||||
|
||||
a = asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700))
|
||||
b = asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700))
|
||||
|
||||
@ -5,7 +5,7 @@ def test_meta_and_across_keys_and_types(
|
||||
http, api_base: str, asset_factory, make_asset_bytes
|
||||
):
|
||||
name = "mf_and_mix.safetensors"
|
||||
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-and"]
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-and"]
|
||||
meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
|
||||
asset_factory(name, tags, meta, make_asset_bytes(name, 4096))
|
||||
|
||||
@ -41,7 +41,7 @@ def test_meta_and_across_keys_and_types(
|
||||
|
||||
def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_types.safetensors"
|
||||
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-types"]
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-types"]
|
||||
meta = {"epoch": 1, "active": True}
|
||||
asset_factory(name, tags, meta, make_asset_bytes(name))
|
||||
|
||||
@ -95,7 +95,7 @@ def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory,
|
||||
|
||||
def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_list_scalars.safetensors"
|
||||
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-list"]
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-list"]
|
||||
meta = {"flags": ["red", "green"]}
|
||||
asset_factory(name, tags, meta, make_asset_bytes(name, 3000))
|
||||
|
||||
@ -134,7 +134,7 @@ def test_meta_none_semantics_missing_or_null_and_any_of_with_none(
|
||||
http, api_base, asset_factory, make_asset_bytes
|
||||
):
|
||||
# a1: key missing; a2: explicit null; a3: concrete value
|
||||
t = ["models", "model_type:checkpoints", "unit-tests", "mf-none"]
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-none"]
|
||||
a1 = asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1"))
|
||||
a2 = asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2"))
|
||||
a3 = asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3"))
|
||||
@ -166,7 +166,7 @@ def test_meta_none_semantics_missing_or_null_and_any_of_with_none(
|
||||
|
||||
def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_nested_json.safetensors"
|
||||
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-nested"]
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-nested"]
|
||||
cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}}
|
||||
asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200))
|
||||
|
||||
@ -197,7 +197,7 @@ def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_as
|
||||
|
||||
def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_list_objects.safetensors"
|
||||
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-objlist"]
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-objlist"]
|
||||
transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}]
|
||||
asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048))
|
||||
|
||||
@ -228,7 +228,7 @@ def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_b
|
||||
|
||||
def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_keys_unicode.safetensors"
|
||||
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-keys"]
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-keys"]
|
||||
meta = {
|
||||
"weird.key": "v1",
|
||||
"path/like": 7,
|
||||
@ -259,7 +259,7 @@ def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_
|
||||
|
||||
|
||||
def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "model_type:checkpoints", "unit-tests", "mf-zero-bool"]
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"]
|
||||
a0 = asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025))
|
||||
a1 = asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026))
|
||||
|
||||
@ -286,7 +286,7 @@ def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_as
|
||||
|
||||
def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_mixed_list.safetensors"
|
||||
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-mixed"]
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-mixed"]
|
||||
meta = {"mix": ["1", 1, True, None]}
|
||||
asset_factory(name, tags, meta, make_asset_bytes(name, 1999))
|
||||
|
||||
@ -311,7 +311,7 @@ def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, mak
|
||||
|
||||
def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes):
|
||||
# Use a unique scope tag to avoid interference
|
||||
t = ["models", "model_type:checkpoints", "unit-tests", "mf-unknown-scope"]
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"]
|
||||
x = asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua"))
|
||||
y = asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub"))
|
||||
|
||||
@ -340,13 +340,13 @@ def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_
|
||||
# alpha matches epoch=1; beta has epoch=2
|
||||
a = asset_factory(
|
||||
"mf_tag_alpha.safetensors",
|
||||
["models", "model_type:checkpoints", "unit-tests", "mf-tag", "alpha"],
|
||||
["models", "checkpoints", "unit-tests", "mf-tag", "alpha"],
|
||||
{"epoch": 1},
|
||||
make_asset_bytes("alpha"),
|
||||
)
|
||||
b = asset_factory(
|
||||
"mf_tag_beta.safetensors",
|
||||
["models", "model_type:checkpoints", "unit-tests", "mf-tag", "beta"],
|
||||
["models", "checkpoints", "unit-tests", "mf-tag", "beta"],
|
||||
{"epoch": 2},
|
||||
make_asset_bytes("beta"),
|
||||
)
|
||||
@ -367,7 +367,7 @@ def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_
|
||||
|
||||
def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes):
|
||||
# Three assets in same scope with different sizes and a common filter key
|
||||
t = ["models", "model_type:checkpoints", "unit-tests", "mf-sort"]
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-sort"]
|
||||
n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors"
|
||||
asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024))
|
||||
asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048))
|
||||
|
||||
@ -29,7 +29,7 @@ def create_seed_file(comfy_tmp_base_dir: Path):
|
||||
def find_asset(http: requests.Session, api_base: str):
|
||||
"""Query API for assets matching scope and optional name."""
|
||||
def _find(scope: str, name: str | None = None) -> list[dict]:
|
||||
params = {"limit": "500"}
|
||||
params = {"include_tags": f"unit-tests,{scope}"}
|
||||
if name:
|
||||
params["name_contains"] = name
|
||||
r = http.get(f"{api_base}/api/assets", params=params, timeout=120)
|
||||
@ -91,7 +91,7 @@ def test_hashed_asset_not_pruned_when_file_missing(
|
||||
data = make_asset_bytes("test", 2048)
|
||||
a = asset_factory("test.bin", ["input", "unit-tests", scope], {}, data)
|
||||
|
||||
path = comfy_tmp_base_dir / "input" / get_asset_filename(a["asset_hash"], ".bin")
|
||||
path = comfy_tmp_base_dir / "input" / "unit-tests" / scope / get_asset_filename(a["asset_hash"], ".bin")
|
||||
path.unlink()
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
@ -108,20 +108,18 @@ def test_prune_across_multiple_roots(
|
||||
):
|
||||
"""Prune correctly handles assets across input and output roots."""
|
||||
scope = f"multi-{uuid.uuid4().hex[:6]}"
|
||||
input_name = f"{scope}-input.bin"
|
||||
output_name = f"{scope}-output.bin"
|
||||
input_fp = create_seed_file("input", scope, input_name)
|
||||
create_seed_file("output", scope, output_name)
|
||||
input_fp = create_seed_file("input", scope, "input.bin")
|
||||
create_seed_file("output", scope, "output.bin")
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
assert find_asset(scope, input_name)
|
||||
assert find_asset(scope, output_name)
|
||||
assert len(find_asset(scope)) == 2
|
||||
|
||||
input_fp.unlink()
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
assert not find_asset(scope, input_name)
|
||||
assert find_asset(scope, output_name)
|
||||
remaining = find_asset(scope)
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0]["name"] == "output.bin"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dirname", ["100%_done", "my_folder_name", "has spaces"])
|
||||
|
||||
@ -10,9 +10,9 @@ def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict)
|
||||
body1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names = [t["name"] for t in body1["tags"]]
|
||||
# A few selected contract tags should exist.
|
||||
# A few system tags from migration should exist:
|
||||
assert "models" in names
|
||||
assert "model_type:checkpoints" in names
|
||||
assert "checkpoints" in names
|
||||
|
||||
# Only used tags before we add anything new from this test cycle
|
||||
r2 = http.get(api_base + "/api/tags", params={"include_zero": "false"}, timeout=120)
|
||||
@ -21,7 +21,7 @@ def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict)
|
||||
# We already seeded one asset via fixture, so used tags must be non-empty
|
||||
used_names = [t["name"] for t in body2["tags"]]
|
||||
assert "models" in used_names
|
||||
assert "model_type:checkpoints" in used_names
|
||||
assert "checkpoints" in used_names
|
||||
|
||||
# Prefix filter should refine the list
|
||||
r3 = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}, timeout=120)
|
||||
@ -45,7 +45,7 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory,
|
||||
body1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names = [t["name"] for t in body1["tags"]]
|
||||
assert "models" in names and "model_type:checkpoints" in names
|
||||
assert "models" in names and "checkpoints" in names
|
||||
|
||||
# Create a short-lived asset under input with a unique custom tag
|
||||
scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}"
|
||||
@ -89,28 +89,28 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory,
|
||||
def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# Add tags with duplicates while preserving source case.
|
||||
payload_add = {"tags": ["NewTag", "unit-tests", "NewTag", "BETA"]}
|
||||
# Add tags with duplicates and mixed case
|
||||
payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]}
|
||||
r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200, b1
|
||||
# stripped, deduplicated; 'unit-tests' was already present from the seed
|
||||
assert set(b1["added"]) == {"NewTag", "BETA"}
|
||||
# normalized, deduplicated; 'unit-tests' was already present from the seed
|
||||
assert set(b1["added"]) == {"newtag", "beta"}
|
||||
assert set(b1["already_present"]) == {"unit-tests"}
|
||||
assert "NewTag" in b1["total_tags"] and "BETA" in b1["total_tags"]
|
||||
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
|
||||
|
||||
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
g = rg.json()
|
||||
assert rg.status_code == 200
|
||||
tags_now = set(g["tags"])
|
||||
assert {"NewTag", "BETA"}.issubset(tags_now)
|
||||
assert {"newtag", "beta"}.issubset(tags_now)
|
||||
|
||||
# Remove a tag and a non-existent tag
|
||||
payload_del = {"tags": ["NewTag", "does-not-exist"]}
|
||||
payload_del = {"tags": ["newtag", "does-not-exist"]}
|
||||
r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
assert set(b2["removed"]) == {"NewTag"}
|
||||
assert set(b2["removed"]) == {"newtag"}
|
||||
assert set(b2["not_present"]) == {"does-not-exist"}
|
||||
|
||||
# Verify remaining tags after deletion
|
||||
@ -118,44 +118,8 @@ def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset
|
||||
g2 = rg2.json()
|
||||
assert rg2.status_code == 200
|
||||
tags_later = set(g2["tags"])
|
||||
assert "NewTag" not in tags_later
|
||||
assert "BETA" in tags_later # still present
|
||||
|
||||
|
||||
def test_add_system_looking_tags_allowed_as_labels(
|
||||
http: requests.Session, api_base: str, seeded_asset: dict
|
||||
):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
response = http.post(
|
||||
f"{api_base}/api/assets/{aid}/tags",
|
||||
json={
|
||||
"tags": [
|
||||
"models",
|
||||
"model_type:manual",
|
||||
"model:true",
|
||||
"models:foo",
|
||||
"input:true",
|
||||
"output:true",
|
||||
"uploaded:true",
|
||||
"temp:true",
|
||||
"temporary",
|
||||
]
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
body = response.json()
|
||||
|
||||
assert response.status_code == 200, body
|
||||
assert "models" in body["total_tags"]
|
||||
assert "model_type:manual" in body["total_tags"]
|
||||
assert "model:true" in body["total_tags"]
|
||||
assert "models:foo" in body["total_tags"]
|
||||
assert "input:true" in body["total_tags"]
|
||||
assert "output:true" in body["total_tags"]
|
||||
assert "uploaded:true" in body["total_tags"]
|
||||
assert "temp:true" in body["total_tags"]
|
||||
assert "temporary" in body["total_tags"]
|
||||
assert "newtag" not in tags_later
|
||||
assert "beta" in tags_later # still present
|
||||
|
||||
|
||||
def test_tags_list_order_and_prefix(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
|
||||
@ -1,14 +1,11 @@
|
||||
import json
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import pytest
|
||||
|
||||
from app.assets.api.schemas_in import UploadAssetSpec
|
||||
from app.assets.api.schemas_out import Asset, AssetCreated
|
||||
from helpers import get_asset_filename
|
||||
|
||||
|
||||
def test_asset_created_inherits_hash_field():
|
||||
@ -23,18 +20,9 @@ def test_asset_created_inherits_hash_field():
|
||||
assert AssetCreated.model_fields["hash"].annotation == Asset.model_fields["hash"].annotation
|
||||
|
||||
|
||||
def test_upload_asset_spec_ignores_subfolder_field():
|
||||
spec = UploadAssetSpec.model_validate(
|
||||
{"tags": ["input"], "subfolder": "pasted", "name": "image.png"}
|
||||
)
|
||||
|
||||
assert "subfolder" not in UploadAssetSpec.model_fields
|
||||
assert not hasattr(spec, "subfolder")
|
||||
|
||||
|
||||
def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes):
|
||||
name = "dup_a.safetensors"
|
||||
tags = ["models", "model_type:checkpoints", "unit-tests", "alpha"]
|
||||
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||
meta = {"purpose": "dup"}
|
||||
data = make_asset_bytes(name)
|
||||
files = {"file": (name, data, "application/octet-stream")}
|
||||
@ -55,8 +43,6 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma
|
||||
assert a2["asset_hash"] == a1["asset_hash"]
|
||||
assert a2["hash"] == a1["hash"]
|
||||
assert a2["id"] != a1["id"] # new reference with same content
|
||||
assert a2.get("loader_path") is None
|
||||
assert a2.get("display_name") is None
|
||||
|
||||
# Third upload with the same data but different name also creates new AssetReference
|
||||
files = {"file": (name, data, "application/octet-stream")}
|
||||
@ -67,14 +53,12 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma
|
||||
assert a3["asset_hash"] == a1["asset_hash"]
|
||||
assert a3["id"] != a1["id"]
|
||||
assert a3["id"] != a2["id"]
|
||||
assert a3.get("loader_path") is None
|
||||
assert a3.get("display_name") is None
|
||||
|
||||
|
||||
def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str):
|
||||
# Seed a small file first
|
||||
name = "fastpath_seed.safetensors"
|
||||
tags = ["input", "unit-tests"]
|
||||
tags = ["models", "checkpoints", "unit-tests"]
|
||||
meta = {}
|
||||
files = {"file": (name, b"B" * 1024, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
|
||||
@ -85,10 +69,9 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
|
||||
assert b1["hash"] == h
|
||||
|
||||
# Now POST /api/assets with only hash and no file
|
||||
hash_only_tags = ["models", "checkpoints", "unit-tests", "hash-labels"]
|
||||
files = [
|
||||
("hash", (None, h)),
|
||||
("tags", (None, json.dumps(hash_only_tags))),
|
||||
("tags", (None, json.dumps(tags))),
|
||||
("name", (None, "fastpath_copy.safetensors")),
|
||||
("user_metadata", (None, json.dumps({"purpose": "copy"}))),
|
||||
]
|
||||
@ -98,53 +81,6 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
|
||||
assert b2["created_new"] is False
|
||||
assert b2["asset_hash"] == h
|
||||
assert b2["hash"] == h
|
||||
assert "models" in b2["tags"]
|
||||
assert "checkpoints" in b2["tags"]
|
||||
assert "uploaded" not in b2["tags"]
|
||||
assert not any(tag.startswith("model_type:") for tag in b2["tags"])
|
||||
assert b2.get("loader_path") is None
|
||||
assert b2.get("display_name") is None
|
||||
|
||||
rg = http.get(f"{api_base}/api/assets/{b2['id']}", timeout=120)
|
||||
detail = rg.json()
|
||||
assert rg.status_code == 200, detail
|
||||
assert detail.get("loader_path") is None
|
||||
assert detail.get("display_name") is None
|
||||
|
||||
|
||||
def test_create_from_hash_with_model_tags_does_not_synthesize_loader_path(
|
||||
http: requests.Session, api_base: str
|
||||
):
|
||||
seed_name = "from_hash_seed.safetensors"
|
||||
seed_tags = ["models", "model_type:checkpoints", "unit-tests"]
|
||||
files = {"file": (seed_name, b"D" * 1024, "application/octet-stream")}
|
||||
form = {
|
||||
"tags": json.dumps(seed_tags),
|
||||
"name": seed_name,
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
seed_r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
seed = seed_r.json()
|
||||
assert seed_r.status_code == 201, seed
|
||||
|
||||
payload = {
|
||||
"hash": seed["asset_hash"],
|
||||
"name": "from_hash_copy.safetensors",
|
||||
"tags": ["models", "model_type:checkpoints", "unit-tests", "spoofed"],
|
||||
}
|
||||
created_r = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120)
|
||||
created = created_r.json()
|
||||
assert created_r.status_code == 201, created
|
||||
assert created["created_new"] is False
|
||||
assert created["asset_hash"] == seed["asset_hash"]
|
||||
assert created.get("loader_path") is None
|
||||
assert created.get("display_name") is None
|
||||
|
||||
detail_r = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
|
||||
detail = detail_r.json()
|
||||
assert detail_r.status_code == 200, detail
|
||||
assert detail.get("loader_path") is None
|
||||
assert detail.get("display_name") is None
|
||||
|
||||
|
||||
def test_upload_fastpath_with_known_hash_and_file(
|
||||
@ -152,7 +88,7 @@ def test_upload_fastpath_with_known_hash_and_file(
|
||||
):
|
||||
# Seed
|
||||
files = {"file": ("seed.safetensors", b"C" * 128, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})}
|
||||
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 201, b1
|
||||
@ -168,49 +104,11 @@ def test_upload_fastpath_with_known_hash_and_file(
|
||||
assert b2["created_new"] is False
|
||||
assert b2["asset_hash"] == h
|
||||
assert b2["hash"] == h
|
||||
assert "checkpoints" in b2["tags"]
|
||||
assert "uploaded" not in b2["tags"]
|
||||
assert not any(tag == "model_type:checkpoints" for tag in b2["tags"])
|
||||
|
||||
|
||||
def test_duplicate_byte_upload_is_reference_only_and_does_not_need_destination(
|
||||
http: requests.Session, api_base: str
|
||||
):
|
||||
data = b"duplicate-reference-only" * 64
|
||||
seed_files = {"file": ("duplicate-seed.bin", data, "application/octet-stream")}
|
||||
seed_form = {
|
||||
"tags": json.dumps(["input", "unit-tests", "duplicate-seed"]),
|
||||
"name": "duplicate-seed.bin",
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
seed_response = http.post(api_base + "/api/assets", data=seed_form, files=seed_files, timeout=120)
|
||||
seed = seed_response.json()
|
||||
assert seed_response.status_code == 201, seed
|
||||
|
||||
duplicate_files = {"file": ("duplicate-copy.bin", data, "application/octet-stream")}
|
||||
duplicate_form = {
|
||||
"tags": json.dumps(["not-a-destination", "unit-tests", "duplicate-copy"]),
|
||||
"name": "duplicate-copy.bin",
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
duplicate_response = http.post(
|
||||
api_base + "/api/assets", data=duplicate_form, files=duplicate_files, timeout=120
|
||||
)
|
||||
duplicate = duplicate_response.json()
|
||||
|
||||
assert duplicate_response.status_code == 200, duplicate
|
||||
assert duplicate["created_new"] is False
|
||||
assert duplicate["asset_hash"] == seed["asset_hash"]
|
||||
assert "not-a-destination" in duplicate["tags"]
|
||||
assert "uploaded" not in duplicate["tags"]
|
||||
assert "input" not in duplicate["tags"]
|
||||
assert duplicate.get("loader_path") is None
|
||||
assert duplicate.get("display_name") is None
|
||||
|
||||
|
||||
def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str):
|
||||
data = [
|
||||
("tags", "models,model_type:checkpoints"),
|
||||
("tags", "models,checkpoints"),
|
||||
("tags", json.dumps(["unit-tests", "alpha"])),
|
||||
("name", "merge.safetensors"),
|
||||
("user_metadata", json.dumps({"u": 1})),
|
||||
@ -226,71 +124,7 @@ def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base
|
||||
detail = rg.json()
|
||||
assert rg.status_code == 200, detail
|
||||
tags = set(detail["tags"])
|
||||
assert {"models", "model_type:checkpoints", "unit-tests", "alpha"}.issubset(tags)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
(
|
||||
"tags",
|
||||
"extension",
|
||||
"expected_display_prefix",
|
||||
),
|
||||
[
|
||||
(["input", "unit-tests"], ".png", ""),
|
||||
(
|
||||
["models", "model_type:checkpoints", "unit-tests"],
|
||||
".safetensors",
|
||||
"checkpoints/",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_upload_response_includes_loader_path_and_display_name(
|
||||
tags: list[str],
|
||||
extension: str,
|
||||
expected_display_prefix: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
make_asset_bytes,
|
||||
):
|
||||
scope = f"response-paths-{uuid.uuid4().hex[:6]}"
|
||||
scoped_tags = [*tags, scope]
|
||||
name = f"asset_response_path{extension}"
|
||||
|
||||
files = {"file": (name, make_asset_bytes(name, 1024), "application/octet-stream")}
|
||||
form = {
|
||||
"tags": json.dumps(scoped_tags),
|
||||
"name": name,
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
created_r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
created = created_r.json()
|
||||
assert created_r.status_code in (200, 201), created
|
||||
stored_filename = get_asset_filename(created["asset_hash"], extension)
|
||||
expected_suffix = stored_filename
|
||||
expected_display_name = f"{expected_display_prefix}{expected_suffix}"
|
||||
# In-root loader path: model category dropped, no subfolders here -> just the filename.
|
||||
expected_loader_path = expected_suffix
|
||||
|
||||
assert created["loader_path"] == expected_loader_path
|
||||
assert created["display_name"] == expected_display_name
|
||||
assert "logical_path" not in created
|
||||
|
||||
detail_r = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
|
||||
detail = detail_r.json()
|
||||
assert detail_r.status_code == 200, detail
|
||||
assert detail["loader_path"] == expected_loader_path
|
||||
assert detail["display_name"] == expected_display_name
|
||||
|
||||
list_r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "limit": "50"},
|
||||
timeout=120,
|
||||
)
|
||||
listed = list_r.json()
|
||||
assert list_r.status_code == 200, listed
|
||||
match = next(a for a in listed["assets"] if a["id"] == created["id"])
|
||||
assert match["loader_path"] == expected_loader_path
|
||||
assert match["display_name"] == expected_display_name
|
||||
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
@ -358,55 +192,16 @@ def test_create_from_hash_endpoint_404(http: requests.Session, api_base: str):
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
|
||||
def test_create_from_hash_accepts_arbitrary_system_looking_tags(
|
||||
http: requests.Session, api_base: str
|
||||
):
|
||||
files = {"file": ("hash-seed.bin", b"hash-seed" * 64, "application/octet-stream")}
|
||||
form = {
|
||||
"tags": json.dumps(["input", "unit-tests", "hash-seed"]),
|
||||
"name": "hash-seed.bin",
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
seed_response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
seed = seed_response.json()
|
||||
assert seed_response.status_code == 201, seed
|
||||
|
||||
response = http.post(
|
||||
api_base + "/api/assets/from-hash",
|
||||
json={
|
||||
"hash": seed["asset_hash"],
|
||||
"name": "hash-copy.bin",
|
||||
"tags": [
|
||||
"models",
|
||||
"model:true",
|
||||
"models:foo",
|
||||
"temporary:true",
|
||||
"unit-tests",
|
||||
"hash-copy",
|
||||
],
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
body = response.json()
|
||||
|
||||
assert response.status_code == 201, body
|
||||
assert "models" in body["tags"]
|
||||
assert "model:true" in body["tags"]
|
||||
assert "models:foo" in body["tags"]
|
||||
assert "temporary:true" in body["tags"]
|
||||
assert "uploaded" not in body["tags"]
|
||||
|
||||
|
||||
def test_upload_zero_byte_rejected(http: requests.Session, api_base: str):
|
||||
files = {"file": ("empty.safetensors", b"", "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "EMPTY_UPLOAD"
|
||||
|
||||
|
||||
def test_upload_rejects_arbitrary_labels_without_required_destination_role(http: requests.Session, api_base: str):
|
||||
def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str):
|
||||
files = {"file": ("badroot.bin", b"A" * 64, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["not-a-root", "whatever"]), "name": "badroot.bin", "user_metadata": json.dumps({})}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
@ -417,7 +212,7 @@ def test_upload_rejects_arbitrary_labels_without_required_destination_role(http:
|
||||
|
||||
def test_upload_user_metadata_must_be_json(http: requests.Session, api_base: str):
|
||||
files = {"file": ("badmeta.bin", b"A" * 128, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
@ -433,7 +228,7 @@ def test_upload_requires_multipart(http: requests.Session, api_base: str):
|
||||
|
||||
def test_upload_missing_file_and_hash(http: requests.Session, api_base: str):
|
||||
files = [
|
||||
("tags", (None, json.dumps(["models", "model_type:checkpoints", "unit-tests"]))),
|
||||
("tags", (None, json.dumps(["models", "checkpoints", "unit-tests"]))),
|
||||
("name", (None, "x.safetensors")),
|
||||
]
|
||||
r = http.post(api_base + "/api/assets", files=files, timeout=120)
|
||||
@ -442,33 +237,17 @@ def test_upload_missing_file_and_hash(http: requests.Session, api_base: str):
|
||||
assert body["error"]["code"] == "MISSING_FILE"
|
||||
|
||||
|
||||
def test_upload_models_unknown_model_type(http: requests.Session, api_base: str):
|
||||
def test_upload_models_unknown_category(http: requests.Session, api_base: str):
|
||||
files = {"file": ("m.safetensors", b"A" * 128, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "model_type:no_such_category", "unit-tests"]), "name": "m.safetensors"}
|
||||
form = {"tags": json.dumps(["models", "no_such_category", "unit-tests"]), "name": "m.safetensors"}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400, body
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
assert body["error"]["message"].startswith("unknown models category")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_type", ["configs", "custom_nodes"])
|
||||
def test_upload_models_rejects_non_model_registered_folder(
|
||||
model_type: str, http: requests.Session, api_base: str
|
||||
):
|
||||
files = {"file": ("not-a-model.py", b"A" * 128, "application/octet-stream")}
|
||||
form = {
|
||||
"tags": json.dumps(["models", f"model_type:{model_type}", "unit-tests"]),
|
||||
"name": "not-a-model.py",
|
||||
}
|
||||
|
||||
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = response.json()
|
||||
|
||||
assert response.status_code == 400, body
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
def test_upload_models_requires_model_type(http: requests.Session, api_base: str):
|
||||
def test_upload_models_requires_category(http: requests.Session, api_base: str):
|
||||
files = {"file": ("nocat.safetensors", b"A" * 64, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models"]), "name": "nocat.safetensors", "user_metadata": json.dumps({})}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
@ -477,152 +256,13 @@ def test_upload_models_requires_model_type(http: requests.Session, api_base: str
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
def test_upload_extra_tags_are_labels_not_path_components(http: requests.Session, api_base: str):
|
||||
def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
|
||||
files = {"file": ("evil.safetensors", b"A" * 256, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 201, body
|
||||
assert ".." in body["tags"]
|
||||
assert "zzz" in body["tags"]
|
||||
assert "models" in body["tags"]
|
||||
assert "model_type:checkpoints" in body["tags"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("subfolder", "expected_tag", "unexpected_tags"),
|
||||
[
|
||||
("custom/session", None, {"custom", "session"}),
|
||||
("pasted", "pasted", set()),
|
||||
],
|
||||
)
|
||||
def test_upload_image_accepts_arbitrary_subfolder_but_only_known_values_become_tags(
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
subfolder: str,
|
||||
expected_tag: str | None,
|
||||
unexpected_tags: set[str],
|
||||
):
|
||||
name = f"upload-image-{uuid.uuid4().hex}.png"
|
||||
files = {"image": (name, b"image-upload" * 64, "image/png")}
|
||||
form = {"type": "input", "subfolder": subfolder}
|
||||
|
||||
response = http.post(api_base + "/upload/image", data=form, files=files, timeout=120)
|
||||
body = response.json()
|
||||
|
||||
assert response.status_code == 200, body
|
||||
assert body["subfolder"] == subfolder
|
||||
assert (comfy_tmp_base_dir / "input" / subfolder / body["name"]).exists()
|
||||
|
||||
asset = body["asset"]
|
||||
tags = set(asset["tags"])
|
||||
assert "input" in tags
|
||||
assert "uploaded" in tags
|
||||
if expected_tag:
|
||||
assert expected_tag in tags
|
||||
assert tags.isdisjoint(unexpected_tags)
|
||||
|
||||
|
||||
def test_multipart_upload_accepts_system_looking_extra_labels(
|
||||
http: requests.Session, api_base: str
|
||||
):
|
||||
files = {"file": ("relaxed-labels.bin", b"relaxed" * 64, "application/octet-stream")}
|
||||
form = {
|
||||
"tags": json.dumps(
|
||||
[
|
||||
"input",
|
||||
"unit-tests",
|
||||
"model:true",
|
||||
"models:foo",
|
||||
"temporary",
|
||||
"uploaded:true",
|
||||
]
|
||||
),
|
||||
"name": "relaxed-labels.bin",
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = response.json()
|
||||
|
||||
assert response.status_code == 201, body
|
||||
assert "input" in body["tags"]
|
||||
assert "model:true" in body["tags"]
|
||||
assert "models:foo" in body["tags"]
|
||||
assert "temporary" in body["tags"]
|
||||
assert "uploaded:true" in body["tags"]
|
||||
|
||||
|
||||
def test_multipart_upload_rejects_ambiguous_destination_roles(
|
||||
http: requests.Session, api_base: str
|
||||
):
|
||||
files = {"file": ("ambiguous.bin", b"ambiguous" * 64, "application/octet-stream")}
|
||||
form = {
|
||||
"tags": json.dumps(["input", "output", "unit-tests"]),
|
||||
"name": "ambiguous.bin",
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = response.json()
|
||||
|
||||
assert response.status_code == 400, body
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
def test_multipart_upload_rejects_multiple_model_types_for_models_destination(
|
||||
http: requests.Session, api_base: str
|
||||
):
|
||||
files = {"file": ("ambiguous-model.safetensors", b"ambiguous-model" * 64, "application/octet-stream")}
|
||||
form = {
|
||||
"tags": json.dumps(
|
||||
["models", "model_type:checkpoints", "model_type:loras", "unit-tests"]
|
||||
),
|
||||
"name": "ambiguous-model.safetensors",
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = response.json()
|
||||
|
||||
assert response.status_code == 400, body
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tags", "expected_root", "extension"),
|
||||
[
|
||||
(["input", "unit-tests", "upload-location-input"], "input", ".bin"),
|
||||
(["output", "unit-tests", "upload-location-output"], "output", ".bin"),
|
||||
(
|
||||
["models", "model_type:checkpoints", "unit-tests", "upload-location-model"],
|
||||
"models/checkpoints",
|
||||
".safetensors",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_multipart_upload_role_selects_write_location(
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
tags: list[str],
|
||||
expected_root: str,
|
||||
extension: str,
|
||||
):
|
||||
role = next(tag for tag in tags if tag in {"input", "models", "output"})
|
||||
name = f"{role}-role-upload{extension}"
|
||||
files = {"file": (name, f"{role}-role-bytes".encode() * 64, "application/octet-stream")}
|
||||
form = {
|
||||
"tags": json.dumps(tags),
|
||||
"name": name,
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
|
||||
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = response.json()
|
||||
|
||||
assert response.status_code == 201, body
|
||||
stored_name = get_asset_filename(body["asset_hash"], extension)
|
||||
expected_disk_path = comfy_tmp_base_dir / expected_root / stored_name
|
||||
assert expected_disk_path.exists()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
||||
|
||||
|
||||
def test_upload_empty_tags_rejected(http: requests.Session, api_base: str):
|
||||
|
||||
204
tests-unit/comfy_api_test/io_dynamic_group_test.py
Normal file
204
tests-unit/comfy_api_test/io_dynamic_group_test.py
Normal file
@ -0,0 +1,204 @@
|
||||
"""Unit tests for io.DynamicGroup: expansion/reconstruction (0-row and N-row cases)."""
|
||||
import sys
|
||||
import types
|
||||
import pytest
|
||||
|
||||
# Stub torch (type-hint only in _io.py; real torch not available in unit-test env)
|
||||
if "torch" not in sys.modules:
|
||||
_torch_stub = types.ModuleType("torch")
|
||||
_torch_stub.Tensor = object # type: ignore[attr-defined]
|
||||
sys.modules["torch"] = _torch_stub
|
||||
|
||||
from comfy_api.latest._io import ( # noqa: E402
|
||||
DynamicGroup,
|
||||
Float,
|
||||
Int,
|
||||
String,
|
||||
Boolean,
|
||||
get_finalized_class_inputs,
|
||||
build_nested_inputs,
|
||||
create_input_dict_v1,
|
||||
setup_dynamic_input_funcs,
|
||||
)
|
||||
|
||||
# Make sure dynamic input funcs are registered (may already be done at import time)
|
||||
setup_dynamic_input_funcs()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_class_inputs(group_input: DynamicGroup.Input) -> dict:
|
||||
"""Wrap a DynamicGroup.Input into the required/optional dict structure."""
|
||||
return create_input_dict_v1([group_input])
|
||||
|
||||
|
||||
def _run(group_input: DynamicGroup.Input, live_values: dict) -> dict:
|
||||
"""End-to-end helper: expand schema + reconstruct values.
|
||||
|
||||
Mirrors the production split in execution.py:
|
||||
1. get_finalized_class_inputs (schema expansion, line 162)
|
||||
2. build_nested_inputs (value reconstruction, line 281)
|
||||
|
||||
The two steps are separate in production because the engine resolves
|
||||
linked node outputs between them, but in tests we supply values directly.
|
||||
"""
|
||||
class_inputs = _make_class_inputs(group_input)
|
||||
_, _, v3_data = get_finalized_class_inputs(class_inputs, live_values)
|
||||
return build_nested_inputs(dict(live_values), v3_data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDynamicGroupInputConstruction:
|
||||
def test_basic_construction(self):
|
||||
inp = DynamicGroup.Input(
|
||||
"loras",
|
||||
template=[
|
||||
Float.Input("strength", default=1.0),
|
||||
String.Input("name"),
|
||||
],
|
||||
min=0,
|
||||
max=10,
|
||||
)
|
||||
assert inp.id == "loras"
|
||||
assert inp.min == 0
|
||||
assert inp.max == 10
|
||||
assert len(inp.template) == 2
|
||||
|
||||
def test_get_all_includes_self_and_template(self):
|
||||
inp = DynamicGroup.Input(
|
||||
"items",
|
||||
template=[Float.Input("value")],
|
||||
)
|
||||
all_inputs = inp.get_all()
|
||||
assert all_inputs[0] is inp
|
||||
assert all_inputs[1].id == "value"
|
||||
|
||||
def test_as_dict_has_template_min_max(self):
|
||||
inp = DynamicGroup.Input(
|
||||
"items",
|
||||
template=[Float.Input("val", default=0.5)],
|
||||
min=1,
|
||||
max=5,
|
||||
)
|
||||
d = inp.as_dict()
|
||||
assert "template" in d
|
||||
assert d["min"] == 1
|
||||
assert d["max"] == 5
|
||||
|
||||
def test_duplicate_field_ids_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input(
|
||||
"bad",
|
||||
template=[Float.Input("x"), Float.Input("x")],
|
||||
)
|
||||
|
||||
def test_empty_template_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input("bad", template=[])
|
||||
|
||||
def test_min_gt_max_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input("bad", template=[Float.Input("x")], min=5, max=3)
|
||||
|
||||
def test_max_exceeds_limit_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input("bad", template=[Float.Input("x")], max=101)
|
||||
|
||||
def test_dynamic_input_in_template_raises(self):
|
||||
with pytest.raises(AssertionError):
|
||||
DynamicGroup.Input(
|
||||
"bad",
|
||||
template=[DynamicGroup.Input("nested", template=[Float.Input("x")])],
|
||||
)
|
||||
|
||||
def test_validate_calls_through(self):
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", min=-1.0, max=1.0)])
|
||||
inp.validate() # should not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 0-row case
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestZeroRows:
|
||||
def test_empty_live_inputs_produces_empty_list(self):
|
||||
"""With min=0 and no live values, the result should be an empty list."""
|
||||
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
|
||||
assert _run(inp, {}).get("loras") == []
|
||||
|
||||
def test_min_zero_with_values(self):
|
||||
"""min=0 but 2 rows of live data."""
|
||||
inp = DynamicGroup.Input("loras", template=[Float.Input("strength", default=1.0)], min=0, max=10)
|
||||
result = _run(inp, {"loras.0.strength": 0.8, "loras.1.strength": 0.5})
|
||||
assert result["loras"] == [{"strength": 0.8}, {"strength": 0.5}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# N-row case
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNRows:
|
||||
def test_two_rows_two_fields(self):
|
||||
"""Two rows with two fields each produce a list[dict]."""
|
||||
inp = DynamicGroup.Input(
|
||||
"loras",
|
||||
template=[String.Input("lora_name"), Float.Input("strength", default=1.0)],
|
||||
min=0, max=50,
|
||||
)
|
||||
result = _run(inp, {
|
||||
"loras.0.lora_name": "model_a.safetensors", "loras.0.strength": 0.9,
|
||||
"loras.1.lora_name": "model_b.safetensors", "loras.1.strength": 0.4,
|
||||
})
|
||||
assert result["loras"] == [
|
||||
{"lora_name": "model_a.safetensors", "strength": 0.9},
|
||||
{"lora_name": "model_b.safetensors", "strength": 0.4},
|
||||
]
|
||||
|
||||
def test_rows_are_sorted_by_index(self):
|
||||
"""Rows must be in ascending index order even if dict iteration is unordered."""
|
||||
inp = DynamicGroup.Input("items", template=[Int.Input("v", default=0)], min=0, max=10)
|
||||
result = _run(inp, {"items.0.v": 10, "items.2.v": 30, "items.1.v": 20})
|
||||
assert [row["v"] for row in result["items"]] == [10, 20, 30]
|
||||
|
||||
def test_min_rows_schema_slots(self):
|
||||
"""With min=2 and no live data, 2 slots must appear in the expanded schema."""
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
out, _, _ = get_finalized_class_inputs(_make_class_inputs(inp), {})
|
||||
all_slots = {**out.get("required", {}), **out.get("optional", {})}
|
||||
assert "items.0.val" in all_slots
|
||||
assert "items.1.val" in all_slots
|
||||
|
||||
def test_min_rows_reconstructs_when_no_values(self):
|
||||
"""min=2 with NO live values must still yield a 2-element list,
|
||||
not collapse to [] (regression: parent-path clobber)."""
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
result = _run(inp, {})
|
||||
assert len(result["items"]) == 2
|
||||
assert all("val" in row for row in result["items"])
|
||||
|
||||
def test_min_rows_reconstructs_with_partial_values(self):
|
||||
"""min=2 with only the first row's value present still yields 2 rows."""
|
||||
inp = DynamicGroup.Input("items", template=[Float.Input("val", default=0.0)], min=2, max=5)
|
||||
result = _run(inp, {"items.0.val": 0.7})
|
||||
assert len(result["items"]) == 2
|
||||
assert result["items"][0]["val"] == 0.7
|
||||
assert result["items"][1]["val"] is None
|
||||
|
||||
def test_list_paths_in_v3_data(self):
|
||||
"""list_paths must contain the group id so build_nested_inputs knows to convert."""
|
||||
inp = DynamicGroup.Input("things", template=[Boolean.Input("flag")], min=0, max=5)
|
||||
_, _, v3_data = get_finalized_class_inputs(_make_class_inputs(inp), {})
|
||||
assert "things" in v3_data.get("list_paths", set())
|
||||
|
||||
def test_no_leftover_flat_keys(self):
|
||||
"""Flat keys must be consumed; only the reconstructed list remains."""
|
||||
inp = DynamicGroup.Input("rows", template=[Float.Input("x", default=0.0)], min=0, max=5)
|
||||
result = _run(inp, {"rows.0.x": 1.0, "rows.1.x": 2.0})
|
||||
assert "rows.0.x" not in result
|
||||
assert "rows.1.x" not in result
|
||||
assert isinstance(result["rows"], list)
|
||||
@ -29,8 +29,6 @@ class TestFeatureFlags:
|
||||
features = get_server_features()
|
||||
assert "supports_preview_metadata" in features
|
||||
assert features["supports_preview_metadata"] is True
|
||||
assert "supports_model_type_tags" in features
|
||||
assert features["supports_model_type_tags"] is True
|
||||
assert "max_upload_size" in features
|
||||
assert isinstance(features["max_upload_size"], (int, float))
|
||||
|
||||
|
||||
@ -12,8 +12,6 @@ class TestWebSocketFeatureFlags:
|
||||
# Check expected server features
|
||||
assert "supports_preview_metadata" in features
|
||||
assert features["supports_preview_metadata"] is True
|
||||
assert "supports_model_type_tags" in features
|
||||
assert features["supports_model_type_tags"] is True
|
||||
assert "max_upload_size" in features
|
||||
assert isinstance(features["max_upload_size"], (int, float))
|
||||
|
||||
@ -77,5 +75,3 @@ class TestWebSocketFeatureFlags:
|
||||
assert server_message["type"] == "feature_flags"
|
||||
assert "supports_preview_metadata" in server_message["data"]
|
||||
assert server_message["data"]["supports_preview_metadata"] is True
|
||||
assert "supports_model_type_tags" in server_message["data"]
|
||||
assert server_message["data"]["supports_model_type_tags"] is True
|
||||
|
||||
Reference in New Issue
Block a user