Compare commits

..

2 Commits

Author SHA1 Message Date
bf00c39705 Don't instantiate nodes during validation
Addresses review feedback: the V1 executability check fell back to
constructing the node (class_def()) when the FUNCTION method wasn't found on
the class. That runs __init__ during validation, so a constructor's side
effects or failure could be misreported as invalid_node_definition for an
otherwise valid node.

Inspect only the class. No core/extra node defines its FUNCTION method on the
instance, so this loses no real coverage while removing the side-effect risk.

Replace the instance-fallback test with one asserting a node with a raising
__init__ but a valid class-level method still passes validation (i.e. it is
never instantiated).
2026-06-26 16:04:29 -07:00
82c954bd2a Validate that a node is executable before running the prompt
A node whose FUNCTION points at a method that does not exist (e.g. a typo in
a custom node), or a V3 node missing its execute override, was only detected
once that node ran -- after every upstream node had already executed. In a
multi-node workflow the user waited for the whole graph to run up to the
broken node before seeing the error.

validate_prompt already walks every node before execution; add an
executability check there so the error is reported up front and attributed
to the offending node (returned in node_errors), and nothing runs.

The check resolves the V1 FUNCTION method on the class (the common case) and
falls back to an instance, since the runtime invokes it on an instance and a
node may define FUNCTION or its method in __init__. V3 nodes are checked via
their existing VALIDATE_CLASS.

Add tests for V1 typo, V3 typo, good nodes, and a node whose method is
defined in __init__ (must not be falsely rejected).
2026-06-26 15:53:34 -07:00
31 changed files with 381 additions and 1379 deletions

View File

@ -1,107 +0,0 @@
"""
Allow case-sensitive tag names.
Revision ID: 0005_allow_case_sensitive_tags
Revises: 0004_drop_tag_type
Create Date: 2026-06-16
"""
import sqlalchemy as sa
from alembic import op
revision = "0005_allow_case_sensitive_tags"
down_revision = "0004_drop_tag_type"
branch_labels = None
depends_on = None
def upgrade() -> None:
bind = op.get_bind()
if bind.dialect.name == "sqlite":
# SQLite cannot ALTER/DROP CHECK constraints. Recreate the small tag
# vocabulary table without the lowercase constraint while preserving
# existing tag names.
op.execute("PRAGMA foreign_keys=OFF")
try:
op.execute(
"CREATE TABLE tags_new ("
"name VARCHAR(512) NOT NULL, "
"CONSTRAINT pk_tags PRIMARY KEY (name)"
")"
)
op.execute("INSERT INTO tags_new(name) SELECT name FROM tags")
op.execute("DROP TABLE tags")
op.execute("ALTER TABLE tags_new RENAME TO tags")
finally:
op.execute("PRAGMA foreign_keys=ON")
return
op.drop_constraint("ck_tags_ck_tags_lowercase", "tags", type_="check")
def downgrade() -> None:
# Existing mixed-case tags cannot satisfy the old constraint. Lowercase them
# before restoring it, merging duplicate vocabulary/link rows that collide.
bind = op.get_bind()
tag_names = [row[0] for row in bind.execute(sa.text("SELECT name FROM tags"))]
existing_names = set(tag_names)
lowercase_names = sorted({name.lower() for name in tag_names})
missing_lowercase_rows = [
{"name": name} for name in lowercase_names if name not in existing_names
]
if missing_lowercase_rows:
bind.execute(sa.text("INSERT INTO tags(name) VALUES (:name)"), missing_lowercase_rows)
link_rows = bind.execute(
sa.text(
"SELECT asset_reference_id, tag_name, origin, added_at "
"FROM asset_reference_tags "
"ORDER BY asset_reference_id, tag_name"
)
).mappings()
deduped_links = {}
for row in link_rows:
key = (row["asset_reference_id"], row["tag_name"].lower())
deduped_links.setdefault(
key,
{
"asset_reference_id": row["asset_reference_id"],
"tag_name": row["tag_name"].lower(),
"origin": row["origin"],
"added_at": row["added_at"],
},
)
op.execute("DELETE FROM asset_reference_tags")
if deduped_links:
bind.execute(
sa.text(
"INSERT INTO asset_reference_tags "
"(asset_reference_id, tag_name, origin, added_at) "
"VALUES (:asset_reference_id, :tag_name, :origin, :added_at)"
),
list(deduped_links.values()),
)
op.execute("DELETE FROM tags WHERE name != lower(name)")
if bind.dialect.name == "sqlite":
op.execute("PRAGMA foreign_keys=OFF")
try:
op.execute(
"CREATE TABLE tags_new ("
"name VARCHAR(512) NOT NULL, "
"CONSTRAINT pk_tags PRIMARY KEY (name), "
"CONSTRAINT ck_tags_lowercase CHECK (name = lower(name))"
")"
)
op.execute("INSERT INTO tags_new(name) SELECT name FROM tags")
op.execute("DROP TABLE tags")
op.execute("ALTER TABLE tags_new RENAME TO tags")
finally:
op.execute("PRAGMA foreign_keys=ON")
return
op.create_check_constraint(
"ck_tags_ck_tags_lowercase", "tags", "name = lower(name)"
)

View File

@ -10,6 +10,7 @@ from typing import Any
from aiohttp import web
from pydantic import ValidationError
import folder_paths
from app import user_manager
from app.assets.api import schemas_in, schemas_out
from app.assets.services import schemas
@ -39,7 +40,6 @@ from app.assets.services import (
upload_from_temp_path,
)
from app.assets.services.cursor import InvalidCursorError
from app.assets.services.path_utils import compute_asset_response_paths
from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef()
@ -161,18 +161,11 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
preview_url = None
else:
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
if result.ref.file_path:
paths = compute_asset_response_paths(result.ref.file_path)
file_path, display_name = paths if paths else (None, None)
else:
file_path, display_name = None, None
asset_content_hash = result.asset.hash if result.asset else None
return schemas_out.Asset(
id=result.ref.id,
name=result.ref.name,
hash=asset_content_hash,
file_path=file_path,
display_name=display_name,
asset_hash=asset_content_hash,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
@ -423,6 +416,17 @@ async def upload_asset(request: web.Request) -> web.Response:
400, "INVALID_BODY", f"Validation failed: {ve.json()}"
)
if spec.tags and spec.tags[0] == "models":
if (
len(spec.tags) < 2
or spec.tags[1] not in folder_paths.folder_names_and_paths
):
delete_temp_file_if_exists(parsed.tmp_path)
category = spec.tags[1] if len(spec.tags) >= 2 else ""
return _build_error_response(
400, "INVALID_BODY", f"unknown models category '{category}'"
)
try:
# Fast path: hash exists, create AssetReference without writing anything
if spec.hash and parsed.provided_hash_exists is True:
@ -466,7 +470,7 @@ async def upload_asset(request: web.Request) -> web.Response:
return _build_error_response(400, e.code, str(e))
except ValueError as e:
delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, "INVALID_BODY", str(e))
return _build_error_response(400, "BAD_REQUEST", str(e))
except HashMismatchError as e:
delete_temp_file_if_exists(parsed.tmp_path)
return _build_error_response(400, "HASH_MISMATCH", str(e))

View File

@ -140,7 +140,7 @@ class CreateFromHashBody(BaseModel):
if v is None:
return []
if isinstance(v, list):
out = [str(t).strip() for t in v if str(t).strip()]
out = [str(t).strip().lower() for t in v if str(t).strip()]
seen = set()
dedup = []
for t in out:
@ -149,7 +149,7 @@ class CreateFromHashBody(BaseModel):
dedup.append(t)
return dedup
if isinstance(v, str):
return list(dict.fromkeys(t.strip() for t in v.split(",") if t.strip()))
return [t.strip().lower() for t in v.split(",") if t.strip()]
return []
@ -206,7 +206,7 @@ class TagsListQuery(BaseModel):
if v is None:
return v
v = v.strip()
return v or None
return v.lower() or None
class TagsAdd(BaseModel):
@ -220,7 +220,7 @@ class TagsAdd(BaseModel):
for t in v:
if not isinstance(t, str):
raise TypeError("tags must be strings")
tnorm = t.strip()
tnorm = t.strip().lower()
if tnorm:
out.append(tnorm)
seen = set()
@ -239,8 +239,8 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: labels plus one destination role ('models'|'input'|'output') for new bytes;
if role == 'models', exactly one model_type:<folder_name> tag is required
- tags: optional list; if provided, first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
@ -309,7 +309,7 @@ class UploadAssetSpec(BaseModel):
norm = []
seen = set()
for t in items:
tnorm = str(t).strip()
tnorm = str(t).strip().lower()
if tnorm and tnorm not in seen:
seen.add(tnorm)
norm.append(tnorm)
@ -335,4 +335,14 @@ class UploadAssetSpec(BaseModel):
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
raise ValueError("at least one tag is required for uploads")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError(
"models uploads require a category tag as the second tag"
)
return self

View File

@ -9,20 +9,8 @@ class Asset(BaseModel):
``id`` here is the AssetReference id, not the content-addressed Asset id."""
id: str
name: str = Field(
...,
deprecated=True,
description="Reference label, often caller-provided or derived from the filename. Deprecated for storage path/display semantics; use `file_path` and `display_name` when present.",
)
name: str
hash: str | None = None
file_path: str | None = Field(
default=None,
description="Runtime storage locator for filesystem-backed assets, using Comfy storage namespaces such as `input/`, `output/`, `temp/`, or `models/`. Not an absolute filesystem path, unique identity, or model loader path.",
)
display_name: str | None = Field(
default=None,
description="Human-facing label derived from `file_path`, usually the path below the top-level storage namespace. Not unique.",
)
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None

View File

@ -140,6 +140,7 @@ async def parse_multipart_upload(
provided_mime_type = ((await field.text()) or "").strip() or None
elif fname == "preview_id":
provided_preview_id = ((await field.text()) or "").strip() or None
if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError(
400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'."

View File

@ -265,8 +265,6 @@ def list_tags_with_usage(
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
prefix_filter = prefix.strip() if prefix else ""
counts_sq = (
select(
AssetReferenceTag.tag_name.label("tag_name"),
@ -295,8 +293,9 @@ def list_tags_with_usage(
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix_filter:
q = q.where(func.substr(Tag.name, 1, len(prefix_filter)) == prefix_filter)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
@ -307,8 +306,9 @@ def list_tags_with_usage(
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix_filter:
total_q = total_q.where(func.substr(Tag.name, 1, len(prefix_filter)) == prefix_filter)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
visible_tags_sq = (
select(AssetReferenceTag.tag_name)

View File

@ -41,10 +41,10 @@ def get_utc_now() -> datetime:
def normalize_tags(tags: list[str] | None) -> list[str]:
"""
Normalize a list of tags by:
- Stripping whitespace.
- Removing exact duplicates while preserving order and case.
- Stripping whitespace and converting to lowercase.
- Removing duplicates.
"""
return list(dict.fromkeys(t.strip() for t in (tags or []) if (t or "").strip()))
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
def validate_blake3_hash(s: str) -> str:

View File

@ -35,7 +35,6 @@ from app.assets.services.image_dimensions import extract_image_dimensions
from app.assets.services.path_utils import (
compute_relative_filename,
get_name_and_tags_from_asset_path,
get_path_derived_tags_from_path,
resolve_destination_from_tags,
validate_path_within_base,
)
@ -102,32 +101,17 @@ def _ingest_file_from_path(
if preview_id and ref.preview_id != preview_id:
ref.preview_id = preview_id
try:
backend_tags = get_path_derived_tags_from_path(locator)
except ValueError:
backend_tags = []
caller_tags = normalize_tags(tags)
backend_tags = normalize_tags(backend_tags)
all_tags = normalize_tags([*caller_tags, *backend_tags])
if all_tags:
norm = normalize_tags(list(tags))
if norm:
if require_existing_tags:
validate_tags_exist(session, all_tags)
if backend_tags:
add_tags_to_reference(
session,
reference_id=reference_id,
tags=backend_tags,
origin="automatic",
create_if_missing=not require_existing_tags,
)
if caller_tags:
add_tags_to_reference(
session,
reference_id=reference_id,
tags=caller_tags,
origin=tag_origin,
create_if_missing=not require_existing_tags,
)
validate_tags_exist(session, norm)
add_tags_to_reference(
session,
reference_id=reference_id,
tags=norm,
origin=tag_origin,
create_if_missing=not require_existing_tags,
)
_update_metadata_with_filename(
session,
@ -490,10 +474,6 @@ def upload_from_temp_path(
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
# Once content is already known, duplicate byte uploads are treated as
# reference-only creation. Request tags are labels only here: do not
# require upload destination tags, do not move bytes, and do not
# synthesize path-derived classification or uploaded provenance.
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
@ -555,7 +535,7 @@ def upload_from_temp_path(
owner_id=owner_id,
preview_id=preview_id,
user_metadata=user_metadata or {},
tags=[*(tags or []), "uploaded"],
tags=tags,
tag_origin="manual",
require_existing_tags=False,
)
@ -589,19 +569,15 @@ def register_file_in_place(
) -> UploadResult:
"""Register an already-saved file in the asset database without moving it.
This helper is used by upload paths that have already written bytes before
registering the file, so it records the same ``uploaded`` tag as the
multipart byte-upload path.
Tags are derived from trusted filesystem classification and merged with any
caller-provided tags, matching the behavior of the scanner.
Tags are derived from the filesystem path (root category + subfolder names),
merged with any caller-provided tags, matching the behavior of the scanner.
If the path is not under a known root, only the caller-provided tags are used.
"""
try:
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
except ValueError:
path_tags = []
merged_tags = normalize_tags([*path_tags, *tags, "uploaded"])
merged_tags = normalize_tags([*path_tags, *tags])
try:
digest, _ = hashing.compute_blake3_hash(abs_path)

View File

@ -3,10 +3,10 @@ from pathlib import Path
from typing import Literal
import folder_paths
from app.assets.helpers import normalize_tags
_NON_MODEL_FOLDER_NAMES = frozenset({"configs", "custom_nodes"})
_KNOWN_SUBFOLDER_TAGS = frozenset({"3d", "pasted", "painter", "threed", "webcam"})
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
@ -14,7 +14,7 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
Includes every category registered in folder_names_and_paths,
regardless of whether its paths are under the main models_dir,
but excludes non-model entries like configs and custom_nodes.
but excludes non-model entries like custom_nodes.
"""
targets: list[tuple[str, list[str]]] = []
for name, values in folder_paths.folder_names_and_paths.items():
@ -27,37 +27,35 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps upload routing tags -> (base_dir, subdirs_for_fs).
The request tags are only used to choose the write destination. Extra tags
remain labels; they do not become path components or trusted classification.
"""
destination_roles = [t for t in tags if t in {"input", "models", "output"}]
if len(destination_roles) != 1:
raise ValueError("uploads require exactly one destination role: input, models, or output")
root = destination_roles[0]
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
if not tags:
raise ValueError("tags must not be empty")
root = tags[0].lower()
if root == "models":
model_type_tags = [t for t in tags if t.startswith("model_type:")]
if len(model_type_tags) != 1:
raise ValueError("models uploads require exactly one model_type:<folder_name> tag")
folder_name = model_type_tags[0].split(":", 1)[1]
if not folder_name:
raise ValueError("models uploads require exactly one model_type:<folder_name> tag")
model_folder_paths = dict(get_comfy_models_folders())
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = model_folder_paths[folder_name]
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{folder_name}'")
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{folder_name}'")
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
elif root == "input":
base_dir = os.path.abspath(folder_paths.get_input_directory())
else:
raw_subdirs = tags[1:]
elif root == "output":
base_dir = os.path.abspath(folder_paths.get_output_directory())
raw_subdirs = tags[1:]
else:
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
_sep_chars = frozenset(("/", "\\", os.sep))
for i in raw_subdirs:
if i in (".", "..") or _sep_chars & set(i):
raise ValueError("invalid path component in tags")
return base_dir, []
return base_dir, raw_subdirs if raw_subdirs else []
def validate_path_within_base(candidate: str, base: str) -> None:
@ -67,62 +65,6 @@ def validate_path_within_base(candidate: str, base: str) -> None:
raise ValueError("destination escapes base directory")
def _compute_relative_path(child: str, parent: str) -> str:
rel = os.path.relpath(os.path.abspath(child), os.path.abspath(parent))
if rel == ".":
return ""
return rel.replace(os.sep, "/")
def _is_relative_to(child: str, parent: str) -> bool:
return Path(os.path.abspath(child)).is_relative_to(os.path.abspath(parent))
def compute_asset_response_paths(file_path: str) -> tuple[str, str | None] | None:
"""Return public (file_path, display_name) response fields for a file path.
These fields are storage locators, not model-loader namespaces. Registered
model-folder membership is represented by backend tags such as
``model_type:<folder_name>``; response paths only use known storage roots.
"""
fp_abs = os.path.abspath(file_path)
candidates: list[tuple[int, int, str, str]] = []
for order, (namespace, base) in enumerate(
(
("input", folder_paths.get_input_directory()),
("output", folder_paths.get_output_directory()),
("temp", folder_paths.get_temp_directory()),
("models", getattr(folder_paths, "models_dir", "")),
)
):
if not base:
continue
base_abs = os.path.abspath(base)
if _is_relative_to(fp_abs, base_abs):
candidates.append((len(base_abs), -order, namespace, base_abs))
if not candidates:
return None
_base_len, _order, namespace, base = max(candidates)
rel = _compute_relative_path(fp_abs, base)
public_path = f"{namespace}/{rel}" if rel else namespace
return public_path, rel or None
def compute_display_name(file_path: str) -> str | None:
"""Return the asset's `display_name`, or None for unknown paths."""
result = compute_asset_response_paths(file_path)
return result[1] if result else None
def compute_file_path(file_path: str) -> str | None:
"""Return the asset's logical storage `file_path`, or None for unknown paths."""
result = compute_asset_response_paths(file_path)
return result[0] if result else None
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
@ -130,9 +72,6 @@ def compute_relative_filename(file_path: str) -> str | None:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
This is legacy metadata/view filename logic, not the public Asset response
`display_name`. Response fields should use compute_asset_response_paths().
For non-model paths, returns None.
"""
try:
@ -177,10 +116,9 @@ def get_asset_category_and_relative_path(
def _compute_relative(child: str, parent: str) -> str:
# Normalize relative path, stripping any leading ".." components
# by anchoring to root (os.sep) then computing relpath back from it.
rel = os.path.relpath(
return os.path.relpath(
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
)
return "" if rel == "." else rel.replace(os.sep, "/")
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
@ -211,99 +149,25 @@ def get_asset_category_and_relative_path(
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
normalized = os.path.relpath(os.path.join(os.sep, combined), os.sep)
return "models", normalized.replace(os.sep, "/")
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(
f"Path is not within input, output, temp, or configured model bases: {file_path}"
)
def get_backend_system_tags_from_path(path: str) -> list[str]:
"""Return trusted backend tags derived from current filesystem facts.
The returned tags are only the backend-generated system tags: ``models``,
``model_type:<folder_name>``, ``input``, ``output``, and ``temp``. Model
type tags are based on registered folder names, not path components.
"""
fp_abs = os.path.abspath(path)
fp_path = Path(fp_abs)
tags: list[str] = []
def _add(tag: str) -> None:
if tag not in tags:
tags.append(tag)
for role, base in (
("input", folder_paths.get_input_directory()),
("output", folder_paths.get_output_directory()),
("temp", folder_paths.get_temp_directory()),
):
if fp_path.is_relative_to(os.path.abspath(base)):
_add(role)
model_types: list[str] = []
for folder_name, bases in get_comfy_models_folders():
for base in bases:
if fp_path.is_relative_to(os.path.abspath(base)):
model_types.append(folder_name)
break
if model_types:
_add("models")
for folder_name in model_types:
_add(f"model_type:{folder_name}")
if not tags:
raise ValueError(
f"Path is not within input, output, temp, or configured model bases: {path}"
)
return tags
def get_known_subfolder_tags(subfolder: str | None) -> list[str]:
"""Return tags for known UI/input subfolder names."""
if subfolder in _KNOWN_SUBFOLDER_TAGS:
return [subfolder]
return []
def get_known_input_subfolder_tags_from_path(path: str) -> list[str]:
"""Return known input-layout tags for files in canonical input subfolders.
These are compatibility tags for current UI-origin input directories such as
``pasted`` and ``webcam``. They are intentionally narrow: only files directly
inside a known top-level input directory receive the matching tag.
"""
fp_abs = os.path.abspath(path)
input_base = os.path.abspath(folder_paths.get_input_directory())
if not Path(fp_abs).is_relative_to(input_base):
return []
rel = os.path.relpath(fp_abs, input_base)
parts = Path(rel).parts
if len(parts) == 2:
return get_known_subfolder_tags(parts[0])
return []
def get_path_derived_tags_from_path(path: str) -> list[str]:
"""Return all backend-derived tags for an asset path."""
tags = get_backend_system_tags_from_path(path)
for tag in get_known_input_subfolder_tags_from_path(path):
if tag not in tags:
tags.append(tag)
return tags
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return (name, tags) derived from a filesystem path.
- name: base filename with extension
- tags: backend-derived tags from root/model classification and known input
subfolder layout conventions
- tags: [root_category] + parent folder names in order
Raises:
ValueError: path does not belong to any known root.
"""
return Path(file_path).name, get_path_derived_tags_from_path(file_path)
root_category, some_path = get_asset_category_and_relative_path(file_path)
p = Path(some_path)
parent_parts = [
part for part in p.parent.parts if part not in (".", "..", p.anchor)
]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))

View File

@ -100,7 +100,6 @@ def _parse_cli_feature_flags() -> dict[str, Any]:
# Default server capabilities
_CORE_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"supports_model_type_tags": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
"node_replacements": True,

View File

@ -1113,6 +1113,32 @@ def full_type_name(klass):
return klass.__qualname__
return module + '.' + klass.__qualname__
def node_not_executable_reason(class_def, class_type):
"""Return a human-readable reason the node cannot be executed, or None if it's fine.
Catches a node whose declared entry point doesn't resolve to a real method
(e.g. a V1 ``FUNCTION = "invert"`` where the method is misspelled, or a V3 node
missing its ``execute`` override). Running this during validation surfaces the
problem before execution starts, instead of after upstream nodes have run.
Only the class is inspected; the node is never instantiated here, so a node's
``__init__`` side effects cannot run (or fail) during validation.
"""
try:
if issubclass(class_def, _ComfyNodeInternal):
# V3: validates that execute()/define_schema() overrides exist.
class_def.VALIDATE_CLASS()
return None
# V1: FUNCTION names the method to call; it must exist on the class.
function_name = getattr(class_def, "FUNCTION", None)
if function_name is None:
return f"'{class_type}' does not define FUNCTION"
if not callable(getattr(class_def, function_name, None)):
return f"'{class_type}' has no method '{function_name}' (declared in FUNCTION)"
return None
except Exception as ex:
return str(ex)
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
outputs = set()
for x in prompt:
@ -1148,6 +1174,35 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[
}
return (False, error, [], {})
# Make sure the node is actually executable (its FUNCTION/execute entry
# point resolves to a real method) before we touch any schema-derived
# attributes below or start execution. Catches code typos up front and
# attributes the error to the offending node.
not_executable = node_not_executable_reason(class_, class_type)
if not_executable is not None:
node_title = prompt[x].get('_meta', {}).get('title', class_type)
error = {
"type": "invalid_node_definition",
"message": "Node is not executable",
"details": f"{not_executable} (Node ID '#{x}')",
"extra_info": {
"node_id": x,
"class_type": class_type,
"node_title": node_title,
}
}
node_errors = {x: {
"errors": [{
"type": "invalid_node_definition",
"message": "Node is not executable",
"details": not_executable,
"extra_info": {},
}],
"dependent_outputs": [],
"class_type": class_type,
}}
return (False, error, [], node_errors)
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
if partial_execution_list is None or x in partial_execution_list:
outputs.add(x)

View File

@ -7,18 +7,18 @@ components:
description: Timestamp when the asset was created
format: date-time
type: string
display_name:
description: Display name of the asset. Mirrors name for backwards compatibility.
nullable: true
type: string
file_path:
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
nullable: true
type: string
hash:
description: Blake3 hash of the asset content.
pattern: ^blake3:[a-f0-9]{64}$
type: string
file_path:
description: Runtime storage locator for filesystem-backed assets, using Comfy storage namespaces such as `input/`, `output/`, `temp/`, or `models/`. Not an absolute filesystem path, unique identity, or model loader path.
nullable: true
type: string
display_name:
description: Human-facing label derived from `file_path`, usually the path below the top-level storage namespace. Not unique.
nullable: true
type: string
id:
description: Unique identifier for the asset
format: uuid
@ -144,6 +144,14 @@ components:
AssetUpdated:
description: Response returned when an existing asset is successfully updated.
properties:
display_name:
description: Display name of the asset. Mirrors name for backwards compatibility.
nullable: true
type: string
file_path:
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
nullable: true
type: string
hash:
description: Blake3 hash of the asset content.
pattern: ^blake3:[a-f0-9]{64}$
@ -1636,7 +1644,7 @@ paths:
format: uuid
type: string
tags:
description: JSON-encoded array of tag strings. For new byte uploads, include exactly one destination role (`input`, `output`, or `models`); `models` uploads also require exactly one `model_type:<folder_name>` tag. Extra tags are stored as labels and do not create path components.
description: JSON-encoded array of freeform tag strings, e.g. '["models","checkpoint"]'. Common types include "models", "input", "output", and "temp", but any tag can be used in any order.
type: string
user_metadata:
description: Custom JSON metadata as a string
@ -1821,7 +1829,7 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/Asset'
$ref: '#/components/schemas/AssetUpdated'
description: Asset updated successfully
"400":
content:
@ -2462,9 +2470,6 @@ paths:
supports_preview_metadata:
description: Whether the server supports preview metadata
type: boolean
supports_model_type_tags:
description: Whether the server supports namespaced model type asset tags
type: boolean
type: object
description: Success
headers:

View File

@ -46,7 +46,6 @@ from comfy_api.internal import _ComfyNodeInternal
from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_routes
from app.assets.services.ingest import register_file_in_place
from app.assets.services.path_utils import get_known_subfolder_tags
from app.assets.services.asset_management import resolve_hash_to_path
from app.user_manager import UserManager
@ -441,9 +440,7 @@ class PromptServer():
if args.enable_assets:
try:
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
tags = [tag]
tags.extend(get_known_subfolder_tags(subfolder))
result = register_file_in_place(abs_path=filepath, name=filename, tags=tags)
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
resp["asset"] = {
"id": result.ref.id,
"name": result.ref.name,

View File

@ -8,7 +8,6 @@ upgrade/downgrade for 0003+.
"""
import os
import sqlite3
import pytest
from alembic import command
@ -31,12 +30,6 @@ def _make_config(db_path: str) -> Config:
return cfg
def _sqlite_path(cfg: Config) -> str:
url = cfg.get_main_option("sqlalchemy.url")
assert url is not None and url.startswith("sqlite:///")
return url.removeprefix("sqlite:///")
@pytest.fixture
def migration_db(tmp_path):
"""Yield an alembic Config pre-upgraded to the baseline revision."""
@ -62,26 +55,3 @@ def test_upgrade_downgrade_cycle(migration_db):
command.upgrade(migration_db, "head")
command.downgrade(migration_db, _BASELINE)
command.upgrade(migration_db, "head")
def test_case_sensitive_tags_downgrade_normalizes_existing_tags(migration_db):
"""Downgrading 0005 folds mixed-case tag vocabulary before restoring CHECK."""
command.upgrade(migration_db, "0005_allow_case_sensitive_tags")
db_path = _sqlite_path(migration_db)
with sqlite3.connect(db_path) as conn:
conn.execute("INSERT INTO tags(name) VALUES (?)", ("NewTag",))
conn.execute("INSERT INTO tags(name) VALUES (?)", ("newtag",))
conn.execute("INSERT INTO tags(name) VALUES (?)", ("model_type:LLM",))
command.downgrade(migration_db, "0004_drop_tag_type")
with sqlite3.connect(db_path) as conn:
tags = {row[0] for row in conn.execute("SELECT name FROM tags")}
assert "newtag" in tags
assert "model_type:llm" in tags
assert "NewTag" not in tags
assert "model_type:LLM" not in tags
with pytest.raises(sqlite3.IntegrityError):
conn.execute("INSERT INTO tags(name) VALUES (?)", ("Upper",))

View File

@ -234,7 +234,7 @@ def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_bas
p = getattr(request, "param", {}) or {}
tags: Optional[list[str]] = p.get("tags")
if tags is None:
tags = ["models", "model_type:checkpoints", "unit-tests", "alpha"]
tags = ["models", "checkpoints", "unit-tests", "alpha"]
meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None}
# Unique content per test so the seed always creates a fresh asset (201).
# Delete is now always a soft delete, so content from a prior test survives

View File

@ -133,66 +133,6 @@ class TestListReferencesPage:
assert total == 1
assert refs[0].name == "tagged"
def test_include_tags_filter_ands_persisted_model_tags(self, session: Session):
asset = _make_asset(session, "hash-model-tags")
checkpoint = _make_reference(session, asset, name="checkpoint")
lora = _make_reference(session, asset, name="lora")
input_ref = _make_reference(session, asset, name="input")
ensure_tags_exist(
session,
["models", "model_type:checkpoints", "model_type:loras", "unit-tests"],
)
add_tags_to_reference(
session,
reference_id=checkpoint.id,
tags=["models", "model_type:checkpoints", "unit-tests"],
origin="automatic",
)
add_tags_to_reference(
session,
reference_id=lora.id,
tags=["models", "model_type:loras", "unit-tests"],
origin="automatic",
)
add_tags_to_reference(
session,
reference_id=input_ref.id,
tags=["unit-tests"],
)
session.commit()
refs, _, total = list_references_page(
session,
include_tags=["models", "model_type:checkpoints", "unit-tests"],
)
assert total == 1
assert refs[0].id == checkpoint.id
def test_include_tags_filter_preserves_model_type_case(self, session: Session):
asset = _make_asset(session, "hash-model-case")
ref = _make_reference(session, asset, name="llm")
ensure_tags_exist(session, ["models", "model_type:LLM"])
add_tags_to_reference(
session,
reference_id=ref.id,
tags=["models", "model_type:LLM"],
origin="automatic",
)
session.commit()
refs, _, total = list_references_page(
session, include_tags=["models", "model_type:LLM"]
)
refs_lower, _, total_lower = list_references_page(
session, include_tags=["models", "model_type:llm"]
)
assert total == 1
assert refs[0].id == ref.id
assert total_lower == 0
assert refs_lower == []
def test_exclude_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, name="keep")

View File

@ -58,7 +58,7 @@ class TestEnsureTagsExist:
session.commit()
tags = session.query(Tag).all()
assert {t.name for t in tags} == {"ALPHA", "Beta", "alpha"}
assert {t.name for t in tags} == {"alpha", "beta"}
def test_empty_list_is_noop(self, session: Session):
ensure_tags_exist(session, [])
@ -258,16 +258,6 @@ class TestListTagsWithUsage:
tag_names = {name for name, _ in rows}
assert tag_names == {"alpha", "alphabet"}
def test_prefix_filter_is_case_sensitive(self, session: Session):
ensure_tags_exist(session, ["model_type:LLM", "model_type:llm"])
session.commit()
rows, total = list_tags_with_usage(session, prefix="model_type:L")
tag_names = {name for name, _ in rows}
assert tag_names == {"model_type:LLM"}
assert total == 1
def test_order_by_name(self, session: Session):
ensure_tags_exist(session, ["zebra", "alpha", "middle"])
session.commit()

View File

@ -94,47 +94,6 @@ class TestIngestFileFromPath:
ref_tags = get_reference_tags(session, reference_id=result.reference_id)
assert set(ref_tags) == {"models", "checkpoints"}
def test_path_derived_tags_use_automatic_origin(
self, mock_create_session, temp_dir: Path, session: Session
):
input_dir = temp_dir / "input"
output_dir = temp_dir / "output"
temp_root = temp_dir / "temp"
for directory in (input_dir, output_dir, temp_root):
directory.mkdir()
file_path = input_dir / "pasted" / "tagged.png"
file_path.parent.mkdir()
file_path.write_bytes(b"data")
with (
patch("app.assets.services.path_utils.folder_paths") as mock_fp,
patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[],
),
):
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.get_temp_directory.return_value = str(temp_root)
result = _ingest_file_from_path(
abs_path=str(file_path),
asset_hash="blake3:pathorigin",
size_bytes=4,
mtime_ns=1234567890000000000,
info_name="Tagged Asset",
tags=["input", "manual-label"],
)
assert result.reference_id is not None
links = session.query(AssetReferenceTag).filter_by(
asset_reference_id=result.reference_id
)
origin_by_tag = {link.tag_name: link.origin for link in links}
assert origin_by_tag["input"] == "automatic"
assert origin_by_tag["pasted"] == "automatic"
assert origin_by_tag["manual-label"] == "manual"
def test_idempotent_upsert(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "dup.bin"
file_path.write_bytes(b"content")

View File

@ -6,15 +6,7 @@ from unittest.mock import patch
import pytest
from app.assets.services.path_utils import (
compute_display_name,
compute_file_path,
get_asset_category_and_relative_path,
get_known_input_subfolder_tags_from_path,
get_known_subfolder_tags,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags,
)
from app.assets.services.path_utils import get_asset_category_and_relative_path
@pytest.fixture
@ -25,8 +17,7 @@ def fake_dirs():
input_dir = root_path / "input"
output_dir = root_path / "output"
temp_dir = root_path / "temp"
models_root = root_path / "models"
models_dir = models_root / "checkpoints"
models_dir = root_path / "models" / "checkpoints"
for d in (input_dir, output_dir, temp_dir, models_dir):
d.mkdir(parents=True)
@ -34,7 +25,6 @@ def fake_dirs():
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.get_temp_directory.return_value = str(temp_dir)
mock_fp.models_dir = str(models_root)
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
@ -44,7 +34,6 @@ def fake_dirs():
"input": input_dir,
"output": output_dir,
"temp": temp_dir,
"models_root": models_root,
"models": models_dir,
}
@ -87,376 +76,6 @@ class TestGetAssetCategoryAndRelativePath:
cat, rel = get_asset_category_and_relative_path(str(f))
assert cat == "models"
def test_model_path_tags_include_registered_model_type_only(self, fake_dirs):
f = fake_dirs["models"] / "subdir" / "model.safetensors"
f.parent.mkdir()
f.touch()
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "models" in tags
assert "model_type:checkpoints" in tags
assert "checkpoints" not in tags
assert "subdir" not in tags
def test_model_type_preserves_registered_folder_case(self, fake_dirs):
llm_dir = fake_dirs["models"].parent / "LLM"
llm_dir.mkdir()
f = llm_dir / "model.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("LLM", [str(llm_dir)])],
):
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "models" in tags
assert "model_type:LLM" in tags
assert "model_type:llm" not in tags
def test_path_components_do_not_create_model_type_tags(self, fake_dirs):
f = fake_dirs["models"] / "loras" / "model.safetensors"
f.parent.mkdir()
f.touch()
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "models" in tags
assert "model_type:checkpoints" in tags
assert "loras" not in tags
assert "model_type:loras" not in tags
def test_shared_root_returns_all_matching_model_type_tags(self, fake_dirs):
shared_root = fake_dirs["models"].parent / "shared"
shared_root.mkdir()
f = shared_root / "foo.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[
("checkpoints", [str(shared_root)]),
("loras", [str(shared_root)]),
],
):
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "models" in tags
assert "model_type:checkpoints" in tags
assert "model_type:loras" in tags
def test_output_backed_registered_folder_gets_model_and_output_tags(self, fake_dirs):
output_checkpoints_dir = fake_dirs["output"] / "checkpoints"
output_checkpoints_dir.mkdir()
f = output_checkpoints_dir / "saved.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("checkpoints", [str(output_checkpoints_dir)])],
):
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "models" in tags
assert "model_type:checkpoints" in tags
assert "output" in tags
def test_temp_path_tags_include_temp_not_output_or_preview(self, fake_dirs):
f = fake_dirs["temp"] / "preview.png"
f.touch()
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "temp" in tags
assert "output" not in tags
assert "preview:true" not in tags
def test_known_subfolder_tags_are_centralized(self):
assert get_known_subfolder_tags("pasted") == ["pasted"]
assert get_known_subfolder_tags("arbitrary") == []
def test_known_input_subfolder_tags_are_path_derived_for_direct_children(self, fake_dirs):
f = fake_dirs["input"] / "pasted" / "image.png"
f.parent.mkdir()
f.touch()
assert get_known_input_subfolder_tags_from_path(str(f)) == ["pasted"]
_name, tags = get_name_and_tags_from_asset_path(str(f))
assert "input" in tags
assert "pasted" in tags
def test_known_input_subfolder_tags_do_not_apply_to_nested_or_other_roots(self, fake_dirs):
nested = fake_dirs["input"] / "pasted" / "session" / "image.png"
output = fake_dirs["output"] / "pasted" / "image.png"
for path in (nested, output):
path.parent.mkdir(parents=True)
path.touch()
assert get_known_input_subfolder_tags_from_path(str(nested)) == []
assert get_known_input_subfolder_tags_from_path(str(output)) == []
def test_unknown_path_raises(self, fake_dirs):
with pytest.raises(ValueError, match="not within"):
get_asset_category_and_relative_path("/some/random/path.png")
class TestResponseStoragePaths:
def test_input_file_path_and_display_name_include_subfolder(self, fake_dirs):
sub = fake_dirs["input"] / "some" / "folder"
sub.mkdir(parents=True)
f = sub / "image.png"
f.touch()
assert compute_file_path(str(f)) == "input/some/folder/image.png"
assert compute_display_name(str(f)) == "some/folder/image.png"
def test_output_file_path_and_display_name_include_subfolder(self, fake_dirs):
sub = fake_dirs["output"] / "renders"
sub.mkdir()
f = sub / "ComfyUI_00001_.png"
f.touch()
assert compute_file_path(str(f)) == "output/renders/ComfyUI_00001_.png"
assert compute_display_name(str(f)) == "renders/ComfyUI_00001_.png"
def test_temp_file_path_and_display_name(self, fake_dirs):
f = fake_dirs["temp"] / "preview.png"
f.touch()
assert compute_file_path(str(f)) == "temp/preview.png"
assert compute_display_name(str(f)) == "preview.png"
def test_exact_storage_root_has_no_display_name(self, fake_dirs):
assert compute_file_path(str(fake_dirs["input"])) == "input"
assert compute_display_name(str(fake_dirs["input"])) is None
def test_longest_matching_builtin_root_wins(self, fake_dirs, tmp_path: Path):
nested_output = fake_dirs["input"] / "nested-output"
nested_output.mkdir()
f = nested_output / "image.png"
f.touch()
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
mock_fp.get_input_directory.return_value = str(fake_dirs["input"])
mock_fp.get_output_directory.return_value = str(nested_output)
mock_fp.get_temp_directory.return_value = str(tmp_path / "temp")
mock_fp.models_dir = str(fake_dirs["models_root"])
assert compute_file_path(str(f)) == "output/image.png"
assert compute_display_name(str(f)) == "image.png"
def test_model_file_path_is_relative_to_physical_models_root(self, fake_dirs):
sub = fake_dirs["models"] / "flux"
sub.mkdir()
f = sub / "model.safetensors"
f.touch()
assert compute_file_path(str(f)) == "models/checkpoints/flux/model.safetensors"
assert compute_display_name(str(f)) == "checkpoints/flux/model.safetensors"
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "model.safetensors"
assert "models" in tags
assert "model_type:checkpoints" in tags
assert "checkpoints" not in tags
assert "flux" not in tags
@pytest.mark.parametrize(
"folder_name",
["checkpoints", "clip", "vae", "diffusion_models", "loras"],
)
def test_output_model_folder_uses_output_storage_file_path(self, fake_dirs, folder_name):
output_model_dir = fake_dirs["output"] / folder_name
output_model_dir.mkdir(exist_ok=True)
default_model_dir = fake_dirs["models_root"] / folder_name
default_model_dir.mkdir(exist_ok=True)
f = output_model_dir / "saved.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[(folder_name, [str(default_model_dir), str(output_model_dir)])],
):
assert compute_file_path(str(f)) == f"output/{folder_name}/saved.safetensors"
assert compute_display_name(str(f)) == f"{folder_name}/saved.safetensors"
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "saved.safetensors"
assert "output" in tags
assert "models" in tags
assert f"model_type:{folder_name}" in tags
assert folder_name not in tags
def test_output_model_subfolder_uses_output_storage_file_path(self, fake_dirs):
folder_name = "loras"
output_model_dir = fake_dirs["output"] / folder_name
subdir = output_model_dir / "experiments"
subdir.mkdir(parents=True)
f = subdir / "my_lora.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[(folder_name, [str(output_model_dir)])],
):
assert (
compute_file_path(str(f))
== "output/loras/experiments/my_lora.safetensors"
)
assert compute_display_name(str(f)) == "loras/experiments/my_lora.safetensors"
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "my_lora.safetensors"
assert "output" in tags
assert "models" in tags
assert "model_type:loras" in tags
assert "loras" not in tags
assert "experiments" not in tags
def test_external_model_folder_without_provenance_has_no_file_path(self, tmp_path: Path):
external_checkpoints_dir = tmp_path / "external" / "not_named_like_category"
external_checkpoints_dir.mkdir(parents=True)
f = external_checkpoints_dir / "external.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("checkpoints", [str(external_checkpoints_dir)])],
):
assert compute_file_path(str(f)) is None
assert compute_display_name(str(f)) is None
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "external.safetensors"
assert "models" in tags
assert "model_type:checkpoints" in tags
def test_same_relative_model_file_under_multiple_external_roots_has_no_storage_file_path(
self, tmp_path: Path
):
foo_dir = tmp_path / "foo"
bar_dir = tmp_path / "bar"
foo_dir.mkdir()
bar_dir.mkdir()
foo_file = foo_dir / "baz.safetensors"
bar_file = bar_dir / "baz.safetensors"
foo_file.touch()
bar_file.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("checkpoints", [str(foo_dir), str(bar_dir)])],
):
assert compute_file_path(str(foo_file)) is None
assert compute_file_path(str(bar_file)) is None
assert compute_display_name(str(foo_file)) is None
assert compute_display_name(str(bar_file)) is None
def test_output_clip_folder_uses_output_storage_and_text_encoder_tag(self, fake_dirs):
output_clip_dir = fake_dirs["output"] / "clip"
output_clip_dir.mkdir()
f = output_clip_dir / "clip_l.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("text_encoders", [str(output_clip_dir)])],
):
assert compute_file_path(str(f)) == "output/clip/clip_l.safetensors"
assert compute_display_name(str(f)) == "clip/clip_l.safetensors"
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "clip_l.safetensors"
assert "output" in tags
assert "models" in tags
assert "model_type:text_encoders" in tags
assert "clip" not in tags
def test_physical_unet_folder_uses_storage_path_and_diffusion_models_tag(self, fake_dirs):
unet_dir = fake_dirs["models_root"] / "unet"
diffusion_models_dir = fake_dirs["models_root"] / "diffusion_models"
unet_dir.mkdir()
diffusion_models_dir.mkdir()
f = unet_dir / "wan.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("diffusion_models", [str(unet_dir), str(diffusion_models_dir)])],
):
assert compute_file_path(str(f)) == "models/unet/wan.safetensors"
assert compute_display_name(str(f)) == "unet/wan.safetensors"
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "wan.safetensors"
assert "models" in tags
assert "model_type:diffusion_models" in tags
assert "unet" not in tags
def test_unregistered_file_under_physical_models_root_still_has_storage_file_path(self, fake_dirs):
f = fake_dirs["models_root"] / "not_registered" / "orphan.bin"
f.parent.mkdir()
f.touch()
assert compute_file_path(str(f)) == "models/not_registered/orphan.bin"
assert compute_display_name(str(f)) == "not_registered/orphan.bin"
def test_output_checkpoint_folder_without_registration_has_only_output_tag(self, fake_dirs):
f = fake_dirs["output"] / "checkpoints" / "saved.safetensors"
f.parent.mkdir(exist_ok=True)
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[],
):
assert compute_file_path(str(f)) == "output/checkpoints/saved.safetensors"
assert compute_display_name(str(f)) == "checkpoints/saved.safetensors"
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "saved.safetensors"
assert "output" in tags
assert "models" not in tags
assert not any(tag.startswith("model_type:") for tag in tags)
def test_unknown_path_returns_none(self):
assert compute_file_path("/some/random/path.png") is None
assert compute_display_name("/some/random/path.png") is None
class TestResolveDestinationFromTags:
def test_extra_tags_are_not_path_components(self, fake_dirs):
base_dir, subdirs = resolve_destination_from_tags(["input", "unit-tests", "foo"])
assert base_dir == os.path.abspath(fake_dirs["input"])
assert subdirs == []
def test_model_upload_rejects_non_writable_registered_folders(self):
with tempfile.TemporaryDirectory() as root:
root_path = Path(root)
checkpoints_dir = root_path / "models" / "checkpoints"
configs_dir = root_path / "models" / "configs"
custom_nodes_dir = root_path / "custom_nodes"
for path in (checkpoints_dir, configs_dir, custom_nodes_dir):
path.mkdir(parents=True)
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
mock_fp.folder_names_and_paths = {
"checkpoints": ([str(checkpoints_dir)], set()),
"configs": ([str(configs_dir)], set()),
"custom_nodes": ([str(custom_nodes_dir)], set()),
}
base_dir, subdirs = resolve_destination_from_tags(
["models", "model_type:checkpoints"]
)
assert base_dir == os.path.abspath(checkpoints_dir)
assert subdirs == []
for folder_name in ("configs", "custom_nodes"):
with pytest.raises(ValueError, match="unknown model category"):
resolve_destination_from_tags(
["models", f"model_type:{folder_name}"]
)

View File

@ -19,8 +19,7 @@ def test_seed_asset_removed_when_file_is_deleted(
"""Asset without hash (seed) whose file disappears:
after triggering sync_seed_assets, Asset + AssetInfo disappear.
"""
# Create a file directly under input/unit-tests/<case>. Backend tags only
# classify the root; nested path components are not exposed as tags.
# Create a file directly under input/unit-tests/<case> so tags include "unit-tests"
case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed"
case_dir.mkdir(parents=True, exist_ok=True)
name = f"seed_{uuid.uuid4().hex[:8]}.bin"
@ -33,7 +32,7 @@ def test_seed_asset_removed_when_file_is_deleted(
# Verify it is visible via API and carries no hash (seed)
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": root, "name_contains": name},
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
timeout=120,
)
body1 = r1.json()
@ -55,7 +54,7 @@ def test_seed_asset_removed_when_file_is_deleted(
# It should disappear (AssetInfo and seed Asset gone)
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": root, "name_contains": name},
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
timeout=120,
)
body2 = r2.json()
@ -133,7 +132,7 @@ def test_hashed_asset_two_asset_infos_both_get_missing(
second_id = b2["id"]
# Remove the single underlying file
p = comfy_tmp_base_dir / "input" / get_asset_filename(created["asset_hash"], ".png")
p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png")
assert p.exists()
p.unlink()
@ -251,7 +250,8 @@ def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match(
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
aid = a["id"]
p = comfy_tmp_base_dir / root / get_asset_filename(a["asset_hash"], ".bin")
base = comfy_tmp_base_dir / root / "unit-tests" / scope
p = base / get_asset_filename(a["asset_hash"], ".bin")
st0 = p.stat()
orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000))

View File

@ -290,7 +290,7 @@ def test_metadata_filename_is_set_for_seed_asset_without_hash(
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": root, "name_contains": name},
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
timeout=120,
)
body = r1.json()

View File

@ -95,7 +95,7 @@ def test_download_chooses_existing_state_and_updates_access_time(
assert t1 > t0
@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "model_type:checkpoints"]}], indirect=True)
@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True)
def test_download_missing_file_returns_404(
http: requests.Session, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict
):

View File

@ -13,7 +13,7 @@ def _seed(asset_factory, make_asset_bytes, count: int, tag: str) -> list[str]:
for n in names:
asset_factory(
n,
["models", "model_type:checkpoints", "unit-tests", tag],
["models", "checkpoints", "unit-tests", tag],
{},
make_asset_bytes(n, size=2048),
)
@ -208,7 +208,7 @@ def test_cursor_walks_for_non_name_sorts(sort_field, http: requests.Session, api
names = []
for i in range(4):
n = f"cursor_{sort_field}_{i:02d}.safetensors"
asset_factory(n, ["models", "model_type:checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i))
asset_factory(n, ["models", "checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i))
names.append(n)
params = {

View File

@ -11,7 +11,7 @@ def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asse
for n in names:
asset_factory(
n,
["models", "model_type:checkpoints", "unit-tests", "paging"],
["models", "checkpoints", "unit-tests", "paging"],
{"epoch": 1},
make_asset_bytes(n, size=2048),
)
@ -45,8 +45,8 @@ def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asse
def test_list_assets_include_exclude_and_name_contains(http: requests.Session, api_base: str, asset_factory):
a = asset_factory("inc_a.safetensors", ["models", "model_type:checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024)
b = asset_factory("inc_b.safetensors", ["models", "model_type:checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024)
a = asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024)
b = asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024)
r = http.get(
api_base + "/api/assets",
@ -81,7 +81,7 @@ def test_list_assets_include_exclude_and_name_contains(http: requests.Session, a
def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "model_type:checkpoints", "unit-tests", "lf-size"]
t = ["models", "checkpoints", "unit-tests", "lf-size"]
n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors"
asset_factory(n1, t, {}, make_asset_bytes(n1, 1024))
asset_factory(n2, t, {}, make_asset_bytes(n2, 2048))
@ -108,7 +108,7 @@ def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, mak
def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "model_type:checkpoints", "unit-tests", "lf-upd"]
t = ["models", "checkpoints", "unit-tests", "lf-upd"]
a1 = asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200))
a2 = asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200))
@ -131,7 +131,7 @@ def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make
def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "model_type:checkpoints", "unit-tests", "lf-access"]
t = ["models", "checkpoints", "unit-tests", "lf-access"]
asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100))
time.sleep(0.02)
a2 = asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100))
@ -154,14 +154,14 @@ def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory
def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "model_type:checkpoints", "unit-tests", "lf-include"]
t = ["models", "checkpoints", "unit-tests", "lf-include"]
a = asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva"))
asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb"))
# CSV tag filters are whitespace-trimmed and case-sensitive.
# CSV + case-insensitive
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-include,alpha"},
params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"},
timeout=120,
)
b1 = r1.json()
@ -196,14 +196,14 @@ def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factor
def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "model_type:checkpoints", "unit-tests", "lf-exclude"]
t = ["models", "checkpoints", "unit-tests", "lf-exclude"]
a = asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900))
asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900))
# Exclude filters are case-sensitive.
# Exclude uppercase should work
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "beta"},
params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"},
timeout=120,
)
b1 = r1.json()
@ -225,7 +225,7 @@ def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory,
def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "model_type:checkpoints", "unit-tests", "lf-name"]
t = ["models", "checkpoints", "unit-tests", "lf-name"]
a1 = asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800))
a2 = asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800))
@ -261,7 +261,7 @@ def test_list_assets_name_contains_case_and_specials(http, api_base, asset_facto
def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "model_type:checkpoints", "unit-tests", "lf-pagelimits"]
t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"]
asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600))
asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600))
asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600))
@ -319,7 +319,7 @@ def test_list_assets_name_contains_literal_underscore(
- foobar.safetensors (must NOT match)
"""
scope = f"lf-underscore-{uuid.uuid4().hex[:6]}"
tags = ["models", "model_type:checkpoints", "unit-tests", scope]
tags = ["models", "checkpoints", "unit-tests", scope]
a = asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700))
b = asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700))

View File

@ -5,7 +5,7 @@ def test_meta_and_across_keys_and_types(
http, api_base: str, asset_factory, make_asset_bytes
):
name = "mf_and_mix.safetensors"
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-and"]
tags = ["models", "checkpoints", "unit-tests", "mf-and"]
meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
asset_factory(name, tags, meta, make_asset_bytes(name, 4096))
@ -41,7 +41,7 @@ def test_meta_and_across_keys_and_types(
def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes):
name = "mf_types.safetensors"
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-types"]
tags = ["models", "checkpoints", "unit-tests", "mf-types"]
meta = {"epoch": 1, "active": True}
asset_factory(name, tags, meta, make_asset_bytes(name))
@ -95,7 +95,7 @@ def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory,
def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes):
name = "mf_list_scalars.safetensors"
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-list"]
tags = ["models", "checkpoints", "unit-tests", "mf-list"]
meta = {"flags": ["red", "green"]}
asset_factory(name, tags, meta, make_asset_bytes(name, 3000))
@ -134,7 +134,7 @@ def test_meta_none_semantics_missing_or_null_and_any_of_with_none(
http, api_base, asset_factory, make_asset_bytes
):
# a1: key missing; a2: explicit null; a3: concrete value
t = ["models", "model_type:checkpoints", "unit-tests", "mf-none"]
t = ["models", "checkpoints", "unit-tests", "mf-none"]
a1 = asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1"))
a2 = asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2"))
a3 = asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3"))
@ -166,7 +166,7 @@ def test_meta_none_semantics_missing_or_null_and_any_of_with_none(
def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes):
name = "mf_nested_json.safetensors"
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-nested"]
tags = ["models", "checkpoints", "unit-tests", "mf-nested"]
cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}}
asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200))
@ -197,7 +197,7 @@ def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_as
def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes):
name = "mf_list_objects.safetensors"
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-objlist"]
tags = ["models", "checkpoints", "unit-tests", "mf-objlist"]
transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}]
asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048))
@ -228,7 +228,7 @@ def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_b
def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes):
name = "mf_keys_unicode.safetensors"
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-keys"]
tags = ["models", "checkpoints", "unit-tests", "mf-keys"]
meta = {
"weird.key": "v1",
"path/like": 7,
@ -259,7 +259,7 @@ def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_
def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "model_type:checkpoints", "unit-tests", "mf-zero-bool"]
t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"]
a0 = asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025))
a1 = asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026))
@ -286,7 +286,7 @@ def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_as
def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes):
name = "mf_mixed_list.safetensors"
tags = ["models", "model_type:checkpoints", "unit-tests", "mf-mixed"]
tags = ["models", "checkpoints", "unit-tests", "mf-mixed"]
meta = {"mix": ["1", 1, True, None]}
asset_factory(name, tags, meta, make_asset_bytes(name, 1999))
@ -311,7 +311,7 @@ def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, mak
def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes):
# Use a unique scope tag to avoid interference
t = ["models", "model_type:checkpoints", "unit-tests", "mf-unknown-scope"]
t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"]
x = asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua"))
y = asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub"))
@ -340,13 +340,13 @@ def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_
# alpha matches epoch=1; beta has epoch=2
a = asset_factory(
"mf_tag_alpha.safetensors",
["models", "model_type:checkpoints", "unit-tests", "mf-tag", "alpha"],
["models", "checkpoints", "unit-tests", "mf-tag", "alpha"],
{"epoch": 1},
make_asset_bytes("alpha"),
)
b = asset_factory(
"mf_tag_beta.safetensors",
["models", "model_type:checkpoints", "unit-tests", "mf-tag", "beta"],
["models", "checkpoints", "unit-tests", "mf-tag", "beta"],
{"epoch": 2},
make_asset_bytes("beta"),
)
@ -367,7 +367,7 @@ def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_
def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes):
# Three assets in same scope with different sizes and a common filter key
t = ["models", "model_type:checkpoints", "unit-tests", "mf-sort"]
t = ["models", "checkpoints", "unit-tests", "mf-sort"]
n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors"
asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024))
asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048))

View File

@ -29,7 +29,7 @@ def create_seed_file(comfy_tmp_base_dir: Path):
def find_asset(http: requests.Session, api_base: str):
"""Query API for assets matching scope and optional name."""
def _find(scope: str, name: str | None = None) -> list[dict]:
params = {"limit": "500"}
params = {"include_tags": f"unit-tests,{scope}"}
if name:
params["name_contains"] = name
r = http.get(f"{api_base}/api/assets", params=params, timeout=120)
@ -91,7 +91,7 @@ def test_hashed_asset_not_pruned_when_file_missing(
data = make_asset_bytes("test", 2048)
a = asset_factory("test.bin", ["input", "unit-tests", scope], {}, data)
path = comfy_tmp_base_dir / "input" / get_asset_filename(a["asset_hash"], ".bin")
path = comfy_tmp_base_dir / "input" / "unit-tests" / scope / get_asset_filename(a["asset_hash"], ".bin")
path.unlink()
trigger_sync_seed_assets(http, api_base)
@ -108,20 +108,18 @@ def test_prune_across_multiple_roots(
):
"""Prune correctly handles assets across input and output roots."""
scope = f"multi-{uuid.uuid4().hex[:6]}"
input_name = f"{scope}-input.bin"
output_name = f"{scope}-output.bin"
input_fp = create_seed_file("input", scope, input_name)
create_seed_file("output", scope, output_name)
input_fp = create_seed_file("input", scope, "input.bin")
create_seed_file("output", scope, "output.bin")
trigger_sync_seed_assets(http, api_base)
assert find_asset(scope, input_name)
assert find_asset(scope, output_name)
assert len(find_asset(scope)) == 2
input_fp.unlink()
trigger_sync_seed_assets(http, api_base)
assert not find_asset(scope, input_name)
assert find_asset(scope, output_name)
remaining = find_asset(scope)
assert len(remaining) == 1
assert remaining[0]["name"] == "output.bin"
@pytest.mark.parametrize("dirname", ["100%_done", "my_folder_name", "has spaces"])

View File

@ -10,9 +10,9 @@ def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict)
body1 = r1.json()
assert r1.status_code == 200
names = [t["name"] for t in body1["tags"]]
# A few selected contract tags should exist.
# A few system tags from migration should exist:
assert "models" in names
assert "model_type:checkpoints" in names
assert "checkpoints" in names
# Only used tags before we add anything new from this test cycle
r2 = http.get(api_base + "/api/tags", params={"include_zero": "false"}, timeout=120)
@ -21,7 +21,7 @@ def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict)
# We already seeded one asset via fixture, so used tags must be non-empty
used_names = [t["name"] for t in body2["tags"]]
assert "models" in used_names
assert "model_type:checkpoints" in used_names
assert "checkpoints" in used_names
# Prefix filter should refine the list
r3 = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}, timeout=120)
@ -45,7 +45,7 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory,
body1 = r1.json()
assert r1.status_code == 200
names = [t["name"] for t in body1["tags"]]
assert "models" in names and "model_type:checkpoints" in names
assert "models" in names and "checkpoints" in names
# Create a short-lived asset under input with a unique custom tag
scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}"
@ -89,28 +89,28 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory,
def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset: dict):
aid = seeded_asset["id"]
# Add tags with duplicates while preserving source case.
payload_add = {"tags": ["NewTag", "unit-tests", "NewTag", "BETA"]}
# Add tags with duplicates and mixed case
payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]}
r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add, timeout=120)
b1 = r1.json()
assert r1.status_code == 200, b1
# stripped, deduplicated; 'unit-tests' was already present from the seed
assert set(b1["added"]) == {"NewTag", "BETA"}
# normalized, deduplicated; 'unit-tests' was already present from the seed
assert set(b1["added"]) == {"newtag", "beta"}
assert set(b1["already_present"]) == {"unit-tests"}
assert "NewTag" in b1["total_tags"] and "BETA" in b1["total_tags"]
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
g = rg.json()
assert rg.status_code == 200
tags_now = set(g["tags"])
assert {"NewTag", "BETA"}.issubset(tags_now)
assert {"newtag", "beta"}.issubset(tags_now)
# Remove a tag and a non-existent tag
payload_del = {"tags": ["NewTag", "does-not-exist"]}
payload_del = {"tags": ["newtag", "does-not-exist"]}
r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del, timeout=120)
b2 = r2.json()
assert r2.status_code == 200
assert set(b2["removed"]) == {"NewTag"}
assert set(b2["removed"]) == {"newtag"}
assert set(b2["not_present"]) == {"does-not-exist"}
# Verify remaining tags after deletion
@ -118,44 +118,8 @@ def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset
g2 = rg2.json()
assert rg2.status_code == 200
tags_later = set(g2["tags"])
assert "NewTag" not in tags_later
assert "BETA" in tags_later # still present
def test_add_system_looking_tags_allowed_as_labels(
http: requests.Session, api_base: str, seeded_asset: dict
):
aid = seeded_asset["id"]
response = http.post(
f"{api_base}/api/assets/{aid}/tags",
json={
"tags": [
"models",
"model_type:manual",
"model:true",
"models:foo",
"input:true",
"output:true",
"uploaded:true",
"temp:true",
"temporary",
]
},
timeout=120,
)
body = response.json()
assert response.status_code == 200, body
assert "models" in body["total_tags"]
assert "model_type:manual" in body["total_tags"]
assert "model:true" in body["total_tags"]
assert "models:foo" in body["total_tags"]
assert "input:true" in body["total_tags"]
assert "output:true" in body["total_tags"]
assert "uploaded:true" in body["total_tags"]
assert "temp:true" in body["total_tags"]
assert "temporary" in body["total_tags"]
assert "newtag" not in tags_later
assert "beta" in tags_later # still present
def test_tags_list_order_and_prefix(http: requests.Session, api_base: str, seeded_asset: dict):

View File

@ -1,14 +1,11 @@
import json
import uuid
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import requests
import pytest
from app.assets.api.schemas_in import UploadAssetSpec
from app.assets.api.schemas_out import Asset, AssetCreated
from helpers import get_asset_filename
def test_asset_created_inherits_hash_field():
@ -23,18 +20,9 @@ def test_asset_created_inherits_hash_field():
assert AssetCreated.model_fields["hash"].annotation == Asset.model_fields["hash"].annotation
def test_upload_asset_spec_ignores_subfolder_field():
spec = UploadAssetSpec.model_validate(
{"tags": ["input"], "subfolder": "pasted", "name": "image.png"}
)
assert "subfolder" not in UploadAssetSpec.model_fields
assert not hasattr(spec, "subfolder")
def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes):
name = "dup_a.safetensors"
tags = ["models", "model_type:checkpoints", "unit-tests", "alpha"]
tags = ["models", "checkpoints", "unit-tests", "alpha"]
meta = {"purpose": "dup"}
data = make_asset_bytes(name)
files = {"file": (name, data, "application/octet-stream")}
@ -55,8 +43,6 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma
assert a2["asset_hash"] == a1["asset_hash"]
assert a2["hash"] == a1["hash"]
assert a2["id"] != a1["id"] # new reference with same content
assert a2.get("file_path") is None
assert a2.get("display_name") is None
# Third upload with the same data but different name also creates new AssetReference
files = {"file": (name, data, "application/octet-stream")}
@ -67,14 +53,12 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma
assert a3["asset_hash"] == a1["asset_hash"]
assert a3["id"] != a1["id"]
assert a3["id"] != a2["id"]
assert a3.get("file_path") is None
assert a3.get("display_name") is None
def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str):
# Seed a small file first
name = "fastpath_seed.safetensors"
tags = ["input", "unit-tests"]
tags = ["models", "checkpoints", "unit-tests"]
meta = {}
files = {"file": (name, b"B" * 1024, "application/octet-stream")}
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
@ -85,10 +69,9 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
assert b1["hash"] == h
# Now POST /api/assets with only hash and no file
hash_only_tags = ["models", "checkpoints", "unit-tests", "hash-labels"]
files = [
("hash", (None, h)),
("tags", (None, json.dumps(hash_only_tags))),
("tags", (None, json.dumps(tags))),
("name", (None, "fastpath_copy.safetensors")),
("user_metadata", (None, json.dumps({"purpose": "copy"}))),
]
@ -98,53 +81,6 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
assert b2["created_new"] is False
assert b2["asset_hash"] == h
assert b2["hash"] == h
assert "models" in b2["tags"]
assert "checkpoints" in b2["tags"]
assert "uploaded" not in b2["tags"]
assert not any(tag.startswith("model_type:") for tag in b2["tags"])
assert b2.get("file_path") is None
assert b2.get("display_name") is None
rg = http.get(f"{api_base}/api/assets/{b2['id']}", timeout=120)
detail = rg.json()
assert rg.status_code == 200, detail
assert detail.get("file_path") is None
assert detail.get("display_name") is None
def test_create_from_hash_with_model_tags_does_not_synthesize_file_path(
http: requests.Session, api_base: str
):
seed_name = "from_hash_seed.safetensors"
seed_tags = ["models", "model_type:checkpoints", "unit-tests"]
files = {"file": (seed_name, b"D" * 1024, "application/octet-stream")}
form = {
"tags": json.dumps(seed_tags),
"name": seed_name,
"user_metadata": json.dumps({}),
}
seed_r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
seed = seed_r.json()
assert seed_r.status_code == 201, seed
payload = {
"hash": seed["asset_hash"],
"name": "from_hash_copy.safetensors",
"tags": ["models", "model_type:checkpoints", "unit-tests", "spoofed"],
}
created_r = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120)
created = created_r.json()
assert created_r.status_code == 201, created
assert created["created_new"] is False
assert created["asset_hash"] == seed["asset_hash"]
assert created.get("file_path") is None
assert created.get("display_name") is None
detail_r = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
detail = detail_r.json()
assert detail_r.status_code == 200, detail
assert detail.get("file_path") is None
assert detail.get("display_name") is None
def test_upload_fastpath_with_known_hash_and_file(
@ -152,7 +88,7 @@ def test_upload_fastpath_with_known_hash_and_file(
):
# Seed
files = {"file": ("seed.safetensors", b"C" * 128, "application/octet-stream")}
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})}
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
b1 = r1.json()
assert r1.status_code == 201, b1
@ -168,49 +104,11 @@ def test_upload_fastpath_with_known_hash_and_file(
assert b2["created_new"] is False
assert b2["asset_hash"] == h
assert b2["hash"] == h
assert "checkpoints" in b2["tags"]
assert "uploaded" not in b2["tags"]
assert not any(tag == "model_type:checkpoints" for tag in b2["tags"])
def test_duplicate_byte_upload_is_reference_only_and_does_not_need_destination(
http: requests.Session, api_base: str
):
data = b"duplicate-reference-only" * 64
seed_files = {"file": ("duplicate-seed.bin", data, "application/octet-stream")}
seed_form = {
"tags": json.dumps(["input", "unit-tests", "duplicate-seed"]),
"name": "duplicate-seed.bin",
"user_metadata": json.dumps({}),
}
seed_response = http.post(api_base + "/api/assets", data=seed_form, files=seed_files, timeout=120)
seed = seed_response.json()
assert seed_response.status_code == 201, seed
duplicate_files = {"file": ("duplicate-copy.bin", data, "application/octet-stream")}
duplicate_form = {
"tags": json.dumps(["not-a-destination", "unit-tests", "duplicate-copy"]),
"name": "duplicate-copy.bin",
"user_metadata": json.dumps({}),
}
duplicate_response = http.post(
api_base + "/api/assets", data=duplicate_form, files=duplicate_files, timeout=120
)
duplicate = duplicate_response.json()
assert duplicate_response.status_code == 200, duplicate
assert duplicate["created_new"] is False
assert duplicate["asset_hash"] == seed["asset_hash"]
assert "not-a-destination" in duplicate["tags"]
assert "uploaded" not in duplicate["tags"]
assert "input" not in duplicate["tags"]
assert duplicate.get("file_path") is None
assert duplicate.get("display_name") is None
def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str):
data = [
("tags", "models,model_type:checkpoints"),
("tags", "models,checkpoints"),
("tags", json.dumps(["unit-tests", "alpha"])),
("name", "merge.safetensors"),
("user_metadata", json.dumps({"u": 1})),
@ -226,72 +124,7 @@ def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base
detail = rg.json()
assert rg.status_code == 200, detail
tags = set(detail["tags"])
assert {"models", "model_type:checkpoints", "unit-tests", "alpha"}.issubset(tags)
@pytest.mark.parametrize(
(
"tags",
"extension",
"expected_prefix",
"expected_display_prefix",
),
[
(["input", "unit-tests"], ".png", "input", ""),
(
["models", "model_type:checkpoints", "unit-tests"],
".safetensors",
"models/checkpoints",
"checkpoints/",
),
],
)
def test_upload_response_includes_file_path_and_display_name(
tags: list[str],
extension: str,
expected_prefix: str,
expected_display_prefix: str,
http: requests.Session,
api_base: str,
make_asset_bytes,
):
scope = f"response-paths-{uuid.uuid4().hex[:6]}"
scoped_tags = [*tags, scope]
name = f"asset_response_path{extension}"
files = {"file": (name, make_asset_bytes(name, 1024), "application/octet-stream")}
form = {
"tags": json.dumps(scoped_tags),
"name": name,
"user_metadata": json.dumps({}),
}
created_r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
created = created_r.json()
assert created_r.status_code in (200, 201), created
stored_filename = get_asset_filename(created["asset_hash"], extension)
expected_suffix = stored_filename
expected_file_path = f"{expected_prefix}/{expected_suffix}"
expected_display_name = f"{expected_display_prefix}{expected_suffix}"
assert created["file_path"] == expected_file_path
assert created["display_name"] == expected_display_name
detail_r = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
detail = detail_r.json()
assert detail_r.status_code == 200, detail
assert detail["file_path"] == expected_file_path
assert detail["display_name"] == expected_display_name
list_r = http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "limit": "50"},
timeout=120,
)
listed = list_r.json()
assert list_r.status_code == 200, listed
match = next(a for a in listed["assets"] if a["id"] == created["id"])
assert match["file_path"] == expected_file_path
assert match["display_name"] == expected_display_name
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
@pytest.mark.parametrize("root", ["input", "output"])
@ -359,55 +192,16 @@ def test_create_from_hash_endpoint_404(http: requests.Session, api_base: str):
assert body["error"]["code"] == "ASSET_NOT_FOUND"
def test_create_from_hash_accepts_arbitrary_system_looking_tags(
http: requests.Session, api_base: str
):
files = {"file": ("hash-seed.bin", b"hash-seed" * 64, "application/octet-stream")}
form = {
"tags": json.dumps(["input", "unit-tests", "hash-seed"]),
"name": "hash-seed.bin",
"user_metadata": json.dumps({}),
}
seed_response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
seed = seed_response.json()
assert seed_response.status_code == 201, seed
response = http.post(
api_base + "/api/assets/from-hash",
json={
"hash": seed["asset_hash"],
"name": "hash-copy.bin",
"tags": [
"models",
"model:true",
"models:foo",
"temporary:true",
"unit-tests",
"hash-copy",
],
},
timeout=120,
)
body = response.json()
assert response.status_code == 201, body
assert "models" in body["tags"]
assert "model:true" in body["tags"]
assert "models:foo" in body["tags"]
assert "temporary:true" in body["tags"]
assert "uploaded" not in body["tags"]
def test_upload_zero_byte_rejected(http: requests.Session, api_base: str):
files = {"file": ("empty.safetensors", b"", "application/octet-stream")}
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == "EMPTY_UPLOAD"
def test_upload_rejects_arbitrary_labels_without_required_destination_role(http: requests.Session, api_base: str):
def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str):
files = {"file": ("badroot.bin", b"A" * 64, "application/octet-stream")}
form = {"tags": json.dumps(["not-a-root", "whatever"]), "name": "badroot.bin", "user_metadata": json.dumps({})}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
@ -418,7 +212,7 @@ def test_upload_rejects_arbitrary_labels_without_required_destination_role(http:
def test_upload_user_metadata_must_be_json(http: requests.Session, api_base: str):
files = {"file": ("badmeta.bin", b"A" * 128, "application/octet-stream")}
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400
@ -434,7 +228,7 @@ def test_upload_requires_multipart(http: requests.Session, api_base: str):
def test_upload_missing_file_and_hash(http: requests.Session, api_base: str):
files = [
("tags", (None, json.dumps(["models", "model_type:checkpoints", "unit-tests"]))),
("tags", (None, json.dumps(["models", "checkpoints", "unit-tests"]))),
("name", (None, "x.safetensors")),
]
r = http.post(api_base + "/api/assets", files=files, timeout=120)
@ -443,33 +237,17 @@ def test_upload_missing_file_and_hash(http: requests.Session, api_base: str):
assert body["error"]["code"] == "MISSING_FILE"
def test_upload_models_unknown_model_type(http: requests.Session, api_base: str):
def test_upload_models_unknown_category(http: requests.Session, api_base: str):
files = {"file": ("m.safetensors", b"A" * 128, "application/octet-stream")}
form = {"tags": json.dumps(["models", "model_type:no_such_category", "unit-tests"]), "name": "m.safetensors"}
form = {"tags": json.dumps(["models", "no_such_category", "unit-tests"]), "name": "m.safetensors"}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400, body
assert r.status_code == 400
assert body["error"]["code"] == "INVALID_BODY"
assert body["error"]["message"].startswith("unknown models category")
@pytest.mark.parametrize("model_type", ["configs", "custom_nodes"])
def test_upload_models_rejects_non_model_registered_folder(
model_type: str, http: requests.Session, api_base: str
):
files = {"file": ("not-a-model.py", b"A" * 128, "application/octet-stream")}
form = {
"tags": json.dumps(["models", f"model_type:{model_type}", "unit-tests"]),
"name": "not-a-model.py",
}
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 400, body
assert body["error"]["code"] == "INVALID_BODY"
def test_upload_models_requires_model_type(http: requests.Session, api_base: str):
def test_upload_models_requires_category(http: requests.Session, api_base: str):
files = {"file": ("nocat.safetensors", b"A" * 64, "application/octet-stream")}
form = {"tags": json.dumps(["models"]), "name": "nocat.safetensors", "user_metadata": json.dumps({})}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
@ -478,152 +256,13 @@ def test_upload_models_requires_model_type(http: requests.Session, api_base: str
assert body["error"]["code"] == "INVALID_BODY"
def test_upload_extra_tags_are_labels_not_path_components(http: requests.Session, api_base: str):
def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
files = {"file": ("evil.safetensors", b"A" * 256, "application/octet-stream")}
form = {"tags": json.dumps(["models", "model_type:checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"}
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 201, body
assert ".." in body["tags"]
assert "zzz" in body["tags"]
assert "models" in body["tags"]
assert "model_type:checkpoints" in body["tags"]
@pytest.mark.parametrize(
("subfolder", "expected_tag", "unexpected_tags"),
[
("custom/session", None, {"custom", "session"}),
("pasted", "pasted", set()),
],
)
def test_upload_image_accepts_arbitrary_subfolder_but_only_known_values_become_tags(
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
subfolder: str,
expected_tag: str | None,
unexpected_tags: set[str],
):
name = f"upload-image-{uuid.uuid4().hex}.png"
files = {"image": (name, b"image-upload" * 64, "image/png")}
form = {"type": "input", "subfolder": subfolder}
response = http.post(api_base + "/upload/image", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 200, body
assert body["subfolder"] == subfolder
assert (comfy_tmp_base_dir / "input" / subfolder / body["name"]).exists()
asset = body["asset"]
tags = set(asset["tags"])
assert "input" in tags
assert "uploaded" in tags
if expected_tag:
assert expected_tag in tags
assert tags.isdisjoint(unexpected_tags)
def test_multipart_upload_accepts_system_looking_extra_labels(
http: requests.Session, api_base: str
):
files = {"file": ("relaxed-labels.bin", b"relaxed" * 64, "application/octet-stream")}
form = {
"tags": json.dumps(
[
"input",
"unit-tests",
"model:true",
"models:foo",
"temporary",
"uploaded:true",
]
),
"name": "relaxed-labels.bin",
"user_metadata": json.dumps({}),
}
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 201, body
assert "input" in body["tags"]
assert "model:true" in body["tags"]
assert "models:foo" in body["tags"]
assert "temporary" in body["tags"]
assert "uploaded:true" in body["tags"]
def test_multipart_upload_rejects_ambiguous_destination_roles(
http: requests.Session, api_base: str
):
files = {"file": ("ambiguous.bin", b"ambiguous" * 64, "application/octet-stream")}
form = {
"tags": json.dumps(["input", "output", "unit-tests"]),
"name": "ambiguous.bin",
"user_metadata": json.dumps({}),
}
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 400, body
assert body["error"]["code"] == "INVALID_BODY"
def test_multipart_upload_rejects_multiple_model_types_for_models_destination(
http: requests.Session, api_base: str
):
files = {"file": ("ambiguous-model.safetensors", b"ambiguous-model" * 64, "application/octet-stream")}
form = {
"tags": json.dumps(
["models", "model_type:checkpoints", "model_type:loras", "unit-tests"]
),
"name": "ambiguous-model.safetensors",
"user_metadata": json.dumps({}),
}
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 400, body
assert body["error"]["code"] == "INVALID_BODY"
@pytest.mark.parametrize(
("tags", "expected_root", "extension"),
[
(["input", "unit-tests", "upload-location-input"], "input", ".bin"),
(["output", "unit-tests", "upload-location-output"], "output", ".bin"),
(
["models", "model_type:checkpoints", "unit-tests", "upload-location-model"],
"models/checkpoints",
".safetensors",
),
],
)
def test_multipart_upload_role_selects_write_location(
http: requests.Session,
api_base: str,
comfy_tmp_base_dir: Path,
tags: list[str],
expected_root: str,
extension: str,
):
role = next(tag for tag in tags if tag in {"input", "models", "output"})
name = f"{role}-role-upload{extension}"
files = {"file": (name, f"{role}-role-bytes".encode() * 64, "application/octet-stream")}
form = {
"tags": json.dumps(tags),
"name": name,
"user_metadata": json.dumps({}),
}
response = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = response.json()
assert response.status_code == 201, body
stored_name = get_asset_filename(body["asset_hash"], extension)
expected_disk_path = comfy_tmp_base_dir / expected_root / stored_name
assert expected_disk_path.exists()
assert r.status_code == 400
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
def test_upload_empty_tags_rejected(http: requests.Session, api_base: str):

View File

@ -0,0 +1,137 @@
"""Tests for pre-execution validation that a node is actually executable.
validate_prompt rejects a node whose declared entry point does not resolve to a
real method (a V1 FUNCTION typo, or a V3 node missing its execute override) before
any node runs, attributing the error to the offending node.
"""
import asyncio
import nodes
from comfy_api.latest import io
from execution import node_not_executable_reason, validate_prompt
class _GoodV1Node:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "run"
OUTPUT_NODE = True
CATEGORY = "Test"
def run(self):
return (None,)
class _TypoV1Node:
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "invert" # method below is misspelled
OUTPUT_NODE = True
CATEGORY = "Test"
def invvert(self):
return (None,)
class _SideEffectInitV1Node:
"""Valid class-level method, but a constructor that must never run in validation."""
@classmethod
def INPUT_TYPES(cls):
return {"required": {}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "run"
OUTPUT_NODE = True
CATEGORY = "Test"
def __init__(self):
raise RuntimeError("__init__ must not run during validation")
def run(self):
return (None,)
def _v3_schema(node_id):
return io.Schema(
node_id=node_id,
display_name=node_id,
category="Test",
inputs=[],
outputs=[io.Image.Output()],
is_output_node=True,
)
class _GoodV3Node(io.ComfyNode):
@classmethod
def define_schema(cls):
return _v3_schema("GoodV3Node")
@classmethod
def execute(cls):
return io.NodeOutput(None)
class _TypoV3Node(io.ComfyNode):
@classmethod
def define_schema(cls):
return _v3_schema("TypoV3Node")
@classmethod
def exicute(cls): # typo: should be "execute"
return io.NodeOutput(None)
def _register(class_type, class_def):
nodes.NODE_CLASS_MAPPINGS[class_type] = class_def
def _validate(class_type):
prompt = {"1": {"class_type": class_type, "inputs": {}}}
return asyncio.run(validate_prompt("pid", prompt, None))
def test_good_node_passes():
_register("GoodV1Node", _GoodV1Node)
assert node_not_executable_reason(_GoodV1Node, "GoodV1Node") is None
valid, _, _, _ = _validate("GoodV1Node")
assert valid is True
def test_typo_node_rejected_with_node_error():
_register("TypoV1Node", _TypoV1Node)
valid, error, _, node_errors = _validate("TypoV1Node")
assert valid is False
assert error["type"] == "invalid_node_definition"
assert node_errors["1"]["class_type"] == "TypoV1Node"
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"
assert "invert" in node_errors["1"]["errors"][0]["details"]
def test_validation_does_not_instantiate_node():
"""A valid node is not constructed during validation, so __init__ never runs."""
_register("SideEffectInitV1Node", _SideEffectInitV1Node)
assert node_not_executable_reason(_SideEffectInitV1Node, "SideEffectInitV1Node") is None
valid, _, _, _ = _validate("SideEffectInitV1Node")
assert valid is True
def test_good_v3_node_passes():
_register("GoodV3Node", _GoodV3Node)
assert node_not_executable_reason(_GoodV3Node, "GoodV3Node") is None
valid, _, _, _ = _validate("GoodV3Node")
assert valid is True
def test_typo_v3_node_rejected_with_node_error():
_register("TypoV3Node", _TypoV3Node)
valid, error, _, node_errors = _validate("TypoV3Node")
assert valid is False
assert error["type"] == "invalid_node_definition"
assert node_errors["1"]["errors"][0]["type"] == "invalid_node_definition"

View File

@ -29,8 +29,6 @@ class TestFeatureFlags:
features = get_server_features()
assert "supports_preview_metadata" in features
assert features["supports_preview_metadata"] is True
assert "supports_model_type_tags" in features
assert features["supports_model_type_tags"] is True
assert "max_upload_size" in features
assert isinstance(features["max_upload_size"], (int, float))

View File

@ -12,8 +12,6 @@ class TestWebSocketFeatureFlags:
# Check expected server features
assert "supports_preview_metadata" in features
assert features["supports_preview_metadata"] is True
assert "supports_model_type_tags" in features
assert features["supports_model_type_tags"] is True
assert "max_upload_size" in features
assert isinstance(features["max_upload_size"], (int, float))
@ -77,5 +75,3 @@ class TestWebSocketFeatureFlags:
assert server_message["type"] == "feature_flags"
assert "supports_preview_metadata" in server_message["data"]
assert server_message["data"]["supports_preview_metadata"] is True
assert "supports_model_type_tags" in server_message["data"]
assert server_message["data"]["supports_model_type_tags"] is True