Compare commits

..

6 Commits

41 changed files with 288 additions and 3475 deletions

View File

@ -3,8 +3,8 @@ name: Backport Release
on:
workflow_dispatch:
inputs:
branch:
description: 'Source branch containing the backported commits (PR source branch into master)'
commit:
description: 'Full 40-char SHA of the tip commit of the backport source branch (the PR head commit that passed tests). The branch is resolved from this SHA and must be unique.'
required: true
type: string
@ -39,17 +39,72 @@ jobs:
git config user.name "fen-release[bot]"
git config user.email "fen-release[bot]@users.noreply.github.com"
- name: Validate source branch exists
- name: Resolve source branch from commit SHA
id: resolve
env:
SOURCE_BRANCH: ${{ inputs.branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
run: |
set -euo pipefail
git fetch origin "refs/heads/${SOURCE_BRANCH}:refs/remotes/origin/${SOURCE_BRANCH}"
if ! git show-ref --verify --quiet "refs/remotes/origin/${SOURCE_BRANCH}"; then
echo "::error::Source branch '${SOURCE_BRANCH}' not found on origin."
# Require a full 40-char lowercase-hex SHA. Short SHAs are ambiguous
# and we will be comparing this value against API responses (PR head
# SHA, ref tips) that always return the full form.
if [[ ! "${SOURCE_COMMIT}" =~ ^[0-9a-f]{40}$ ]]; then
echo "::error::Input commit '${SOURCE_COMMIT}' is not a full 40-char lowercase hex SHA."
exit 1
fi
# Fetch all remote branches so we can search for which one(s) point
# at this SHA. `actions/checkout` with fetch-depth: 0 fetches full
# history of the checked-out ref but does not necessarily populate
# every refs/remotes/origin/*, so do it explicitly.
git fetch --prune origin '+refs/heads/*:refs/remotes/origin/*'
# Verify the commit actually exists in this repo's object DB.
if ! git cat-file -e "${SOURCE_COMMIT}^{commit}" 2>/dev/null; then
echo "::error::Commit ${SOURCE_COMMIT} was not found in the repository."
exit 1
fi
# Find every remote branch whose tip == SOURCE_COMMIT. Exactly one
# branch must point at it. If zero, the commit isn't anyone's tip
# (likely stale, force-pushed past, or never the PR head). If more
# than one, the (branch -> SHA) mapping is ambiguous and we refuse
# to guess — the operator must give us a unique branch to release.
mapfile -t matching_branches < <(
git for-each-ref \
--format='%(refname:strip=3)' \
--points-at="${SOURCE_COMMIT}" \
refs/remotes/origin/ \
| grep -vx 'HEAD' || true
)
if [[ "${#matching_branches[@]}" -eq 0 ]]; then
echo "::error::No branch on origin has ${SOURCE_COMMIT} as its tip."
echo "::error::Either the branch was updated after you copied this SHA, or this commit was never the head of a branch."
exit 1
fi
if [[ "${#matching_branches[@]}" -gt 1 ]]; then
echo "::error::More than one branch on origin has ${SOURCE_COMMIT} as its tip; cannot pick one:"
for b in "${matching_branches[@]}"; do
echo "::error:: - ${b}"
done
echo "::error::Refusing to proceed with an ambiguous source branch."
exit 1
fi
source_branch="${matching_branches[0]}"
if [[ "${source_branch}" == "${DEFAULT_BRANCH}" ]]; then
echo "::error::Source branch must not be the default branch ('${DEFAULT_BRANCH}')."
exit 1
fi
echo "Resolved commit ${SOURCE_COMMIT} to branch '${source_branch}'."
echo "source_branch=${source_branch}" >> "$GITHUB_OUTPUT"
- name: Determine latest stable release
id: latest
env:
@ -102,23 +157,26 @@ jobs:
- name: Validate source branch is cut directly from the latest stable release
env:
SOURCE_BRANCH: ${{ inputs.branch }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
run: |
set -euo pipefail
source_sha="$(git rev-parse "refs/remotes/origin/${SOURCE_BRANCH}")"
# Use the user-provided SHA directly rather than re-resolving the branch
# tip — the resolve step already proved the branch tip equals SOURCE_COMMIT,
# and pinning to the SHA here makes the rest of the job TOCTOU-safe against
# someone pushing to the branch mid-run.
source_sha="${SOURCE_COMMIT}"
# The source branch must be cut directly off the latest stable tag.
# "Cut directly off" means: walking first-parent from the source tip
# eventually reaches LATEST_TAG_SHA. This rejects branches that were
# cut from master after the tag (which would carry unrelated commits),
# while accepting a branch rooted at the tag with N backport commits
# on top (each of which may itself be a merge — first-parent walks
# through the mainline of the branch).
if ! git rev-list --first-parent "${source_sha}" \
| grep -qx "${LATEST_TAG_SHA}"; then
# Walking first-parent from the source tip must reach LATEST_TAG_SHA.
# We capture rev-list into a variable and grep against a here-string
# rather than piping `rev-list | grep -q`: under `set -o pipefail`,
# `grep -q` would exit on first match and SIGPIPE the still-streaming
# `rev-list`, propagating exit 141 as a spurious "not found".
first_parent_chain="$(git rev-list --first-parent "${source_sha}")"
if ! grep -Fxq "${LATEST_TAG_SHA}" <<< "${first_parent_chain}"; then
echo "::error::Source branch '${SOURCE_BRANCH}' is not cut from '${LATEST_TAG}'."
echo "::error::Its first-parent history does not include ${LATEST_TAG_SHA}."
exit 1
@ -153,10 +211,11 @@ jobs:
added_count="$(printf '%s\n' "${all_added}" | grep -c . || true)"
echo "Source branch is cut directly from ${LATEST_TAG} with ${added_count} commit(s) on top."
- name: Validate PR exists, is named correctly, and checks pass
- name: Validate PR exists, is open, named correctly, has latest commit, and checks pass
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
SOURCE_BRANCH: ${{ inputs.branch }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
REPO: ${{ github.repository }}
run: |
@ -164,20 +223,22 @@ jobs:
expected_title="ComfyUI backport release ${NEW_VERSION}"
# Find open PRs from this branch into master
# Find open PRs from this branch into master. The --state open filter
# is load-bearing: a closed/merged PR with passing checks must not be
# accepted as authorization for a new release.
pr_json="$(
gh pr list \
--repo "${REPO}" \
--state open \
--head "${SOURCE_BRANCH}" \
--base master \
--json number,title,headRefOid \
--json number,title,headRefOid,state \
--limit 10
)"
pr_count="$(echo "${pr_json}" | jq 'length')"
if [[ "${pr_count}" -eq 0 ]]; then
echo "::error::No open PR found from '${SOURCE_BRANCH}' into 'master'."
echo "::error::No open PR found from '${SOURCE_BRANCH}' into 'master'. The PR must exist and be open."
exit 1
fi
@ -196,7 +257,19 @@ jobs:
exit 1
fi
echo "Found PR #${pr_number} titled '${expected_title}' (head ${pr_head_sha})."
# The PR's current head commit must equal the SHA the operator gave us.
# This is what closes the door on releasing stale code: if anyone has
# pushed to the branch since the operator validated tests passed, the
# PR head will have advanced past SOURCE_COMMIT and we abort. (The
# resolve step already proved the branch tip == SOURCE_COMMIT; this
# ties that same SHA to the PR that authorizes the release.)
if [[ "${pr_head_sha}" != "${SOURCE_COMMIT}" ]]; then
echo "::error::PR #${pr_number} head commit is ${pr_head_sha}, but the operator-provided commit is ${SOURCE_COMMIT}."
echo "::error::The PR has new commits since this release was authorized. Re-run with the new head SHA after verifying its checks."
exit 1
fi
echo "Found open PR #${pr_number} titled '${expected_title}' at head ${pr_head_sha} (matches operator-provided commit)."
# Verify all check runs on the head commit have completed successfully.
# A check is considered passing if conclusion is success, neutral, or skipped.
@ -238,7 +311,6 @@ jobs:
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
REPO: ${{ github.repository }}
SOURCE_BRANCH: ${{ inputs.branch }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
@ -274,7 +346,8 @@ jobs:
- name: Fast-forward merge source branch into release branch
env:
SOURCE_BRANCH: ${{ inputs.branch }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
run: |
set -euo pipefail
@ -285,12 +358,16 @@ jobs:
# that the source branch is rooted on the latest stable tag, and the
# release branch tip equals that same tag, this fast-forward should
# always succeed for a well-formed backport branch.
if ! git merge --ff-only "refs/remotes/origin/${SOURCE_BRANCH}"; then
echo "::error::Cannot fast-forward '${RELEASE_BRANCH}' to '${SOURCE_BRANCH}'. A merge commit would be required. Aborting."
#
# We merge the operator-provided SHA, not the branch ref, so a push to
# the branch in the window between resolve and now cannot smuggle new
# commits into the release.
if ! git merge --ff-only "${SOURCE_COMMIT}"; then
echo "::error::Cannot fast-forward '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}'). A merge commit would be required. Aborting."
exit 1
fi
echo "Fast-forwarded '${RELEASE_BRANCH}' to tip of '${SOURCE_BRANCH}'."
echo "Fast-forwarded '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}')."
- name: Bump version files
env:
@ -387,14 +464,20 @@ jobs:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
SOURCE_BRANCH: ${{ inputs.branch }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
run: |
# SOURCE_BRANCH is empty if the resolve step never produced an output
# (e.g. the workflow failed in or before that step). Show a placeholder
# in that case so the summary table still renders cleanly.
source_branch_display="${SOURCE_BRANCH:-(unresolved)}"
{
echo "## Backport release"
echo ""
echo "| Field | Value |"
echo "|---|---|"
echo "| Source branch | \`${SOURCE_BRANCH}\` |"
echo "| Source commit | \`${SOURCE_COMMIT}\` |"
echo "| Source branch | \`${source_branch_display}\` |"
echo "| Previous stable | \`${LATEST_TAG}\` |"
echo "| New version | \`${NEW_VERSION}\` |"
echo "| Release branch | \`${RELEASE_BRANCH}\` |"

View File

@ -20,7 +20,7 @@
[website-url]: https://www.comfy.org/
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
[discord-url]: https://www.comfy.org/discord
[discord-url]: https://discord.com/invite/comfyorg
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
[twitter-url]: https://x.com/ComfyUI

View File

@ -39,8 +39,6 @@ from app.assets.services import (
update_asset_metadata,
upload_from_temp_path,
)
from app.assets.services.cursor import InvalidCursorError
from app.assets.services.path_utils import compute_paths_for_response
from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef()
@ -162,19 +160,10 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
preview_url = None
else:
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
asset_content_hash = result.asset.hash if result.asset else None
if result.ref.file_path:
paths = compute_paths_for_response(result.ref.file_path)
file_path, display_name = paths if paths else (None, None)
else:
file_path, display_name = None, None
return schemas_out.Asset(
id=result.ref.id,
name=result.ref.name,
file_path=file_path,
display_name=display_name,
hash=asset_content_hash,
asset_hash=asset_content_hash,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
@ -183,7 +172,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
user_metadata=result.ref.user_metadata or {},
metadata=result.ref.system_metadata,
job_id=result.ref.job_id,
prompt_id=result.ref.job_id, # deprecated alias of job_id, kept for compatibility
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
created_at=result.ref.created_at,
updated_at=result.ref.updated_at,
last_access_time=result.ref.last_access_time,
@ -220,38 +209,24 @@ async def list_assets_route(request: web.Request) -> web.Response:
order_candidate = (q.order or "desc").lower()
order = order_candidate if order_candidate in {"asc", "desc"} else "desc"
try:
result = list_assets_page(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
job_ids=q.job_ids,
limit=q.limit,
offset=q.offset,
sort=sort,
order=order,
after=q.after,
)
except InvalidCursorError as e:
return _build_error_response(400, "INVALID_CURSOR", str(e))
result = list_assets_page(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=sort,
order=order,
)
summaries = [_build_asset_response(item) for item in result.items]
# has_more semantics differ by mode:
# - cursor mode: a non-empty next_cursor means there are more results.
# - offset mode: derived from total - (offset + page size).
if q.after is not None:
has_more = result.next_cursor is not None
else:
has_more = (q.offset + len(summaries)) < result.total
payload = schemas_out.AssetsList(
assets=summaries,
total=result.total,
has_more=has_more,
next_cursor=result.next_cursor,
has_more=(q.offset + len(summaries)) < result.total,
)
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
@ -426,16 +401,12 @@ async def upload_asset(request: web.Request) -> web.Response:
)
if spec.tags and spec.tags[0] == "models":
# tag[1] may be the standalone category ("checkpoints") or the
# slash-joined shape ("checkpoints/flux/...") that
# `get_name_and_tags_from_asset_path` and cloud both emit. Match
# `resolve_destination_from_tags` by extracting the first segment.
category = spec.tags[1].split("/", 1)[0] if len(spec.tags) >= 2 else ""
if (
len(spec.tags) < 2
or category not in folder_paths.folder_names_and_paths
or spec.tags[1] not in folder_paths.folder_names_and_paths
):
delete_temp_file_if_exists(parsed.tmp_path)
category = spec.tags[1] if len(spec.tags) >= 2 else ""
return _build_error_response(
400, "INVALID_BODY", f"unknown models category '{category}'"
)

View File

@ -1,5 +1,4 @@
import json
import uuid
from dataclasses import dataclass
from typing import Any, Literal
@ -54,18 +53,12 @@ class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
job_ids: list[str] = Field(default_factory=list, max_length=500)
# Accept either a JSON string (query param) or a dict
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
# Opaque keyset cursor. When supplied, `offset` is ignored. Cursor pagination
# is supported for sort values `created_at`, `updated_at`, `name`, `size`.
# Supplying `after` together with `sort=last_access_time` returns
# 400 INVALID_CURSOR; that sort only supports offset/limit.
after: str | None = None
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
"created_at"
@ -88,40 +81,6 @@ class ListAssetsQuery(BaseModel):
return out
return v
@field_validator("job_ids", mode="before")
@classmethod
def _split_and_validate_job_ids(cls, v):
# Accept "uuid1,uuid2" or ["uuid1","uuid2"] or repeated query params.
# Each entry must parse as a UUID; canonicalized to lowercase hyphenated form.
if v is None:
return []
if isinstance(v, str):
raw = [t.strip() for t in v.split(",") if t.strip()]
elif isinstance(v, list):
raw = []
for item in v:
if not isinstance(item, str):
raise ValueError(
f"job_ids entries must be strings, got {type(item).__name__}"
)
raw.extend([t.strip() for t in item.split(",") if t.strip()])
else:
raise ValueError(
f"job_ids must be a string or list of strings, got {type(v).__name__}"
)
out: list[str] = []
seen: set[str] = set()
for s in raw:
try:
canonical = str(uuid.UUID(s))
except ValueError as e:
raise ValueError(f"job_ids must be UUIDs: {s!r}") from e
if canonical not in seen:
seen.add(canonical)
out.append(canonical)
return out
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):

View File

@ -9,10 +9,7 @@ class Asset(BaseModel):
``id`` here is the AssetReference id, not the content-addressed Asset id."""
id: str
name: str = Field(..., deprecated=True)
file_path: str | None = None
display_name: str | None = None
hash: str | None = None
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
@ -43,8 +40,6 @@ class AssetsList(BaseModel):
assets: list[Asset]
total: int
has_more: bool
# Opaque cursor for the next page. Omitted when there are no more results.
next_cursor: str | None = None
class TagUsage(BaseModel):

View File

@ -264,21 +264,11 @@ def list_references_page(
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
metadata_filter: dict | None = None,
job_ids: Sequence[str] | None = None,
sort: str | None = None,
order: str | None = None,
after_cursor_value: object | None = None,
after_cursor_id: str | None = None,
) -> tuple[list[AssetReference], dict[str, list[str]], int]:
"""List references with pagination, filtering, and sorting.
When ``after_cursor_value``/``after_cursor_id`` are supplied the query uses
keyset pagination — ``offset`` is ignored and a WHERE clause selects rows
strictly after the given ``(sort_col, id)`` position in the active sort
direction. The cursor value must already be typed for the column
(datetime for time sorts, int for size, str for name); the caller decodes
the opaque cursor string and resolves to the typed value.
Returns (references, tag_map, total_count).
"""
base = (
@ -294,9 +284,6 @@ def list_references_page(
escaped, esc = escape_sql_like_string(name_contains)
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
if job_ids:
base = base.where(AssetReference.job_id.in_(list(job_ids)))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
@ -310,31 +297,9 @@ def list_references_page(
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetReference.created_at)
descending = order == "desc"
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
# Keyset WHERE: (sort_col, id) strictly less-than / greater-than the cursor.
# Equivalent to: sort_col <op> v OR (sort_col = v AND id <op> cursor_id).
if after_cursor_value is not None and after_cursor_id is not None:
if descending:
keyset = sa.or_(
sort_col < after_cursor_value,
sa.and_(sort_col == after_cursor_value, AssetReference.id < after_cursor_id),
)
else:
keyset = sa.or_(
sort_col > after_cursor_value,
sa.and_(sort_col == after_cursor_value, AssetReference.id > after_cursor_id),
)
base = base.where(keyset)
# Secondary ORDER BY id (matching the primary direction) gives the keyset
# comparison a deterministic tiebreaker on duplicate sort_col values.
id_exp = AssetReference.id.desc() if descending else AssetReference.id.asc()
sort_exp = sort_col.desc() if descending else sort_col.asc()
base = base.order_by(sort_exp, id_exp).limit(limit)
if after_cursor_id is None:
base = base.offset(offset)
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
@ -349,8 +314,6 @@ def list_references_page(
count_stmt = count_stmt.where(
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
)
if job_ids:
count_stmt = count_stmt.where(AssetReference.job_id.in_(list(job_ids)))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
@ -364,12 +327,7 @@ def list_references_page(
select(AssetReferenceTag.asset_reference_id, Tag.name)
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
# Preserve insertion order so the structural first tag (the root
# category like "models") stays in position 0 and the path-derived
# sub-path tag stays in position 1, matching cloud's behavior.
# tag_name is a deterministic tiebreaker when multiple tags share
# an added_at (same-batch insert via set_reference_tags).
.order_by(AssetReferenceTag.added_at.asc(), AssetReferenceTag.tag_name.asc())
.order_by(AssetReferenceTag.tag_name.asc())
)
for ref_id, tag_name in rows.all():
tag_map[ref_id].append(tag_name)
@ -397,8 +355,7 @@ def fetch_reference_asset_and_tags(
build_visible_owner_clause(owner_id),
)
.options(noload(AssetReference.tags))
# See list_references_page for the rationale behind ordering by added_at.
.order_by(AssetReferenceTag.added_at.asc(), Tag.name.asc())
.order_by(Tag.name.asc())
)
rows = session.execute(stmt).all()

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Iterable, Sequence
import sqlalchemy as sa
@ -21,12 +20,7 @@ from app.assets.database.queries.common import (
build_visible_owner_clause,
iter_row_chunks,
)
from app.assets.helpers import (
escape_sql_like_string,
expand_bucket_prefixes,
get_utc_now,
normalize_tags,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
@dataclass(frozen=True)
@ -50,26 +44,6 @@ class SetTagsResult:
total: list[str]
def _next_added_at_base(session: Session, reference_id: str) -> datetime:
"""Return a timestamp strictly greater than any existing
`added_at` for this reference. On platforms where the wall clock
has insufficient resolution between back-to-back commits (notably
Windows), two write batches on the same reference can otherwise
share a microsecond — the `ORDER BY added_at, tag_name` retrieval
then falls back to the alphabetic tiebreaker and user-tier tags
sort ahead of path-tier tags they were meant to follow.
"""
existing_max = session.execute(
sa.select(sa.func.max(AssetReferenceTag.added_at)).where(
AssetReferenceTag.asset_reference_id == reference_id
)
).scalar()
now = get_utc_now()
if existing_max is None:
return now
return max(existing_max + timedelta(microseconds=1), now)
def validate_tags_exist(session: Session, tags: list[str]) -> None:
"""Raise ValueError if any of the given tag names do not exist."""
existing_tag_names = set(
@ -103,13 +77,7 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
session.execute(
select(AssetReferenceTag.tag_name)
.where(AssetReferenceTag.asset_reference_id == reference_id)
# Match the response-path ordering used by
# list_references_page / fetch_reference_asset_and_tags so
# upload responses and subsequent GETs agree on tag order.
.order_by(
AssetReferenceTag.added_at.asc(),
AssetReferenceTag.tag_name.asc(),
)
.order_by(AssetReferenceTag.tag_name.asc())
)
).all()
]
@ -121,7 +89,7 @@ def set_reference_tags(
tags: Sequence[str],
origin: str = "manual",
) -> SetTagsResult:
desired = expand_bucket_prefixes(normalize_tags(tags))
desired = normalize_tags(tags)
current = set(get_reference_tags(session, reference_id))
@ -130,22 +98,15 @@ def set_reference_tags(
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
# Stagger added_at by microsecond per tag so the retrieval ORDER BY
# added_at preserves input order. Per-tag get_utc_now() calls can
# collide at microsecond resolution on fast machines, dropping the
# query to the tag_name alphabetical tiebreaker — same fix as in
# batch_insert_seed_assets. Read max(existing) so this batch sorts
# strictly after any prior batch on the same reference.
base_ts = _next_added_at_base(session, reference_id)
session.add_all(
[
AssetReferenceTag(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=base_ts + timedelta(microseconds=i),
added_at=get_utc_now(),
)
for i, t in enumerate(to_add)
for t in to_add
]
)
session.flush()
@ -175,7 +136,7 @@ def add_tags_to_reference(
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
norm = expand_bucket_prefixes(normalize_tags(tags))
norm = normalize_tags(tags)
if not norm:
total = get_reference_tags(session, reference_id=reference_id)
return AddTagsResult(added=[], already_present=[], total_tags=total)
@ -185,17 +146,10 @@ def add_tags_to_reference(
current = set(get_reference_tags(session, reference_id))
# Preserve the caller's insertion order rather than alphabetizing —
# the retrieval ORDER BY added_at + microsecond stagger only meaningfully
# preserves insertion order if "the order we insert in" actually matches
# the caller's intent.
want = set(norm)
to_add = [t for t in norm if t not in current]
to_add = sorted(want - current)
if to_add:
# See set_reference_tags for the rationale behind the per-tag stagger
# and the max(existing) seed.
base_ts = _next_added_at_base(session, reference_id)
with session.begin_nested() as nested:
try:
session.add_all(
@ -204,9 +158,9 @@ def add_tags_to_reference(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=base_ts + timedelta(microseconds=i),
added_at=get_utc_now(),
)
for i, t in enumerate(to_add)
for t in to_add
]
)
session.flush()

View File

@ -47,50 +47,6 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
def _known_bucket_prefixes() -> set[str]:
"""Lowercased model-category names eligible for standalone-prefix
expansion. Tags whose first slash segment matches one of these get
the bucket inserted as a separate token, so FE filters like
``include_tags=models,checkpoints`` keep matching even when the
asset lives in a nested subfolder (`models/checkpoints/flux/foo`).
Bare user labels with slashes whose first segment is not a registered
bucket (e.g. ``my-org/team-a``) pass through unchanged.
"""
try:
import folder_paths
return {
name.lower()
for name in folder_paths.folder_names_and_paths.keys()
if name != "custom_nodes"
}
except Exception:
return set()
def expand_bucket_prefixes(tags: list[str]) -> list[str]:
"""Insert standalone bucket tokens after any slash-joined tag whose
first segment is a registered model category. Preserves caller order
and is idempotent (existing bucket tokens are not duplicated).
"""
if not tags:
return list(tags)
buckets = _known_bucket_prefixes()
if not buckets:
return list(tags)
seen = set(tags)
result: list[str] = []
for t in tags:
result.append(t)
if "/" in t:
prefix = t.split("/", 1)[0]
if prefix.lower() in buckets and prefix not in seen:
result.append(prefix)
seen.add(prefix)
return result
def validate_blake3_hash(s: str) -> str:
"""Validate and normalize a blake3 hash string.

View File

@ -33,7 +33,6 @@ from app.assets.services.file_utils import (
verify_file_unchanged,
)
from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
from app.assets.services.image_dimensions import extract_image_dimensions
from app.assets.services.metadata_extract import extract_file_metadata
from app.assets.services.path_utils import (
compute_relative_filename,
@ -507,10 +506,6 @@ def enrich_asset(
if extract_metadata and metadata:
system_metadata = metadata.to_user_metadata()
if mime_type and mime_type.startswith("image/"):
dims = extract_image_dimensions(file_path, mime_type=mime_type)
if dims:
system_metadata.update(dims)
set_reference_system_metadata(session, reference_id, system_metadata)
if full_hash:

View File

@ -1,19 +1,8 @@
import contextlib
import mimetypes
import os
from datetime import timezone
from typing import Sequence
from app.assets.services.cursor import (
CursorPayload,
InvalidCursorError,
decode_cursor,
decode_cursor_int,
decode_cursor_time,
encode_cursor,
encode_cursor_from_time,
)
from app.assets.database.models import Asset
from app.assets.database.queries import (
@ -253,55 +242,17 @@ def get_asset_by_hash(asset_hash: str) -> AssetData | None:
return extract_asset_data(asset)
# Sort fields that support cursor pagination. `last_access_time` is not
# in this list — it falls back to offset/limit.
_CURSOR_SORT_FIELDS = ("created_at", "updated_at", "name", "size")
def list_assets_page(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
job_ids: Sequence[str] | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
after: str | None = None,
) -> ListAssetsResult:
"""List assets with optional cursor pagination.
When ``after`` is supplied it overrides ``offset``. The cursor's sort field
must match ``sort`` and be in the cursor-supported allowlist; mismatches
raise InvalidCursorError so the handler can map to 400 INVALID_CURSOR.
"""
cursor_value: object | None = None
cursor_id: str | None = None
# Mint next_cursor on every page where the sort is cursor-supported, not
# only when the request itself arrived with a cursor. Otherwise a first
# request (no `after`) returns next_cursor=None and the client can never
# enter cursor mode.
mint_cursor = sort in _CURSOR_SORT_FIELDS
if after is not None:
if sort not in _CURSOR_SORT_FIELDS:
raise InvalidCursorError(
f"cursor pagination is not supported for sort={sort!r}"
)
payload = decode_cursor(after, _CURSOR_SORT_FIELDS, expected_order=order)
if payload.sort_field != sort:
raise InvalidCursorError(
f"cursor sort field {payload.sort_field!r} does not match request sort {sort!r}"
)
cursor_value, cursor_id = _resolve_cursor_value(payload), payload.id
# Over-fetch by one row so we can distinguish "exactly `limit` rows total
# remaining" from "more rows past this page" without a second query. Drop
# the sentinel before returning.
fetch_limit = limit + 1 if mint_cursor else limit
with create_session() as session:
refs, tag_map, total = list_references_page(
session,
@ -310,23 +261,12 @@ def list_assets_page(
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
job_ids=job_ids,
limit=fetch_limit,
limit=limit,
offset=offset,
sort=sort,
order=order,
after_cursor_value=cursor_value,
after_cursor_id=cursor_id,
)
next_cursor: str | None = None
if mint_cursor and len(refs) > limit:
# There's at least one more row past this page — mint a cursor from
# the last row of the page (i.e. index `limit - 1`, since we
# over-fetched), and drop the sentinel.
next_cursor = _encode_next_cursor(refs[limit - 1], sort, order)
refs = refs[:limit]
items: list[AssetSummaryData] = []
for ref in refs:
items.append(
@ -337,39 +277,7 @@ def list_assets_page(
)
)
return ListAssetsResult(items=items, total=total, next_cursor=next_cursor)
def _resolve_cursor_value(payload: CursorPayload) -> object:
"""Map a decoded cursor payload to a column-typed Python value."""
if payload.sort_field in ("created_at", "updated_at"):
# DB stores naive UTC; strip tzinfo so the comparison binds against a
# `TIMESTAMP WITHOUT TIME ZONE` column without an offset shift.
return decode_cursor_time(payload).replace(tzinfo=None)
if payload.sort_field == "size":
return decode_cursor_int(payload)
return payload.value # name, str-typed
def _encode_next_cursor(ref, sort: str, order: str) -> str | None:
"""Mint a cursor pointing at *ref* for the given sort dimension.
Returns None when the boundary row carries a NULL sort value (e.g. an asset
record whose size_bytes hasn't been backfilled). Continuing pagination
across a NULL boundary is undefined under keyset ordering — better to
truncate cleanly here than to mint a cursor that mis-positions.
"""
if sort == "name":
return encode_cursor("name", ref.name, ref.id, order=order)
if sort == "size":
if ref.asset is None or ref.asset.size_bytes is None:
return None
return encode_cursor("size", str(ref.asset.size_bytes), ref.id, order=order)
# created_at / updated_at — DB datetimes are naive UTC; attach tz before encoding.
value = ref.created_at if sort == "created_at" else ref.updated_at
if value is None:
return None
return encode_cursor_from_time(sort, value.replace(tzinfo=timezone.utc), ref.id, order=order)
return ListAssetsResult(items=items, total=total)
def resolve_hash_to_path(

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import os
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
from datetime import datetime
from typing import TYPE_CHECKING, Any, TypedDict
from sqlalchemy.orm import Session
@ -13,14 +13,13 @@ from app.assets.database.queries import (
bulk_insert_references_ignore_conflicts,
bulk_insert_tags_and_meta,
delete_assets_by_ids,
ensure_tags_exist,
get_existing_asset_ids,
get_reference_ids_by_ids,
get_references_by_paths_and_asset_ids,
get_unreferenced_unhashed_asset_ids,
restore_references_by_paths,
)
from app.assets.helpers import expand_bucket_prefixes, get_utc_now
from app.assets.helpers import get_utc_now
if TYPE_CHECKING:
from app.assets.services.metadata_extract import ExtractedMetadata
@ -234,20 +233,13 @@ def batch_insert_seed_assets(
if ref_id not in inserted_ref_ids:
continue
# Stagger added_at by microsecond per tag within a reference so
# the retrieval ORDER BY added_at preserves the input list order
# (the path-derived root category stays at position 0). Without
# this, every tag in a bulk-insert batch shares current_time and
# the tag_name tiebreaker sorts them alphabetically — putting the
# subpath tag ahead of "models" since "c"/"d"/"l" < "m".
ref_tags = expand_bucket_prefixes(ref_data["tags"])
for tag_idx, tag in enumerate(ref_tags):
for tag in ref_data["tags"]:
tag_rows.append(
{
"asset_reference_id": ref_id,
"tag_name": tag,
"origin": "automatic",
"added_at": current_time + timedelta(microseconds=tag_idx),
"added_at": current_time,
}
)
@ -269,16 +261,6 @@ def batch_insert_seed_assets(
}
)
if tag_rows:
# Bucket-prefix expansion may have introduced tags the caller did
# not register via the upstream tag_pool (e.g. `checkpoints` for a
# nested `checkpoints/flux/foo` path). Pre-register the full set so
# the AssetReferenceTag.tag_name FK is satisfied; the underlying
# insert is ON CONFLICT DO NOTHING so re-registration is idempotent.
ensure_tags_exist(
session, {row["tag_name"] for row in tag_rows}, tag_type="user"
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows)
return BulkInsertResult(

View File

@ -1,225 +0,0 @@
"""Opaque keyset-pagination cursor for /api/assets.
Payload JSON uses short keys to keep the encoded length small:
{"s": <sort_field>, "v": <value>, "id": <id>, "o": <order>}
The `o` key binds the cursor to the sort direction it was minted under,
so replaying a `desc` cursor against an `asc` request fails with
``INVALID_CURSOR`` rather than silently walking the wrong direction.
`o` is mandatory on every payload — a cursor without it is rejected as
malformed.
Encoding is base64url with no padding. JSON serialization escapes `<`,
`>`, `&`, U+2028, and U+2029 in encoded string values so asset names
containing those characters produce a stable, byte-identical wire form
across any compatible implementation of the same payload format.
Time values are serialized as Unix microseconds (UTC) — microsecond
precision is sufficient to round-trip the timestamps stored by the
database without rounding rows in the same millisecond bucket.
"""
from __future__ import annotations
import base64
import json
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Iterable, Optional
class InvalidCursorError(ValueError):
"""Raised on a malformed, oversized, or unsupported-sort-field cursor.
Map to a 400 response with code ``INVALID_CURSOR`` at the handler.
"""
# Wire-format length caps. Cursors are user-controlled, so caps protect the
# decode path from oversized allocations and downstream SQL predicates from
# unbounded strings.
#
# MAX_CURSOR_VALUE_LENGTH is 512 to fit the `AssetReference.name` column max
# (`String(512)`) — otherwise a long-named asset would mint a cursor the same
# server then refuses on the next request.
MAX_ENCODED_CURSOR_LENGTH = 1024
MAX_CURSOR_VALUE_LENGTH = 512
MAX_CURSOR_ID_LENGTH = 128
@dataclass(frozen=True)
class CursorPayload:
sort_field: str
value: str
id: str
order: str
_VALID_ORDERS = ("asc", "desc")
def encode_cursor(sort_field: str, value: str, id: str, order: str = "desc") -> str:
"""Encode a cursor payload as a base64url (no-padding) string.
`order` binds the cursor to the sort direction it was minted under so a
later request with a flipped `order` query parameter is rejected with
``INVALID_CURSOR`` rather than silently walking the wrong direction.
"""
if order not in _VALID_ORDERS:
raise InvalidCursorError(f"order must be one of {_VALID_ORDERS}, got {order!r}")
# Symmetric input validation: the encoder must reject anything the
# decoder rejects, or the same server will mint cursors it then 400s on
# the next request.
if not id:
raise InvalidCursorError("id must be non-empty")
if len(id) > MAX_CURSOR_ID_LENGTH:
raise InvalidCursorError("id exceeds maximum length")
if len(value) > MAX_CURSOR_VALUE_LENGTH:
raise InvalidCursorError("value exceeds maximum length")
payload = {"s": sort_field, "v": value, "id": id, "o": order}
raw = json.dumps(payload, separators=(",", ":"), ensure_ascii=False)
# Match the default JSON escaping of HTML-significant characters and JS
# line/paragraph separators (U+2028 / U+2029) so an asset name carrying
# any of them encodes to identical bytes across runtimes. None of these
# characters appear in JSON structural syntax, so a global replace on the
# serialized output can only touch encoded values. Use explicit \uXXXX
# escapes for U+2028 / U+2029 so the source survives any editor / git
# tooling that normalizes invisible separators.
raw = (
raw.replace("<", "\\u003c")
.replace(">", "\\u003e")
.replace("&", "\\u0026")
.replace("\u2028", "\\u2028")
.replace("\u2029", "\\u2029")
)
encoded = base64.urlsafe_b64encode(raw.encode("utf-8")).rstrip(b"=").decode("ascii")
# Final wire-size guard: the per-field caps above are char-counted, but the
# wire cap applies to the base64url of the UTF-8-encoded, escape-expanded
# payload. A value full of multibyte or HTML-significant characters (e.g.
# 512 \u00d7 "\u00e9" or 512 \u00d7 "<") inflates well past MAX_ENCODED_CURSOR_LENGTH even
# though it passes the char-count check. Refuse to mint a cursor the decoder
# on the next request would reject.
if len(encoded) > MAX_ENCODED_CURSOR_LENGTH:
raise InvalidCursorError("encoded cursor exceeds maximum length")
return encoded
def encode_cursor_from_time(sort_field: str, t: datetime, id: str, order: str = "desc") -> str:
"""Encode a time-typed cursor at Unix microsecond precision.
Accepts an aware datetime (any timezone) and normalizes to UTC. Naive
datetimes are rejected so callers can't accidentally encode the local
wall-clock value of a UTC-stored timestamp.
"""
if t.tzinfo is None:
raise ValueError("encode_cursor_from_time requires an aware datetime")
micros = _datetime_to_unix_micros(t.astimezone(timezone.utc))
return encode_cursor(sort_field, str(micros), id, order=order)
def decode_cursor(
cursor: str,
allowed_sort_fields: Iterable[str],
expected_order: str | None = None,
) -> CursorPayload:
"""Parse an opaque cursor.
``allowed_sort_fields`` is the endpoint's accepted sort-field list — a
cursor carrying a field outside this set is rejected so a cursor minted
for one column can't be replayed against another (e.g. a ``created_at``
timestamp string compared against a ``name`` column).
``expected_order`` (``"asc"``/``"desc"``), when supplied, must match the
payload's ``o`` field. ``o`` is required on every payload; a cursor
missing it is rejected as malformed.
Passing no allowed fields rejects every cursor.
"""
if len(cursor) > MAX_ENCODED_CURSOR_LENGTH:
raise InvalidCursorError("cursor exceeds maximum length")
try:
# urlsafe_b64decode requires correct padding; we strip on encode, so
# restore the trailing '=' pad here.
padding = "=" * (-len(cursor) % 4)
raw = base64.urlsafe_b64decode(cursor + padding)
except (ValueError, base64.binascii.Error) as e:
raise InvalidCursorError(f"encoding: {e}") from e
try:
decoded = json.loads(raw)
except (json.JSONDecodeError, UnicodeDecodeError) as e:
raise InvalidCursorError(f"payload: {e}") from e
if not isinstance(decoded, dict):
raise InvalidCursorError("payload: expected object")
sort_field = decoded.get("s")
value = decoded.get("v")
id = decoded.get("id")
order = decoded.get("o")
if not isinstance(sort_field, str) or not isinstance(value, str) or not isinstance(id, str):
raise InvalidCursorError("payload: missing or non-string s/v/id")
if id == "":
raise InvalidCursorError("missing id")
if len(id) > MAX_CURSOR_ID_LENGTH:
raise InvalidCursorError("id exceeds maximum length")
if len(value) > MAX_CURSOR_VALUE_LENGTH:
raise InvalidCursorError("value exceeds maximum length")
if sort_field not in allowed_sort_fields:
raise InvalidCursorError(f"unsupported sort field {sort_field!r}")
if not isinstance(order, str):
raise InvalidCursorError("missing or non-string o")
if order not in _VALID_ORDERS:
raise InvalidCursorError(f"unsupported order {order!r}")
if expected_order is not None and order != expected_order:
raise InvalidCursorError(
f"cursor order {order!r} does not match request order {expected_order!r}"
)
return CursorPayload(sort_field=sort_field, value=value, id=id, order=order)
def decode_cursor_time(payload: Optional[CursorPayload]) -> datetime:
"""Parse a time-typed cursor value as Unix microseconds, returning UTC."""
if payload is None:
raise InvalidCursorError("nil cursor payload")
try:
micros = int(payload.value)
except ValueError as e:
raise InvalidCursorError(f"value is not a valid timestamp: {e}") from e
try:
return _unix_micros_to_datetime(micros)
except (OverflowError, OSError, ValueError) as e:
# Crafted out-of-range microseconds (e.g. > datetime.MAX_YEAR) blow up
# in fromtimestamp / datetime construction. Map to 400, not 500.
raise InvalidCursorError(f"value is out of representable range: {e}") from e
def decode_cursor_int(payload: Optional[CursorPayload]) -> int:
"""Parse a cursor value as a base-10 integer."""
if payload is None:
raise InvalidCursorError("nil cursor payload")
try:
return int(payload.value)
except ValueError as e:
raise InvalidCursorError(f"value is not a valid integer: {e}") from e
_EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc)
def _datetime_to_unix_micros(t: datetime) -> int:
"""Convert an aware UTC datetime to Unix microseconds (integer math)."""
delta = t - _EPOCH
return (delta.days * 86_400 + delta.seconds) * 1_000_000 + delta.microseconds
def _unix_micros_to_datetime(micros: int) -> datetime:
"""Convert Unix microseconds to a UTC datetime, preserving precision."""
seconds, micro_remainder = divmod(micros, 1_000_000)
return datetime.fromtimestamp(seconds, tz=timezone.utc).replace(microsecond=micro_remainder)

View File

@ -1,63 +0,0 @@
"""Image dimension extraction for asset ingest.
Reads only the image header via Pillow to capture width/height cheaply,
without a full pixel decode. Returns a metadata dict suitable for merging
into ``AssetReference.system_metadata``.
"""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
def extract_image_dimensions(
file_path: str, mime_type: str | None = None
) -> dict[str, Any] | None:
"""Extract image dimensions for the file at ``file_path``.
Args:
file_path: Absolute path to a file on disk.
mime_type: Optional MIME type hint. When provided and not prefixed
with ``image/``, extraction is skipped without touching the file.
Returns:
``{"kind": "image", "width": W, "height": H}`` when the file is a
recognizable image with positive dimensions, otherwise ``None``.
The dict shape is intended to be merged into ``system_metadata`` so the
asset response surfaces ``metadata.kind`` plus dimension fields for image
assets. Forward-compatible: future media kinds (e.g. ``"video"`` with
duration/fps) can extend this shape without schema changes.
"""
if mime_type is not None and not mime_type.startswith("image/"):
return None
try:
from PIL import Image, UnidentifiedImageError
except ImportError:
logger.debug(
"Pillow not available; skipping image dimension extraction for %s",
file_path,
)
return None
try:
with Image.open(file_path) as img:
width, height = img.size
except (OSError, UnidentifiedImageError, ValueError) as exc:
logger.debug(
"Failed to read image dimensions from %s: %s", file_path, exc
)
return None
if (
not isinstance(width, int)
or not isinstance(height, int)
or width <= 0
or height <= 0
):
return None
return {"kind": "image", "width": width, "height": height}

View File

@ -17,11 +17,9 @@ from app.assets.database.queries import (
get_reference_by_file_path,
get_reference_tags,
get_or_create_reference,
list_references_by_asset_id,
reference_exists,
remove_missing_tag_for_asset_id,
set_reference_metadata,
set_reference_system_metadata,
set_reference_tags,
update_asset_hash_and_mime,
upsert_asset,
@ -31,7 +29,6 @@ from app.assets.database.queries import (
from app.assets.helpers import get_utc_now, normalize_tags
from app.assets.services.bulk_ingest import batch_insert_seed_assets
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.image_dimensions import extract_image_dimensions
from app.assets.services.path_utils import (
compute_relative_filename,
get_name_and_tags_from_asset_path,
@ -121,14 +118,6 @@ def _ingest_file_from_path(
user_metadata=user_metadata,
)
_maybe_store_image_dimensions(
session,
reference_id=reference_id,
file_path=locator,
mime_type=mime_type,
current_system_metadata=ref.system_metadata,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
@ -299,13 +288,6 @@ def _register_existing_asset(
user_metadata=new_meta,
)
_backfill_image_dimensions_from_siblings(
session,
asset_id=asset.id,
new_reference_id=ref.id,
current_system_metadata=ref.system_metadata,
)
if tags is not None:
set_reference_tags(
session,
@ -352,87 +334,6 @@ def _update_metadata_with_filename(
)
_IMAGE_DIMENSION_KEYS = ("kind", "width", "height")
def _maybe_store_image_dimensions(
session: Session,
reference_id: str,
file_path: str,
mime_type: str | None,
current_system_metadata: dict | None,
) -> None:
"""Populate ``kind``/``width``/``height`` on system_metadata for image refs.
Non-image MIME types are a no-op. Pre-existing keys (e.g. enricher-written
safetensors metadata, download provenance) are preserved by merge.
"""
if not mime_type or not mime_type.startswith("image/"):
return
dims = extract_image_dimensions(file_path, mime_type=mime_type)
if not dims:
return
current = current_system_metadata or {}
merged = dict(current)
merged.update(dims)
if merged != current:
set_reference_system_metadata(
session,
reference_id=reference_id,
system_metadata=merged,
)
def _backfill_image_dimensions_from_siblings(
session: Session,
asset_id: str,
new_reference_id: str,
current_system_metadata: dict | None,
) -> None:
"""Copy image dimension keys from any sibling reference of the same asset.
The from-hash path doesn't read the file bytes, so dimensions can't be
extracted there directly. When another reference of the same asset already
carries image dimensions, copy them onto the new reference so consumers
see consistent metadata regardless of how the asset was registered.
Best-effort: missing siblings, non-image siblings, or absent dimension
keys leave the target reference unchanged.
"""
current = current_system_metadata or {}
if current.get("kind") == "image" and "width" in current and "height" in current:
return
for sibling in list_references_by_asset_id(session, asset_id):
if sibling.id == new_reference_id:
continue
meta = sibling.system_metadata or {}
if meta.get("kind") != "image":
continue
width = meta.get("width")
height = meta.get("height")
if (
type(width) is not int
or type(height) is not int
or width <= 0
or height <= 0
):
continue
merged = dict(current)
merged["kind"] = "image"
merged["width"] = width
merged["height"] = height
if merged != current:
set_reference_system_metadata(
session,
reference_id=new_reference_id,
system_metadata=merged,
)
return
def _sanitize_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
return n if n else fallback

View File

@ -3,12 +3,11 @@ from pathlib import Path
from typing import Literal
import folder_paths
from app.assets.helpers import normalize_tags
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
RootCategory = Literal["input", "output", "temp", "models"]
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build list of (folder_name, base_paths[]) for all model locations.
@ -28,51 +27,27 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs).
Accepts both the legacy one-tag-per-directory shape
(``["models", "diffusers", "Kolors", "text_encoder"]``) and the
slash-joined shape emitted by :func:`get_name_and_tags_from_asset_path`
(``["models", "diffusers/Kolors/text_encoder"]``). Hybrid shapes that
mix the two within a single call (e.g.
``["models", "diffusers", "Kolors/text_encoder"]``) are also
accepted: each entry after ``tags[0]`` is split on ``/`` and
concatenated, so the two shapes — and any mix of them — resolve to
the same destination. The same safety checks are applied to each
component after expansion.
"""
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
if not tags:
raise ValueError("tags must not be empty")
root = tags[0].lower()
# Expand any slash-joined entries into individual path components so
# the rest of the function can treat both tag shapes uniformly. Each
# component is also stripped, so " a / b " behaves like ["a", "b"].
expanded: list[str] = []
for t in tags[1:]:
for part in str(t).split("/"):
part = part.strip()
if part:
expanded.append(part)
if root == "models":
if not expanded:
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
category = expanded[0]
try:
bases = folder_paths.folder_names_and_paths[category][0]
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{category}'")
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{category}'")
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = expanded[1:]
raw_subdirs = tags[2:]
elif root == "input":
base_dir = os.path.abspath(folder_paths.get_input_directory())
raw_subdirs = expanded
raw_subdirs = tags[1:]
elif root == "output":
base_dir = os.path.abspath(folder_paths.get_output_directory())
raw_subdirs = expanded
raw_subdirs = tags[1:]
else:
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
_sep_chars = frozenset(("/", "\\", os.sep))
@ -90,109 +65,35 @@ def validate_path_within_base(candidate: str, base: str) -> None:
raise ValueError("destination escapes base directory")
def compute_paths_for_response(
file_path: str,
) -> tuple[str, str | None] | None:
"""Compute (file_path, display_name) for an Asset response.
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
`file_path` is a logical locator under the asset namespace: `<root>/<rel>`
for input/output/temp assets and `<root>/<bucket>/<rel>` for model assets.
`display_name` is the path below that root or model bucket, suitable for UI
labels. Returns None when the absolute path is not under a known asset root.
For non-model paths, returns None.
"""
try:
root, bucket, rel = get_asset_root_bucket_and_filepath(file_path)
root_category, rel_path = get_asset_category_and_relative_path(file_path)
except ValueError:
return None
display_name = rel or None
if bucket is None:
response_file_path = f"{root}/{rel}" if rel else root
else:
response_file_path = f"{root}/{bucket}/{rel}" if rel else f"{root}/{bucket}"
return response_file_path, display_name
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
def compute_display_name(file_path: str) -> str | None:
"""Return the asset's `display_name`, or None for unknown paths."""
result = compute_paths_for_response(file_path)
return result[1] if result else None
def compute_file_path(file_path: str) -> str | None:
"""Return the asset's logical `file_path`, or None for unknown paths."""
result = compute_paths_for_response(file_path)
return result[0] if result else None
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the path relative to the asset root or model category, using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
/.../input/sub/image.png -> "sub/image.png"
For unknown paths, returns None.
"""
return compute_display_name(file_path)
def get_asset_root_bucket_and_filepath(
file_path: str,
) -> tuple[RootCategory, str | None, str]:
"""Decompose an absolute path into (root, bucket, path-under-bucket).
`bucket` is only set for model assets. The returned relative path always
uses `/` separators and is empty when the path is exactly the matched root.
Raises:
ValueError: path does not belong to any known root.
"""
fp_abs = os.path.abspath(file_path)
def _check_is_within(child: str, parent: str) -> bool:
return Path(child).is_relative_to(parent)
def _compute_relative(child: str, parent: str) -> str:
# Normalize relative path, stripping any leading ".." components
# by anchoring to root (os.sep) then computing relpath back from it.
rel = os.path.relpath(
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
)
return "" if rel == "." else rel.replace(os.sep, "/")
for root_tag, getter in (
("input", folder_paths.get_input_directory),
("output", folder_paths.get_output_directory),
("temp", folder_paths.get_temp_directory),
):
base = os.path.abspath(getter())
if _check_is_within(fp_abs, base):
return root_tag, None, _compute_relative(fp_abs, base)
# models: check deepest matching base to avoid ambiguity.
best: tuple[int, str, str] | None = None
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _check_is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
return "models", bucket, rel_inside
raise ValueError(
f"Path is not within input, output, temp, or configured model bases: {file_path}"
)
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_asset_category_and_relative_path(
file_path: str,
) -> tuple[RootCategory, str]:
) -> tuple[Literal["input", "output", "temp", "models"], str]:
"""Determine which root category a file path belongs to.
Categories:
@ -259,21 +160,7 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return (name, tags) derived from a filesystem path.
- name: base filename with extension
- tags: [root_category] for paths with no parent subdirectories,
[root_category, slash_joined_subpath] otherwise. The parent subpath
(everything between the root category and the filename) is collapsed
into a single tag rather than emitted as one tag per directory, so
consumers can use ``tags[1]`` as a stable category identifier that
survives nested directory layouts (e.g. diffusers components).
The subpath is lowercased to match the canonicalization applied by
:func:`ensure_tags_exist`; without that, the
``asset_reference_tags.tag_name`` FK to the lowercased ``tags.name``
would fail for any path containing uppercase letters. The root
category is lowercase by construction in
:func:`get_asset_category_and_relative_path`, so no separate cast
is applied here. Consumers that need to look up providers keyed on
original-case paths should normalize their lookup key to lowercase.
- tags: [root_category] + parent folder names in order
Raises:
ValueError: path does not belong to any known root.
@ -283,7 +170,4 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
parent_parts = [
part for part in p.parent.parts if part not in (".", "..", p.anchor)
]
tags = [root_category]
if parent_parts:
tags.append("/".join(parent_parts).lower())
return p.name, list(dict.fromkeys(t.strip() for t in tags if t.strip()))
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))

View File

@ -71,7 +71,6 @@ class AssetSummaryData:
class ListAssetsResult:
items: list[AssetSummaryData]
total: int
next_cursor: str | None = None
@dataclass(frozen=True)

View File

@ -62,6 +62,8 @@ def get_comfy_package_versions():
def check_comfy_packages_versions():
"""Warn for every comfy* package whose installed version is below requirements.txt."""
from packaging.version import InvalidVersion, parse as parse_pep440
outdated_packages = []
for pkg in get_comfy_package_versions():
installed_str = pkg["installed"]
required_str = pkg["required"]
@ -73,19 +75,26 @@ def check_comfy_packages_versions():
logging.error(f"Failed to check {pkg['name']} version: {e}")
continue
if outdated:
app.logger.log_startup_warning(
f"""
outdated_packages.append((pkg["name"], installed_str, required_str))
else:
logging.info("{} version: {}".format(pkg["name"], installed_str))
if outdated_packages:
package_warnings = "\n".join(
f"Installed {name} version {installed} is lower than the recommended version {required}."
for name, installed, required in outdated_packages
)
app.logger.log_startup_warning(
f"""
________________________________________________________________________
WARNING WARNING WARNING WARNING WARNING
Installed {pkg["name"]} version {installed_str} is lower than the recommended version {required_str}.
{package_warnings}
{get_missing_requirements_message()}
________________________________________________________________________
""".strip()
)
else:
logging.info("{} version: {}".format(pkg["name"], installed_str))
)
REQUEST_TIMEOUT = 10 # seconds

View File

@ -1,73 +0,0 @@
"""Enrich executed-node output entries with asset id."""
import logging
import os
def enrich_output_with_assets(output_ui: dict) -> dict:
"""Inject asset ``id`` into file-type output entries when --enable-assets is set.
Returns a new dict; entries without a resolvable on-disk file path are left
unchanged. Errors are caught per-entry so a failure never blocks the WS
message from sending.
"""
from comfy.cli_args import args
if not args.enable_assets:
return output_ui
import folder_paths
from app.assets.services.ingest import register_file_in_place, DependencyMissingError
from app.assets.database.queries.asset_reference import get_reference_by_file_path
from app.database.db import create_session
enriched = {}
for key, entries in output_ui.items():
if not isinstance(entries, list):
enriched[key] = entries
continue
new_entries = []
for entry in entries:
if not isinstance(entry, dict) or "filename" not in entry or "type" not in entry:
new_entries.append(entry)
continue
try:
base = folder_paths.get_directory_by_type(entry["type"])
if base is None:
new_entries.append(entry)
continue
base_abs = os.path.abspath(base)
abs_path = os.path.abspath(os.path.join(base_abs, entry.get("subfolder") or "", entry["filename"]))
try:
if os.path.commonpath([base_abs, abs_path]) != base_abs:
raise ValueError("escapes base")
except ValueError:
logging.warning("Asset enrichment skipped (path escapes base): %s", entry.get("filename"))
new_entries.append(entry)
continue
if not os.path.isfile(abs_path):
new_entries.append(entry)
continue
# Try DB lookup first (cached node re-send); fall back to registering inline.
asset_id = None
with create_session() as session:
db_ref = get_reference_by_file_path(session, abs_path)
if db_ref is not None:
asset_id = db_ref.id
if asset_id is None:
result = register_file_in_place(
abs_path=abs_path,
name=entry["filename"],
tags=[entry["type"]],
)
asset_id = result.ref.id
entry = dict(entry)
entry["id"] = asset_id
except DependencyMissingError:
logging.warning("Asset enrichment skipped (blake3 not available): %s", entry.get("filename"))
except Exception:
logging.warning("Failed to enrich output entry with asset id: %s", entry.get("filename"), exc_info=True)
new_entries.append(entry)
enriched[key] = new_entries
return enriched

View File

@ -8,6 +8,82 @@ from comfy_api.latest import _io
MISSING = object()
class NotNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ComfyNotNode",
display_name="Not",
category="utils/logic",
description="Logical NOT operation. Returns true if the value is falsy. Uses Python's rules for truthiness.",
search_aliases=["invert", "toggle", "negate", "flip boolean"],
inputs=[
io.AnyType.Input("value"),
],
outputs=[
io.Boolean.Output(),
],
)
@classmethod
def execute(cls, value) -> io.NodeOutput:
return io.NodeOutput(not value)
class AndNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.Autogrow.TemplatePrefix(
input=io.AnyType.Input("value"),
prefix="value",
min=1,
)
return io.Schema(
node_id="ComfyAndNode",
display_name="And",
category="utils/logic",
description="Logical AND operation. Returns true if all of the values are truthy. Uses Python's rules for truthiness.",
search_aliases=["all", "every"],
inputs=[
io.Autogrow.Input("values", template=template),
],
outputs=[
io.Boolean.Output(),
],
)
@classmethod
def execute(cls, values: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(all(values.values()))
class OrNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.Autogrow.TemplatePrefix(
input=io.AnyType.Input("value"),
prefix="value",
min=1,
)
return io.Schema(
node_id="ComfyOrNode",
display_name="Or",
category="utils/logic",
description="Logical OR operation. Returns true if any of the values are truthy. Uses Python's rules for truthiness.",
search_aliases=["any", "some"],
inputs=[
io.Autogrow.Input("values", template=template),
],
outputs=[
io.Boolean.Output(),
],
)
@classmethod
def execute(cls, values: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(any(values.values()))
class SwitchNode(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -261,6 +337,9 @@ class LogicExtension(ComfyExtension):
return [
SwitchNode,
CustomComboNode,
NotNode,
AndNode,
OrNode,
# SoftSwitchNode,
# ConvertStringToComboNode,
# DCTestNode,

View File

@ -40,7 +40,6 @@ from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.utils import CurrentNodeContext
from comfy_execution.asset_enrichment import enrich_output_with_assets
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
@ -419,15 +418,11 @@ def _is_intermediate_output(dynprompt, node_id):
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
if server.client_id is None:
return
cached_ui = cached.ui or {}
output = cached_ui.get("output", None)
if output:
output = enrich_output_with_assets(output)
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": output, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
if cached.ui is not None:
ui_outputs[node_id] = cached.ui
@ -567,7 +562,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
"output": output_ui
}
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": enrich_output_with_assets(output_ui), "prompt_id": prompt_id }, server.client_id)
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph:
cached_outputs = []
new_node_ids = []

View File

@ -1517,22 +1517,6 @@ paths:
schema:
type: integer
default: 0
description: |
Offset-based pagination. Cursor pagination via `after` is preferred
for sequential walks (stable across concurrent inserts/deletes) but
`offset` remains fully supported for random access (jump-to-page
UIs, "showing items XY of N" displays). When both are supplied,
`after` wins and `offset` is ignored.
- name: after
in: query
schema:
type: string
description: |
Opaque cursor for keyset pagination. Pass the `next_cursor` value
from a previous response to fetch the next page. Stable across
inserts/deletes between pages. Supported with `sort` values
`created_at`, `updated_at`, `name`, and `size`. Malformed or
unsupported cursors return 400 with `INVALID_CURSOR`.
- name: include_tags
in: query
schema:
@ -1572,17 +1556,6 @@ paths:
type: string
enum: [asc, desc]
description: Sort direction
- name: job_ids
in: query
schema:
type: array
maxItems: 500
items:
type: string
format: uuid
style: form
explode: true
description: "Filter assets by associated job UUIDs. Accepts repeated query params (e.g. `?job_ids=a&job_ids=b`) or a single comma-separated value (`?job_ids=a,b`)."
- name: include_public
in: query
schema:
@ -1602,12 +1575,6 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/ListAssetsResponse"
"400":
description: Malformed query or cursor (e.g. `INVALID_CURSOR`)
content:
application/json:
schema:
$ref: "#/components/schemas/AssetsApiError"
post:
operationId: createAsset
tags: [assets]
@ -6643,18 +6610,7 @@ components:
description: Unique identifier for the asset
name:
type: string
deprecated: true
description: Name of the asset file
file_path:
type: string
nullable: true
x-runtime: [cloud, local]
description: "Logical asset locator under the namespace root. Not a unique reference key; use `id` for identity."
display_name:
type: string
nullable: true
x-runtime: [cloud, local]
description: "Human-facing display label for the asset. Not a unique reference key; use `id` for identity."
hash:
type: string
nullable: true
@ -6794,42 +6750,6 @@ components:
type: integer
has_more:
type: boolean
next_cursor:
type: string
description: |
Opaque cursor to fetch the next page. Pass back as the `after`
query parameter. Omitted when there are no more results.
AssetsApiError:
type: object
description: Error envelope returned by the assets API on 400 responses.
required:
- error
properties:
error:
type: object
required:
- code
- message
- details
properties:
code:
type: string
description: |
Machine-readable error code. `INVALID_CURSOR` is returned when the
`after` cursor is malformed, oversized, or its sort field does
not match the request's `sort`. `INVALID_QUERY` covers other
Pydantic validation failures.
enum: [INVALID_CURSOR, INVALID_QUERY]
message:
type: string
details:
type: object
description: |
Free-form, code-specific context. `INVALID_QUERY` populates this
with an `errors` array of Pydantic validation entries;
`INVALID_CURSOR` returns an empty object.
additionalProperties: true
TagInfo:
type: object
@ -6920,13 +6840,6 @@ components:
enum: [input, output, temp]
display_name:
type: string
id:
type: string
format: uuid
description: |
Asset reference UUID. Present only when the server is started with
`--enable-assets` and the file resolves to a registered asset.
Fetch the full asset via `GET /api/assets/{id}`.
NodeOutputs:
type: object
@ -8810,4 +8723,4 @@ components:
items:
$ref: "#/components/schemas/TaskEntry"
pagination:
$ref: "#/components/schemas/PaginationInfo"
$ref: "#/components/schemas/PaginationInfo"

View File

@ -236,8 +236,6 @@ def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_bas
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
body = r.json()
assert r.status_code == 201, body
from helpers import assert_hash_fields_consistent
assert_hash_fields_consistent(body)
return body

View File

@ -26,26 +26,3 @@ def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
def get_asset_filename(asset_hash: str, extension: str) -> str:
return asset_hash.removeprefix("blake3:") + extension
def assert_hash_fields_consistent(body: dict, expected_hash: str | None = None) -> None:
"""Assert hash and asset_hash invariants on an Asset response.
Both must be present or both absent (so a regression that drops only one
is caught). When present, they must equal each other and, if expected_hash
is provided, must equal that value.
"""
hash_present = "hash" in body
asset_hash_present = "asset_hash" in body
assert hash_present == asset_hash_present, (
f"hash and asset_hash must both be present or both absent: "
f"hash present={hash_present}, asset_hash present={asset_hash_present}"
)
if hash_present:
h = body["hash"]
ah = body["asset_hash"]
assert h == ah, f"hash and asset_hash must match: hash={h!r}, asset_hash={ah!r}"
if expected_hash is not None:
assert h == expected_hash, (
f"hash must equal expected: got {h!r}, expected {expected_hash!r}"
)

View File

@ -21,7 +21,6 @@ from app.assets.database.queries import (
get_reference_ids_by_ids,
ensure_tags_exist,
add_tags_to_reference,
set_reference_tags,
)
from app.assets.helpers import get_utc_now
@ -159,203 +158,6 @@ class TestListReferencesPage:
refs, _, _ = list_references_page(session, sort="name", order="asc")
assert refs[0].name == "large"
def test_job_ids_filter(self, session: Session):
asset = _make_asset(session, "hash1")
job_a = str(uuid.uuid4())
job_b = str(uuid.uuid4())
ref_a = _make_reference(session, asset, name="from_job_a")
ref_a.job_id = job_a
ref_b = _make_reference(session, asset, name="from_job_b")
ref_b.job_id = job_b
_make_reference(session, asset, name="no_job")
session.commit()
# Single job filter
refs, _, total = list_references_page(session, job_ids=[job_a])
assert total == 1
assert refs[0].name == "from_job_a"
# Multi-job filter (IN)
refs, _, total = list_references_page(session, job_ids=[job_a, job_b])
names = sorted(r.name for r in refs)
assert total == 2
assert names == ["from_job_a", "from_job_b"]
# Unknown job id matches nothing
refs, _, total = list_references_page(session, job_ids=[str(uuid.uuid4())])
assert total == 0
assert refs == []
# Empty/None means no filter -> all three references
refs, _, total = list_references_page(session, job_ids=[])
assert total == 3
refs, _, total = list_references_page(session, job_ids=None)
assert total == 3
def test_job_ids_combined_with_other_filters(self, session: Session):
asset = _make_asset(session, "hash1")
job_a = str(uuid.uuid4())
ref_match = _make_reference(session, asset, name="match.bin")
ref_match.job_id = job_a
ref_wrong_name = _make_reference(session, asset, name="other.bin")
ref_wrong_name.job_id = job_a
ref_wrong_job = _make_reference(session, asset, name="match.bin")
ref_wrong_job.job_id = str(uuid.uuid4())
session.commit()
refs, _, total = list_references_page(
session, job_ids=[job_a], name_contains="match"
)
assert total == 1
assert refs[0].id == ref_match.id
class TestTagRetrievalOrder:
"""End-to-end check: tags written through the public write paths come
back from the public read paths in insertion order rather than the
composite-PK alphabetical order SQLite would otherwise impose.
Each test deliberately picks tag names that would sort differently
under alphabetical vs insertion order, so an alphabetical regression
fails loudly.
"""
def _make_ref(self, session: Session) -> AssetReference:
asset = _make_asset(session, "h1")
return _make_reference(session, asset, name="x.bin")
def test_set_reference_tags_preserves_input_order_in_list(self, session: Session):
ref = self._make_ref(session)
# "checkpoints" < "models" alphabetically; if added_at stagger
# works, list_references_page returns insertion order.
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
session.commit()
_, tag_map, _ = list_references_page(session)
assert tag_map[ref.id] == ["models", "checkpoints"]
def test_set_reference_tags_preserves_input_order_in_fetch(self, session: Session):
ref = self._make_ref(session)
# Subpath tag sorts before "models" alphabetically.
set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "diffusers/kolors/text_encoder"],
)
session.commit()
result = fetch_reference_asset_and_tags(session, ref.id)
assert result is not None
_, _, tags = result
# Bucket-prefix expansion appends the standalone `diffusers` token
# at path-tier (microsecond stagger) so FE set-membership filters
# match nested category paths.
assert tags == ["models", "diffusers/kolors/text_encoder", "diffusers"]
def test_add_tags_to_reference_lands_after_path_tags(self, session: Session):
ref = self._make_ref(session)
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
session.commit()
# "aaa-..." sorts before both path tags alphabetically. If added_at
# stagger is missing, alphabetic tiebreak would hoist it to tags[0].
add_tags_to_reference(
session, reference_id=ref.id, tags=["aaa-user-tag"], origin="manual"
)
session.commit()
_, tag_map, _ = list_references_page(session)
assert tag_map[ref.id] == ["models", "checkpoints", "aaa-user-tag"]
def test_multi_tag_batch_lands_after_path_tags(self, session: Session):
ref = self._make_ref(session)
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
session.commit()
# Three user tags inserted in non-alphabetical input order. Per-tag
# microsecond stagger should preserve at least the "user batch is
# after path tags" property; within the user batch insertion order
# is also preserved.
add_tags_to_reference(
session,
reference_id=ref.id,
tags=["zzz-z", "favorite", "experiment-q4"],
origin="manual",
)
session.commit()
_, tag_map, _ = list_references_page(session)
tags = tag_map[ref.id]
assert tags[0:2] == ["models", "checkpoints"]
assert set(tags[2:]) == {"zzz-z", "favorite", "experiment-q4"}
def test_user_batch_lands_after_path_batch_under_clock_collision(
self, session: Session, monkeypatch: pytest.MonkeyPatch
):
"""Windows-specific race: when two back-to-back commits share the
same datetime.now() microsecond, the path-tier and user-tier
added_at values used to collide and alphabetic tiebreak would
hoist user tags ahead of path tags. The fix reads
max(existing_added_at) for the reference and seeds the next batch
past it, deterministically restoring insertion order.
This test simulates the collision by pinning get_utc_now() so the
platform-dependent race becomes a platform-independent failure.
"""
ref = self._make_ref(session)
from datetime import datetime
from app.assets.database import queries as queries_pkg
from app.assets.database.queries import tags as tags_module
frozen = datetime(2026, 1, 1, 0, 0, 0)
monkeypatch.setattr(tags_module, "get_utc_now", lambda: frozen)
monkeypatch.setattr(queries_pkg, "get_utc_now", lambda: frozen, raising=False)
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
session.commit()
# Same frozen timestamp — without the max(existing) seed, the
# user batch would share added_at with the path batch and
# `aaa-user-tag` would sort to position 0 via the alphabetic
# tiebreaker.
add_tags_to_reference(
session, reference_id=ref.id, tags=["aaa-user-tag"], origin="manual"
)
session.commit()
_, tag_map, _ = list_references_page(session)
assert tag_map[ref.id] == ["models", "checkpoints", "aaa-user-tag"]
def test_remove_then_add_does_not_disrupt_path_tag_positions(
self, session: Session
):
ref = self._make_ref(session)
set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "loras/my/custom/path"],
)
session.commit()
add_tags_to_reference(session, reference_id=ref.id, tags=["temp-tag"])
session.commit()
from app.assets.database.queries import remove_tags_from_reference
remove_tags_from_reference(session, reference_id=ref.id, tags=["temp-tag"])
session.commit()
add_tags_to_reference(session, reference_id=ref.id, tags=["second-tag"])
session.commit()
_, tag_map, _ = list_references_page(session)
# `loras` is expanded from the nested category path; user-added
# tags trail behind it via the microsecond stagger.
assert tag_map[ref.id] == [
"models",
"loras/my/custom/path",
"loras",
"second-tag",
]
class TestFetchReferenceAssetAndTags:
def test_returns_none_for_nonexistent(self, session: Session):

View File

@ -1,112 +0,0 @@
"""Keyset-pagination tiebreaker tests for list_references_page.
When multiple rows share the same primary sort value (e.g. four assets
created in the same microsecond), the secondary `ORDER BY id` is what keeps
keyset pagination from losing or repeating rows. This file exercises that
branch directly against an in-memory SQLite session — engineering identical
timestamps via HTTP is unreliable enough that we work at the query layer.
"""
import uuid
from datetime import datetime
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries.asset_reference import list_references_page
def _make_ref(session: Session, created_at: datetime, name: str, owner: str = "") -> AssetReference:
asset = Asset(hash=f"blake3:{uuid.uuid4().hex}", size_bytes=1024)
session.add(asset)
session.flush()
ref = AssetReference(
id=str(uuid.uuid4()),
asset_id=asset.id,
owner_id=owner,
name=name,
file_path=f"/tmp/{name}",
created_at=created_at,
updated_at=created_at,
last_access_time=created_at,
is_missing=False,
)
session.add(ref)
return ref
@pytest.mark.parametrize("order", ["desc", "asc"])
def test_tiebreaker_walks_duplicate_sort_values(session: Session, order: str):
"""Four rows with the SAME created_at must paginate cleanly under cursor
mode — no row dropped, no row repeated, despite the primary sort column
being non-discriminating.
"""
shared_ts = datetime(2024, 5, 20, 12, 0, 0) # naive UTC, like the DB stores
refs = [_make_ref(session, shared_ts, f"tie_{i}.png") for i in range(4)]
session.commit()
expected_ids = sorted([r.id for r in refs], reverse=(order == "desc"))
# Walk the cursor by hand: page size 2, take 3 pages (2 + 2 + 0).
seen: list[str] = []
after_value = None
after_id = None
for _ in range(4): # generous loop bound; ought to be 2 iterations
page, _tag_map, _total = list_references_page(
session,
limit=2,
sort="created_at",
order=order,
after_cursor_value=after_value,
after_cursor_id=after_id,
)
if not page:
break
seen.extend(p.id for p in page)
# Use the last row's (created_at, id) as the next cursor input.
last = page[-1]
after_value, after_id = last.created_at, last.id
if len(page) < 2:
break
assert seen == expected_ids, (
f"keyset tiebreaker failed for order={order}: expected {expected_ids}, got {seen}"
)
def test_tiebreaker_no_duplicates_under_mixed_collisions(session: Session):
"""Some rows share a timestamp, some don't. The cursor must still walk
every row exactly once regardless of where ties sit relative to a
page boundary."""
t1 = datetime(2024, 5, 20, 12, 0, 0)
t2 = datetime(2024, 5, 20, 12, 0, 1)
layout = [t1, t1, t1, t2, t2] # three rows at t1, two at t2
refs = [_make_ref(session, ts, f"mix_{i}.png") for i, ts in enumerate(layout)]
session.commit()
all_ids = {r.id for r in refs}
seen_set: set[str] = set()
seen_list: list[str] = []
after_value = None
after_id = None
for _ in range(6):
page, _, _ = list_references_page(
session,
limit=2,
sort="created_at",
order="desc",
after_cursor_value=after_value,
after_cursor_id=after_id,
)
if not page:
break
for p in page:
assert p.id not in seen_set, f"duplicate row {p.id} appeared in cursor walk"
seen_set.add(p.id)
seen_list.append(p.id)
last = page[-1]
after_value, after_id = last.created_at, last.id
if len(page) < 2:
break
assert seen_set == all_ids, f"missing rows: expected {all_ids}, got {seen_set}"

View File

@ -1,60 +0,0 @@
"""Schema-level unit tests for ListAssetsQuery (no DB required)."""
import uuid
import pytest
from pydantic import ValidationError
from app.assets.api.schemas_in import ListAssetsQuery
class TestJobIdsValidator:
def test_csv_string_parses_and_canonicalizes(self):
a = "AAAAAAAA-BBBB-CCCC-DDDD-EEEEEEEEEEEE"
b = "11111111-2222-3333-4444-555555555555"
q = ListAssetsQuery.model_validate({"job_ids": f"{a},{b}"})
# Canonicalized to lowercase
assert q.job_ids == [a.lower(), b]
def test_repeated_query_params_as_list(self):
a = "11111111-1111-1111-1111-111111111111"
b = "22222222-2222-2222-2222-222222222222"
q = ListAssetsQuery.model_validate({"job_ids": [a, b]})
assert q.job_ids == [a, b]
def test_dedup_preserves_first_seen_order(self):
a = "11111111-1111-1111-1111-111111111111"
b = "22222222-2222-2222-2222-222222222222"
q = ListAssetsQuery.model_validate({"job_ids": [a, b, a]})
assert q.job_ids == [a, b]
def test_default_empty(self):
q = ListAssetsQuery.model_validate({})
assert q.job_ids == []
def test_invalid_uuid_rejected(self):
with pytest.raises(ValidationError) as exc:
ListAssetsQuery.model_validate({"job_ids": "not-a-uuid"})
assert "must be UUIDs" in str(exc.value)
def test_non_string_list_item_rejected(self):
with pytest.raises(ValidationError) as exc:
ListAssetsQuery.model_validate(
{"job_ids": ["11111111-1111-1111-1111-111111111111", 42]}
)
assert "must be strings" in str(exc.value)
def test_non_string_non_list_value_rejected(self):
with pytest.raises(ValidationError) as exc:
ListAssetsQuery.model_validate({"job_ids": {"bad": "shape"}})
assert "must be a string or list of strings" in str(exc.value)
def test_max_length_enforced(self):
too_many = [str(uuid.uuid4()) for _ in range(501)]
with pytest.raises(ValidationError) as exc:
ListAssetsQuery.model_validate({"job_ids": too_many})
assert exc.value.errors()[0]["type"] == "too_long"
def test_max_length_boundary_accepted(self):
at_cap = [str(uuid.uuid4()) for _ in range(500)]
q = ListAssetsQuery.model_validate({"job_ids": at_cap})
assert len(q.job_ids) == 500

View File

@ -160,120 +160,6 @@ class TestAddTagsToReference:
add_tags_to_reference(session, reference_id="nonexistent", tags=["x"])
class TestBucketPrefixExpansion:
"""The standalone bucket token must appear in the asset's tag set for
nested category paths so FE filters like
`include_tags=models,checkpoints` continue to match.
"""
def test_set_reference_tags_inserts_bucket_for_nested_path(
self, session: Session
):
asset = _make_asset(session, "hash-nested")
ref = _make_reference(session, asset)
result = set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "checkpoints/flux"],
)
session.commit()
assert set(result.total) == {"models", "checkpoints/flux", "checkpoints"}
stored = get_reference_tags(session, reference_id=ref.id)
# tag[1] keeps the slash-joined positional contract; the standalone
# bucket lands after it via path-tier microsecond stagger so user
# tags remain at the tail.
assert stored[:3] == ["models", "checkpoints/flux", "checkpoints"]
def test_set_reference_tags_idempotent_on_replay(self, session: Session):
asset = _make_asset(session, "hash-replay")
ref = _make_reference(session, asset)
set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "checkpoints/flux"],
)
# Replay with the same caller-supplied set; expansion is already
# baked in, so nothing should be added or removed.
result = set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "checkpoints/flux"],
)
session.commit()
assert result.added == []
assert result.removed == []
assert set(result.total) == {"models", "checkpoints/flux", "checkpoints"}
def test_add_tags_to_reference_expands_bucket(self, session: Session):
asset = _make_asset(session, "hash-add")
ref = _make_reference(session, asset)
result = add_tags_to_reference(
session,
reference_id=ref.id,
tags=["loras/style/v2"],
)
session.commit()
assert set(result.added) == {"loras/style/v2", "loras"}
stored = get_reference_tags(session, reference_id=ref.id)
assert "loras" in stored
assert "loras/style/v2" in stored
def test_add_tags_does_not_duplicate_existing_bucket(self, session: Session):
asset = _make_asset(session, "hash-dedupe")
ref = _make_reference(session, asset)
add_tags_to_reference(
session, reference_id=ref.id, tags=["models", "checkpoints"]
)
result = add_tags_to_reference(
session, reference_id=ref.id, tags=["checkpoints/flux"]
)
session.commit()
# `checkpoints` was already there from the first add; only the
# slash-joined token is genuinely new.
assert result.added == ["checkpoints/flux"]
assert "checkpoints" in result.already_present
def test_flat_category_is_unaffected(self, session: Session):
asset = _make_asset(session, "hash-flat")
ref = _make_reference(session, asset)
result = set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "checkpoints"],
)
session.commit()
assert set(result.total) == {"models", "checkpoints"}
assert get_reference_tags(session, reference_id=ref.id) == [
"models",
"checkpoints",
]
def test_unknown_prefix_passes_through(self, session: Session):
asset = _make_asset(session, "hash-user")
ref = _make_reference(session, asset)
# `my-org` isn't a registered bucket — the slash-joined user tag
# should not trigger bucket expansion.
result = set_reference_tags(
session,
reference_id=ref.id,
tags=["my-org/team-a"],
)
session.commit()
assert result.total == ["my-org/team-a"]
class TestRemoveTagsFromReference:
def test_removes_tags(self, session: Session):
asset = _make_asset(session, "hash1")

View File

@ -4,7 +4,7 @@ from pathlib import Path
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag
from app.assets.database.models import Asset, AssetReference
from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets
@ -102,82 +102,6 @@ class TestBatchInsertSeedAssets:
assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}"
class TestBucketPrefixExpansionOnIngest:
"""Path-scanning ingest must persist the standalone bucket token for
nested category paths so the FE set-membership filter
(`include_tags=models,checkpoints`) matches assets organized into
subfolders (`models/checkpoints/flux/foo.safetensors`).
"""
def test_nested_path_inserts_standalone_bucket(
self, session: Session, temp_dir: Path
):
file_path = temp_dir / "flux.safetensors"
file_path.write_bytes(b"content")
specs: list[SeedAssetSpec] = [
{
"abs_path": str(file_path),
"size_bytes": 7,
"mtime_ns": 1234567890000000000,
"info_name": "flux",
# Shape emitted by get_name_and_tags_from_asset_path for a
# nested model path.
"tags": ["models", "checkpoints/flux"],
"fname": "flux.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
}
]
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
ref = session.query(AssetReference).filter_by(name="flux").one()
stored = [
row.tag_name
for row in session.query(AssetReferenceTag)
.filter_by(asset_reference_id=ref.id)
.order_by(AssetReferenceTag.added_at.asc())
.all()
]
assert stored == ["models", "checkpoints/flux", "checkpoints"]
def test_flat_path_remains_two_tags(
self, session: Session, temp_dir: Path
):
file_path = temp_dir / "vanilla.safetensors"
file_path.write_bytes(b"content")
specs: list[SeedAssetSpec] = [
{
"abs_path": str(file_path),
"size_bytes": 7,
"mtime_ns": 1234567890000000000,
"info_name": "vanilla",
"tags": ["models", "checkpoints"],
"fname": "vanilla.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
}
]
batch_insert_seed_assets(session, specs=specs, owner_id="")
ref = session.query(AssetReference).filter_by(name="vanilla").one()
stored = {
row.tag_name
for row in session.query(AssetReferenceTag)
.filter_by(asset_reference_id=ref.id)
.all()
}
# Dedupe means flat layouts don't pick up a redundant `checkpoints`
# row — tag[1] already serves both positional and set-membership.
assert stored == {"models", "checkpoints"}
class TestMetadataExtraction:
def test_extracts_mime_type_for_model_files(self, temp_dir: Path):
"""Verify metadata extraction returns correct mime_type for model files."""

View File

@ -1,354 +0,0 @@
"""Tests for app.assets.services.cursor.
The byte-identity fixtures below pin the wire format so a parallel
implementation in another runtime can mint exchange-compatible cursors
for the same payload. Drift here would break frontend pagination against
any compatible backend.
"""
from __future__ import annotations
import base64
from datetime import datetime, timedelta, timezone
import pytest
from app.assets.services.cursor import (
MAX_CURSOR_ID_LENGTH,
MAX_CURSOR_VALUE_LENGTH,
MAX_ENCODED_CURSOR_LENGTH,
CursorPayload,
InvalidCursorError,
decode_cursor,
decode_cursor_int,
decode_cursor_time,
encode_cursor,
encode_cursor_from_time,
)
ALLOWED = ("created_at", "updated_at", "name", "size")
class TestRoundTrip:
@pytest.mark.parametrize(
"sort_field, value, id",
[
("created_at", "1716200000000000", "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7"),
("size", "1024", "asset-123"),
("name", "my-asset.png", "asset-abc"),
("name", "résumé.txt", "asset-uni"),
],
)
def test_encode_decode(self, sort_field, value, id):
encoded = encode_cursor(sort_field, value, id)
assert encoded != ""
payload = decode_cursor(encoded, ALLOWED)
assert payload.sort_field == sort_field
assert payload.value == value
assert payload.id == id
class TestTimeCursor:
def test_microsecond_precision_preserved(self):
# Pick a time with non-zero microseconds — encoding at ms would lose the µs.
ts = datetime(2024, 5, 20, 12, 53, 20, 123456, tzinfo=timezone.utc)
encoded = encode_cursor_from_time("created_at", ts, "id-1")
payload = decode_cursor(encoded, ALLOWED)
# Value must be a microsecond integer string, not a millisecond one.
assert payload.value == "1716209600123456"
decoded = decode_cursor_time(payload)
assert decoded == ts
def test_decode_returns_utc(self):
payload = CursorPayload(sort_field="created_at", value="1716200000123456", id="id-1", order="desc")
decoded = decode_cursor_time(payload)
assert decoded.tzinfo == timezone.utc
def test_naive_datetime_rejected_on_encode(self):
naive = datetime(2024, 5, 20, 12, 0, 0)
with pytest.raises(ValueError):
encode_cursor_from_time("created_at", naive, "id-1")
def test_non_integer_value_rejected_on_decode(self):
with pytest.raises(InvalidCursorError):
decode_cursor_time(CursorPayload("created_at", "not-a-number", "id-1", "desc"))
def test_none_payload_rejected(self):
with pytest.raises(InvalidCursorError):
decode_cursor_time(None)
def test_non_utc_aware_normalized(self):
# Same instant, different timezone — must encode to the same micros.
utc_ts = datetime(2024, 5, 20, 12, 0, 0, tzinfo=timezone.utc)
offset_ts = utc_ts.astimezone(timezone(timedelta(hours=-5)))
assert encode_cursor_from_time("created_at", utc_ts, "x") == encode_cursor_from_time(
"created_at", offset_ts, "x"
)
class TestIntCursor:
def test_decode_int(self):
assert decode_cursor_int(CursorPayload("size", "1024", "id-1", "desc")) == 1024
def test_decode_int_rejects_non_int(self):
with pytest.raises(InvalidCursorError):
decode_cursor_int(CursorPayload("size", "abc", "id-1", "desc"))
def test_decode_int_rejects_none(self):
with pytest.raises(InvalidCursorError):
decode_cursor_int(None)
class TestInvalidInputs:
def test_oversized_cursor(self):
oversized = "a" * (MAX_ENCODED_CURSOR_LENGTH + 1)
with pytest.raises(InvalidCursorError, match="maximum length"):
decode_cursor(oversized, ALLOWED)
def test_not_base64(self):
with pytest.raises(InvalidCursorError):
decode_cursor("not base64!!!", ALLOWED)
def test_not_json(self):
encoded = base64.urlsafe_b64encode(b"definitely not json").rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError):
decode_cursor(encoded, ALLOWED)
def test_empty_id(self):
# Encoder rejects empty id symmetrically with the decoder, so build the
# payload manually to exercise the decoder's missing-id branch.
raw = b'{"s":"created_at","v":"1","id":"","o":"desc"}'
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError, match="missing id"):
decode_cursor(encoded, ALLOWED)
def test_oversized_id(self):
# Encoder enforces the cap symmetrically; hand-build to exercise decode.
big_id = "a" * (MAX_CURSOR_ID_LENGTH + 1)
raw = ('{"s":"created_at","v":"1","id":"' + big_id + '","o":"desc"}').encode("ascii")
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError, match="id exceeds maximum length"):
decode_cursor(encoded, ALLOWED)
def test_oversized_value(self):
# Encoder enforces the cap symmetrically; hand-build to exercise decode.
big_v = "v" * (MAX_CURSOR_VALUE_LENGTH + 1)
raw = ('{"s":"created_at","v":"' + big_v + '","id":"id-1","o":"desc"}').encode("ascii")
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError, match="value exceeds maximum length"):
decode_cursor(encoded, ALLOWED)
def test_unsupported_sort_field(self):
encoded = encode_cursor("execution_time", "1", "id-1")
with pytest.raises(InvalidCursorError, match="unsupported sort field"):
decode_cursor(encoded, ALLOWED)
def test_no_allowed_fields_rejects_everything(self):
encoded = encode_cursor("created_at", "1", "id-1")
with pytest.raises(InvalidCursorError):
decode_cursor(encoded, ())
def test_non_dict_payload_rejected(self):
encoded = base64.urlsafe_b64encode(b'["array","not","dict"]').rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError, match="expected object"):
decode_cursor(encoded, ALLOWED)
class TestEncodeAtCapsFits:
def test_max_field_lengths_fit_wire_cap(self):
# Worst-case payload: value and id at their per-field caps, with a long
# sort field name. The encoded cursor must fit within MAX_ENCODED_CURSOR_LENGTH
# so the wire cap cannot reject a cursor the encoder mints at the per-field caps.
value = "v" * MAX_CURSOR_VALUE_LENGTH
id = "i" * MAX_CURSOR_ID_LENGTH
sort_field = "very_long_sort_field_name"
encoded = encode_cursor(sort_field, value, id)
assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH
payload = decode_cursor(encoded, (sort_field,))
assert payload.value == value
assert payload.id == id
class TestDatetimeOverflow:
"""Crafted cursors with extreme micros must map to InvalidCursorError,
not OverflowError/OSError leaking as 500.
"""
@pytest.mark.parametrize(
"micros_str",
[
"999999999999999999999", # 10^21 µs — past datetime.MAX_YEAR by ~14 orders
"-999999999999999999999", # symmetric negative — pre-epoch overflow
],
)
def test_out_of_range_micros_rejected(self, micros_str):
encoded = encode_cursor("created_at", micros_str, "asset-x")
payload = decode_cursor(encoded, ALLOWED)
with pytest.raises(InvalidCursorError):
decode_cursor_time(payload)
class TestEncoderDecoderSymmetry:
"""The encoder must reject inputs the decoder rejects, or the same server
will mint a cursor it then 400s on the next request.
"""
def test_long_name_within_cap_round_trips(self):
"""Assets allow names up to 512 chars (`String(512)`); the cursor
encoder must round-trip a value at that cap so a freshly minted
cursor never fails decode on the next request."""
long_name = "n" * MAX_CURSOR_VALUE_LENGTH
encoded = encode_cursor("name", long_name, "asset-x")
payload = decode_cursor(encoded, ALLOWED)
assert payload.value == long_name
def test_encoder_rejects_empty_id(self):
with pytest.raises(InvalidCursorError, match="id must be non-empty"):
encode_cursor("created_at", "1", "")
def test_encoder_rejects_oversized_id(self):
with pytest.raises(InvalidCursorError, match="id exceeds maximum length"):
encode_cursor("created_at", "1", "a" * (MAX_CURSOR_ID_LENGTH + 1))
def test_encoder_rejects_oversized_value(self):
with pytest.raises(InvalidCursorError, match="value exceeds maximum length"):
encode_cursor("name", "v" * (MAX_CURSOR_VALUE_LENGTH + 1), "id-1")
def test_encoder_rejects_multibyte_value_over_wire_cap(self):
"""A value that passes the char-count cap can still inflate past the
wire cap once UTF-8-encoded. Asset name made of 512 × multibyte
characters (e.g. 'é' = 2 bytes) must be rejected at encode time, not
minted into a cursor the next request will 400."""
with pytest.raises(InvalidCursorError, match="encoded cursor exceeds maximum length"):
encode_cursor("name", "é" * MAX_CURSOR_VALUE_LENGTH, "asset-multibyte")
def test_encoder_rejects_escape_heavy_value_over_wire_cap(self):
"""Same wire-cap concern via escape expansion: each `<` serializes to
the six-byte sequence `\\u003c`, so 512 of them blow past the encoded
cap even though the raw char count is within the per-field limit."""
with pytest.raises(InvalidCursorError, match="encoded cursor exceeds maximum length"):
encode_cursor("name", "<" * MAX_CURSOR_VALUE_LENGTH, "asset-escape")
class TestOrderBinding:
def test_order_baked_into_payload(self):
encoded = encode_cursor("created_at", "1", "id-1", order="asc")
payload = decode_cursor(encoded, ALLOWED)
assert payload.order == "asc"
def test_mismatched_order_rejected(self):
encoded = encode_cursor("created_at", "1", "id-1", order="desc")
with pytest.raises(InvalidCursorError, match="does not match request order"):
decode_cursor(encoded, ALLOWED, expected_order="asc")
def test_matching_order_accepted(self):
encoded = encode_cursor("created_at", "1", "id-1", order="desc")
payload = decode_cursor(encoded, ALLOWED, expected_order="desc")
assert payload.order == "desc"
def test_invalid_order_token_rejected_on_encode(self):
with pytest.raises(ValueError):
encode_cursor("created_at", "1", "id-1", order="sideways")
def test_invalid_order_token_rejected_on_decode(self):
# Hand-craft a payload with an illegal `o` value.
raw = b'{"s":"name","v":"x","id":"id-1","o":"sideways"}'
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError, match="unsupported order"):
decode_cursor(encoded, ALLOWED)
def test_cursor_without_order_rejected(self):
"""`o` is mandatory. A cursor minted without it is rejected as
malformed rather than silently walking the keyset in whatever
direction the request happens to ask for."""
raw = b'{"s":"name","v":"x","id":"id-1"}'
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError, match="missing or non-string o"):
decode_cursor(encoded, ALLOWED, expected_order="desc")
class TestHtmlSignificantCharEscaping:
"""An asset name containing `<`, `>`, `&`, U+2028, or U+2029 must encode
to the same escaped wire bytes as any compatible implementation of the
same payload format. Drift here breaks cross-runtime byte-identity for
those characters.
"""
@pytest.mark.parametrize(
"value, escaped_substring",
[
("foo<bar>.png", "\\u003c"), # `<` escaped
("foo<bar>.png", "\\u003e"), # `>` escaped
("foo&bar.png", "\\u0026"),
("foobar.png", "\\u2028"), # JS line separator
("foobar.png", "\\u2029"), # JS paragraph separator
],
)
def test_html_significant_chars_escaped(self, value, escaped_substring):
encoded = encode_cursor("name", value, "id-1")
decoded_bytes = base64.urlsafe_b64decode(encoded + "=" * (-len(encoded) % 4))
assert escaped_substring in decoded_bytes.decode("ascii"), (
f"Expected {escaped_substring!r} in serialized payload, got: {decoded_bytes!r}"
)
def test_value_round_trips_through_escape(self):
"""Encoding then decoding a value with `<>&` should yield the original
string — the escape only affects the wire form, not the decoded value."""
original = "foo<&>bar.png"
encoded = encode_cursor("name", original, "id-1")
payload = decode_cursor(encoded, ALLOWED)
assert payload.value == original
class TestByteIdentityFixtures:
"""Pin the wire format so it doesn't drift silently.
These fixtures assert exact byte equality of the encoded JSON payload —
a change in key order, escape choice, separator whitespace, or anything
else that shifts a byte fails the test loudly rather than diverging
silently from any external consumer of the same payload format.
"""
@pytest.mark.parametrize(
"sort_field, value, id, order, expected_payload",
[
(
"created_at",
"1716200000000000",
"a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7",
"desc",
'{"s":"created_at","v":"1716200000000000","id":"a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7","o":"desc"}',
),
(
"size",
"1024",
"asset-123",
"asc",
'{"s":"size","v":"1024","id":"asset-123","o":"asc"}',
),
(
"name",
"my-asset.png",
"asset-abc",
"desc",
'{"s":"name","v":"my-asset.png","id":"asset-abc","o":"desc"}',
),
(
"name",
"foo<bar>&baz.png",
"asset-html",
"desc",
# `<`, `>`, `&` escape to <, >, & in the value.
'{"s":"name","v":"foo\\u003cbar\\u003e\\u0026baz.png","id":"asset-html","o":"desc"}',
),
],
)
def test_encoded_payload_shape_pinned(self, sort_field, value, id, order, expected_payload):
encoded = encode_cursor(sort_field, value, id, order=order)
decoded_bytes = base64.urlsafe_b64decode(encoded + "=" * (-len(encoded) % 4))
assert decoded_bytes.decode("utf-8") == expected_payload, (
f"wire format drifted for sort={sort_field!r}, value={value!r}:\n"
f" expected: {expected_payload!r}\n"
f" actual: {decoded_bytes.decode('utf-8')!r}"
)

View File

@ -1,86 +0,0 @@
"""Tests for the image_dimensions service."""
from __future__ import annotations
from pathlib import Path
import pytest
from PIL import Image
from app.assets.services.image_dimensions import extract_image_dimensions
def _make_png(path: Path, size: tuple[int, int]) -> Path:
img = Image.new("RGB", size, color=(123, 45, 67))
img.save(path, format="PNG")
return path
def _make_jpeg(path: Path, size: tuple[int, int]) -> Path:
img = Image.new("RGB", size, color=(10, 20, 30))
img.save(path, format="JPEG", quality=80)
return path
class TestExtractImageDimensions:
def test_extracts_png_dimensions(self, tmp_path: Path):
f = _make_png(tmp_path / "rect.png", (320, 240))
result = extract_image_dimensions(str(f), mime_type="image/png")
assert result == {"kind": "image", "width": 320, "height": 240}
def test_extracts_jpeg_dimensions(self, tmp_path: Path):
f = _make_jpeg(tmp_path / "shot.jpg", (1920, 1080))
result = extract_image_dimensions(str(f), mime_type="image/jpeg")
assert result == {"kind": "image", "width": 1920, "height": 1080}
def test_works_when_mime_type_is_none(self, tmp_path: Path):
f = _make_png(tmp_path / "no_mime.png", (50, 100))
result = extract_image_dimensions(str(f), mime_type=None)
assert result == {"kind": "image", "width": 50, "height": 100}
def test_skips_non_image_mime_without_touching_file(self, tmp_path: Path):
# Path doesn't need to exist — non-image MIME short-circuits.
result = extract_image_dimensions(
str(tmp_path / "model.safetensors"),
mime_type="application/octet-stream",
)
assert result is None
@pytest.mark.parametrize(
"mime",
["application/json", "text/plain", "video/mp4", "audio/mpeg"],
)
def test_skips_all_non_image_mime_types(self, tmp_path: Path, mime: str):
f = tmp_path / "file.bin"
f.write_bytes(b"\x00\x01\x02")
assert extract_image_dimensions(str(f), mime_type=mime) is None
def test_returns_none_for_missing_file(self, tmp_path: Path):
result = extract_image_dimensions(
str(tmp_path / "does_not_exist.png"), mime_type="image/png"
)
assert result is None
def test_returns_none_for_corrupt_image(self, tmp_path: Path):
f = tmp_path / "corrupt.png"
f.write_bytes(b"not actually a png file")
result = extract_image_dimensions(str(f), mime_type="image/png")
assert result is None
def test_returns_none_for_empty_file(self, tmp_path: Path):
f = tmp_path / "empty.png"
f.write_bytes(b"")
result = extract_image_dimensions(str(f), mime_type="image/png")
assert result is None

View File

@ -4,12 +4,10 @@ from pathlib import Path
from unittest.mock import patch
import pytest
from PIL import Image
from sqlalchemy.orm import Session as SASession, Session
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
from app.assets.database.queries import get_reference_tags
from app.assets.helpers import get_utc_now
from app.assets.services.ingest import (
_ingest_file_from_path,
_register_existing_asset,
@ -17,11 +15,6 @@ from app.assets.services.ingest import (
)
def _make_png(path: Path, size: tuple[int, int]) -> Path:
Image.new("RGB", size, color=(80, 120, 200)).save(path, format="PNG")
return path
class TestIngestFileFromPath:
def test_creates_asset_and_reference(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "test_file.bin"
@ -286,203 +279,4 @@ class TestIngestExistingFileTagFK:
ref_tags = sess.query(AssetReferenceTag).all()
ref_tag_names = {rt.tag_name for rt in ref_tags}
assert "output" in ref_tag_names
class TestIngestImageDimensions:
"""system_metadata should carry {kind, width, height} for image assets."""
def test_image_asset_emits_dimensions(
self, mock_create_session, temp_dir: Path, session: Session
):
f = _make_png(temp_dir / "shot.png", (640, 480))
result = _ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:img1",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000000,
mime_type="image/png",
)
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
assert ref.system_metadata == {
"kind": "image",
"width": 640,
"height": 480,
}
def test_non_image_asset_leaves_system_metadata_empty(
self, mock_create_session, temp_dir: Path, session: Session
):
f = temp_dir / "model.safetensors"
f.write_bytes(b"not an image")
result = _ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:safetensors1",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000000,
mime_type="application/octet-stream",
)
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
assert ref.system_metadata in (None, {})
def test_preserves_existing_system_metadata_keys(
self, mock_create_session, temp_dir: Path, session: Session
):
f = _make_png(temp_dir / "annotated.png", (100, 200))
# First pass populates a sentinel system_metadata key (simulating prior
# enricher write).
result = _ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:img-merge",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000000,
mime_type="image/png",
)
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/x.png"}
session.commit()
# Second pass with the same path triggers the merge code path again.
_ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:img-merge",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000001,
mime_type="image/png",
)
session.refresh(ref)
assert ref.system_metadata["kind"] == "image"
assert ref.system_metadata["width"] == 100
assert ref.system_metadata["height"] == 200
assert ref.system_metadata["source_url"] == "https://example/x.png"
class TestRegisterExistingAssetBackfill:
"""The from-hash path back-fills dimensions from a sibling reference."""
def _add_reference(
self,
session: Session,
asset: Asset,
name: str,
system_metadata: dict | None = None,
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
asset_id=asset.id,
name=name,
owner_id="",
created_at=now,
updated_at=now,
last_access_time=now,
system_metadata=system_metadata or {},
)
session.add(ref)
session.flush()
return ref
def test_backfills_dimensions_from_sibling_image_reference(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:shared", size_bytes=2048, mime_type="image/png")
session.add(asset)
session.flush()
self._add_reference(
session,
asset,
name="original.png",
system_metadata={"kind": "image", "width": 800, "height": 600},
)
session.commit()
result = _register_existing_asset(
asset_hash="blake3:shared",
name="from_hash.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
assert ref.system_metadata.get("kind") == "image"
assert ref.system_metadata.get("width") == 800
assert ref.system_metadata.get("height") == 600
def test_no_backfill_when_sibling_has_no_image_metadata(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:nodims", size_bytes=2048, mime_type="image/png")
session.add(asset)
session.flush()
self._add_reference(
session,
asset,
name="original.png",
system_metadata={"base_model": "flux"}, # no kind=image
)
session.commit()
result = _register_existing_asset(
asset_hash="blake3:nodims",
name="from_hash.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
meta = ref.system_metadata or {}
assert "kind" not in meta
assert "width" not in meta
assert "height" not in meta
def test_no_backfill_when_no_sibling_exists(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:lonely", size_bytes=1024, mime_type="image/png")
session.add(asset)
session.commit()
result = _register_existing_asset(
asset_hash="blake3:lonely",
name="solo.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
assert ref.system_metadata in (None, {})
def test_backfill_preserves_caller_supplied_keys(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:preserve", size_bytes=2048, mime_type="image/png")
session.add(asset)
session.flush()
self._add_reference(
session,
asset,
name="original.png",
system_metadata={"kind": "image", "width": 1024, "height": 768},
)
session.commit()
# Simulate a from-hash path where the new reference already carries
# some system_metadata (e.g. a download-provenance source_url written
# by an earlier step). The back-fill must merge dim keys without
# clobbering existing keys.
result = _register_existing_asset(
asset_hash="blake3:preserve",
name="from_hash.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
# Seed a sentinel key and re-run back-fill via a second register call
# to exercise the merge path with pre-existing data.
ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/p"}
session.commit()
assert ref.system_metadata.get("source_url") == "https://example/p"
assert ref.system_metadata.get("kind") == "image"
assert ref.system_metadata.get("width") == 1024
assert ref.system_metadata.get("height") == 768
assert "my-job" in ref_tag_names

View File

@ -6,13 +6,7 @@ from unittest.mock import patch
import pytest
from app.assets.services.path_utils import (
compute_display_name,
compute_file_path,
get_asset_category_and_relative_path,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags,
)
from app.assets.services.path_utils import get_asset_category_and_relative_path
@pytest.fixture
@ -44,50 +38,6 @@ def fake_dirs():
}
@pytest.fixture
def fake_dirs_multi_bucket():
"""Variant fixture with multiple model buckets (checkpoints + diffusers + loras)."""
with tempfile.TemporaryDirectory() as root:
root_path = Path(root)
input_dir = root_path / "input"
output_dir = root_path / "output"
temp_dir = root_path / "temp"
checkpoints_dir = root_path / "models" / "checkpoints"
diffusers_dir = root_path / "models" / "diffusers"
loras_dir = root_path / "models" / "loras"
for d in (
input_dir,
output_dir,
temp_dir,
checkpoints_dir,
diffusers_dir,
loras_dir,
):
d.mkdir(parents=True)
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.get_temp_directory.return_value = str(temp_dir)
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[
("checkpoints", [str(checkpoints_dir)]),
("diffusers", [str(diffusers_dir)]),
("loras", [str(loras_dir)]),
],
):
yield {
"input": input_dir,
"output": output_dir,
"temp": temp_dir,
"checkpoints": checkpoints_dir,
"diffusers": diffusers_dir,
"loras": loras_dir,
}
class TestGetAssetCategoryAndRelativePath:
def test_input_file(self, fake_dirs):
f = fake_dirs["input"] / "photo.png"
@ -129,185 +79,3 @@ class TestGetAssetCategoryAndRelativePath:
def test_unknown_path_raises(self, fake_dirs):
with pytest.raises(ValueError, match="not within"):
get_asset_category_and_relative_path("/some/random/path.png")
class TestGetNameAndTagsFromAssetPath:
"""tags collapse the parent subpath into a single slash-joined tag.
Consumers should be able to read ``tags[1]`` as a stable category
identifier regardless of how deep the file lives in the bucket.
"""
def test_flat_input(self, fake_dirs_multi_bucket):
f = fake_dirs_multi_bucket["input"] / "photo.png"
f.touch()
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "photo.png"
assert tags == ["input"]
def test_flat_output(self, fake_dirs_multi_bucket):
f = fake_dirs_multi_bucket["output"] / "result_00001.png"
f.touch()
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "result_00001.png"
assert tags == ["output"]
def test_flat_models_checkpoint(self, fake_dirs_multi_bucket):
f = fake_dirs_multi_bucket["checkpoints"] / "flux.safetensors"
f.touch()
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "flux.safetensors"
assert tags == ["models", "checkpoints"]
def test_diffusers_nested_subpath_slash_joined(self, fake_dirs_multi_bucket):
"""Diffusers components live in nested directories — the full subpath
must collapse into one tag so consumers can look up the model category
via tags[1] regardless of nesting depth.
The subpath is lowercased to match the canonicalization
:func:`ensure_tags_exist` applies on the write side; without that,
the asset_reference_tags.tag_name FK to tags.name would fail for
any path containing uppercase letters.
"""
nested = (
fake_dirs_multi_bucket["diffusers"]
/ "Kolors"
/ "text_encoder"
)
nested.mkdir(parents=True)
f = nested / "model.safetensors"
f.touch()
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "model.safetensors"
assert tags == ["models", "diffusers/kolors/text_encoder"]
def test_deep_lora_user_subpath_slash_joined(self, fake_dirs_multi_bucket):
"""User-created subdirectories under a model bucket also collapse to a
single tag rather than one tag per directory."""
nested = (
fake_dirs_multi_bucket["loras"]
/ "my"
/ "custom"
/ "path"
)
nested.mkdir(parents=True)
f = nested / "v0001.safetensors"
f.touch()
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "v0001.safetensors"
assert tags == ["models", "loras/my/custom/path"]
class TestResolveDestinationFromTags:
"""resolve_destination_from_tags must accept both the legacy
one-tag-per-directory shape and the new slash-joined shape so that an
upload using the tags it just read back from /api/assets round-trips
to the right on-disk destination.
"""
@pytest.fixture
def resolve_dirs(self):
with tempfile.TemporaryDirectory() as root:
root_path = Path(root)
input_dir = root_path / "input"
output_dir = root_path / "output"
checkpoints_dir = root_path / "models" / "checkpoints"
diffusers_dir = root_path / "models" / "diffusers"
loras_dir = root_path / "models" / "loras"
for d in (input_dir, output_dir, checkpoints_dir, diffusers_dir, loras_dir):
d.mkdir(parents=True)
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.folder_names_and_paths = {
"checkpoints": ([str(checkpoints_dir)], None),
"diffusers": ([str(diffusers_dir)], None),
"loras": ([str(loras_dir)], None),
}
yield {
"input": input_dir,
"output": output_dir,
"checkpoints": checkpoints_dir,
"diffusers": diffusers_dir,
"loras": loras_dir,
}
def test_models_flat_category(self, resolve_dirs):
base, subdirs = resolve_destination_from_tags(["models", "checkpoints"])
assert base == str(resolve_dirs["checkpoints"])
assert subdirs == []
def test_models_slash_joined_new_shape(self, resolve_dirs):
# The shape get_name_and_tags_from_asset_path now emits.
base, subdirs = resolve_destination_from_tags(
["models", "diffusers/kolors/text_encoder"]
)
assert base == str(resolve_dirs["diffusers"])
assert subdirs == ["kolors", "text_encoder"]
def test_models_legacy_one_tag_per_dir(self, resolve_dirs):
# The legacy shape must still resolve identically.
base, subdirs = resolve_destination_from_tags(
["models", "diffusers", "kolors", "text_encoder"]
)
assert base == str(resolve_dirs["diffusers"])
assert subdirs == ["kolors", "text_encoder"]
def test_models_loras_slash_joined(self, resolve_dirs):
base, subdirs = resolve_destination_from_tags(
["models", "loras/my/custom/path"]
)
assert base == str(resolve_dirs["loras"])
assert subdirs == ["my", "custom", "path"]
def test_input_no_subdir(self, resolve_dirs):
base, subdirs = resolve_destination_from_tags(["input"])
assert base == str(resolve_dirs["input"])
assert subdirs == []
def test_input_slash_joined_subdir(self, resolve_dirs):
base, subdirs = resolve_destination_from_tags(["input", "portraits/2026"])
assert base == str(resolve_dirs["input"])
assert subdirs == ["portraits", "2026"]
def test_output_slash_joined_subdir(self, resolve_dirs):
base, subdirs = resolve_destination_from_tags(["output", "runs/abc"])
assert base == str(resolve_dirs["output"])
assert subdirs == ["runs", "abc"]
def test_unknown_category_rejected(self, resolve_dirs):
with pytest.raises(ValueError, match="unknown model category"):
resolve_destination_from_tags(["models", "not_a_real_category"])
def test_unknown_category_via_slash_joined(self, resolve_dirs):
# First segment of a slash-joined tag must still match a registered category.
with pytest.raises(ValueError, match="unknown model category 'bogus'"):
resolve_destination_from_tags(["models", "bogus/sub/path"])
def test_traversal_in_subdir_rejected(self, resolve_dirs):
with pytest.raises(ValueError, match="invalid path component"):
resolve_destination_from_tags(["models", "checkpoints/..", "evil"])
class TestResponsePaths:
def test_input_file_path_and_display_name_include_subfolder(self, fake_dirs):
sub = fake_dirs["input"] / "some" / "folder"
sub.mkdir(parents=True)
f = sub / "image.png"
f.touch()
assert compute_file_path(str(f)) == "input/some/folder/image.png"
assert compute_display_name(str(f)) == "some/folder/image.png"
def test_model_file_path_includes_bucket_display_name_drops_it(self, fake_dirs):
sub = fake_dirs["models"] / "flux"
sub.mkdir()
f = sub / "model.safetensors"
f.touch()
assert compute_file_path(str(f)) == "models/checkpoints/flux/model.safetensors"
assert compute_display_name(str(f)) == "flux/model.safetensors"
def test_unknown_path_returns_none(self, fake_dirs):
assert compute_file_path("/some/random/path.png") is None
assert compute_display_name("/some/random/path.png") is None

View File

@ -32,7 +32,7 @@ def test_seed_asset_removed_when_file_is_deleted(
# Verify it is visible via API and carries no hash (seed)
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests/syncseed", "name_contains": name},
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
timeout=120,
)
body1 = r1.json()
@ -40,9 +40,7 @@ def test_seed_asset_removed_when_file_is_deleted(
# there should be exactly one with that name
matches = [a for a in body1.get("assets", []) if a.get("name") == name]
assert matches
# Seed assets have no hash; exclude_none drops both keys from the response
assert "asset_hash" not in matches[0]
assert "hash" not in matches[0]
assert matches[0].get("asset_hash") is None
asset_info_id = matches[0]["id"]
# Remove the underlying file and sync again
@ -54,7 +52,7 @@ def test_seed_asset_removed_when_file_is_deleted(
# It should disappear (AssetInfo and seed Asset gone)
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests/syncseed", "name_contains": name},
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
timeout=120,
)
body2 = r2.json()
@ -334,7 +332,7 @@ def test_fastpass_removes_stale_state_row_no_missing(
rl = http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests/{scope}"},
params={"include_tags": f"unit-tests,{scope}"},
timeout=120,
)
bl = rl.json()

View File

@ -21,8 +21,6 @@ def test_create_from_hash_success(
b1 = r1.json()
assert r1.status_code == 201, b1
assert b1["asset_hash"] == h
assert b1["hash"] == h
assert b1["hash"] == b1["asset_hash"]
assert b1["created_new"] is False
aid = b1["id"]
@ -41,7 +39,6 @@ def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asse
detail = rg.json()
assert rg.status_code == 200, detail
assert detail["id"] == aid
assert detail["hash"] == detail["asset_hash"]
assert "user_metadata" in detail
assert "filename" in detail["user_metadata"]
@ -100,7 +97,6 @@ def test_delete_upon_reference_count(
copy = r2.json()
assert r2.status_code == 201, copy
assert copy["asset_hash"] == src_hash
assert copy["hash"] == src_hash
assert copy["created_new"] is False
# Soft-delete original reference (default) -> asset identity must remain
@ -143,7 +139,6 @@ def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset
body = ru.json()
assert ru.status_code == 200, body
assert body["name"] == payload["name"]
assert body["hash"] == body["asset_hash"]
assert body["tags"] == original_tags # tags unchanged
assert body["user_metadata"]["purpose"] == "updated"
# filename should still be present and normalized by server
@ -285,24 +280,16 @@ def test_metadata_filename_is_set_for_seed_asset_without_hash(
trigger_sync_seed_assets(http, api_base)
# Scanner emits tags as ``[root, "<dir1>/<dir2>/..."]`` — the second tag
# is the slash-joined parent subpath. For ``<root>/unit-tests/<scope>/a/b/<name>``
# the second tag is ``"unit-tests/<scope>/a/b"``.
r1 = http.get(
api_base + "/api/assets",
params={
"include_tags": f"unit-tests/{scope}/a/b",
"name_contains": name,
},
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
timeout=120,
)
body = r1.json()
assert r1.status_code == 200, body
matches = [a for a in body.get("assets", []) if a.get("name") == name]
assert matches, "Seed asset should be visible after sync"
# Seed assets have no hash; exclude_none drops both keys from the response
assert "asset_hash" not in matches[0]
assert "hash" not in matches[0]
assert matches[0].get("asset_hash") is None # still a seed
aid = matches[0]["id"]
r2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)

View File

@ -1,69 +0,0 @@
"""Unit tests for app.assets.helpers."""
from app.assets.helpers import expand_bucket_prefixes
class TestExpandBucketPrefixes:
def test_flat_category_unchanged(self):
# `checkpoints` is already a standalone token, no expansion needed.
assert expand_bucket_prefixes(["models", "checkpoints"]) == [
"models",
"checkpoints",
]
def test_nested_category_inserts_bucket(self):
# Path-derived shape for `models/checkpoints/flux/foo.safetensors` —
# the standalone bucket has to be present so the FE set-membership
# filter (`include_tags=models,checkpoints`) matches the asset.
assert expand_bucket_prefixes(["models", "checkpoints/flux"]) == [
"models",
"checkpoints/flux",
"checkpoints",
]
def test_deeply_nested_only_first_segment_expands(self):
# Only the FIRST slash segment ever gets emitted as a standalone —
# intermediate path segments don't have routing significance.
assert expand_bucket_prefixes(
["models", "diffusers/kolors/text_encoder"]
) == ["models", "diffusers/kolors/text_encoder", "diffusers"]
def test_unknown_prefix_does_not_expand(self):
# Free-form user labels with slashes whose first segment is not a
# registered bucket pass through opaquely.
assert expand_bucket_prefixes(["models", "my-org/team-a"]) == [
"models",
"my-org/team-a",
]
def test_idempotent(self):
# Re-applying the helper is a no-op once the bucket is in the set.
expanded = expand_bucket_prefixes(["models", "checkpoints/flux"])
assert expand_bucket_prefixes(expanded) == expanded
def test_does_not_duplicate_existing_bucket(self):
# If the caller already supplied the standalone bucket, don't add a
# second copy.
assert expand_bucket_prefixes(
["models", "checkpoints/flux", "checkpoints"]
) == ["models", "checkpoints/flux", "checkpoints"]
def test_preserves_caller_order(self):
# User tags after path tags must stay after; the inserted bucket
# token slots in immediately after its slash-joined parent so the
# microsecond stagger lands it at path-tier before user-tier.
assert expand_bucket_prefixes(
["models", "loras/style", "favorite", "v2"]
) == ["models", "loras/style", "loras", "favorite", "v2"]
def test_empty_input(self):
assert expand_bucket_prefixes([]) == []
def test_input_root_with_subpath_no_expansion(self):
# `portraits` isn't a registered model category, so the input
# subpath stays opaque (FE filter doesn't have a checkpoint-loader
# analogue for input subfolders).
assert expand_bucket_prefixes(["input", "portraits/2026"]) == [
"input",
"portraits/2026",
]

View File

@ -1,349 +0,0 @@
"""Integration tests for cursor-based pagination on GET /api/assets.
These tests exercise the handler/service/query path end-to-end;
cursor-encoding-level tests live in
tests-unit/assets_test/services/test_cursor.py.
"""
import pytest
import requests
def _seed(asset_factory, make_asset_bytes, count: int, tag: str) -> list[str]:
names = [f"cursor_{i:02d}.safetensors" for i in range(count)]
for n in names:
asset_factory(
n,
["models", "checkpoints", "unit-tests", tag],
{},
make_asset_bytes(n, size=2048),
)
return sorted(names)
def test_cursor_pages_all_items_in_order(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
names = _seed(asset_factory, make_asset_bytes, count=5, tag="cursor-walk")
params = {
"include_tags": "unit-tests,cursor-walk",
"sort": "name",
"order": "asc",
"limit": "2",
}
seen: list[str] = []
after: str | None = None
pages = 0
while True:
page_params = dict(params)
if after is not None:
page_params["after"] = after
r = http.get(api_base + "/api/assets", params=page_params, timeout=120)
assert r.status_code == 200, r.text
body = r.json()
seen.extend(a["name"] for a in body["assets"])
pages += 1
after = body.get("next_cursor")
if after is None:
break
assert body["has_more"] is True
assert pages < 10, "guard against runaway cursor loop"
assert seen == names, f"expected {names}, got {seen}"
# Last page should have has_more False
assert body["has_more"] is False
assert "next_cursor" not in body
def test_cursor_invalid_returns_400(http: requests.Session, api_base: str):
r = http.get(
api_base + "/api/assets",
params={"after": "not-a-real-cursor", "sort": "created_at"},
timeout=120,
)
assert r.status_code == 400, r.text
body = r.json()
assert body["error"]["code"] == "INVALID_CURSOR"
def test_cursor_sort_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
_seed(asset_factory, make_asset_bytes, count=2, tag="cursor-mismatch")
# Take a real cursor minted for sort=name.
r = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-mismatch",
"sort": "name",
"order": "asc",
"limit": "1",
},
timeout=120,
)
assert r.status_code == 200
cursor = r.json()["next_cursor"]
assert cursor is not None
# Replay against sort=created_at — should fail with INVALID_CURSOR.
r2 = http.get(
api_base + "/api/assets",
params={"after": cursor, "sort": "created_at"},
timeout=120,
)
assert r2.status_code == 400, r2.text
assert r2.json()["error"]["code"] == "INVALID_CURSOR"
def test_cursor_wins_over_offset(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-vs-offset")
# Take a cursor that points past the first item.
r = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-vs-offset",
"sort": "name",
"order": "asc",
"limit": "1",
},
timeout=120,
)
assert r.status_code == 200, r.text
cursor = r.json()["next_cursor"]
assert cursor is not None
# Pass both 'after' and a large offset. Cursor must win; offset is ignored.
r2 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-vs-offset",
"sort": "name",
"order": "asc",
"limit": "1",
"after": cursor,
"offset": "999",
},
timeout=120,
)
assert r2.status_code == 200
body = r2.json()
# Should land on the second name in sorted order — not skip ahead by 999.
assert [a["name"] for a in body["assets"]] == [names[1]]
def test_next_cursor_absent_when_no_more_results(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
_seed(asset_factory, make_asset_bytes, count=2, tag="cursor-exhaust")
r = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-exhaust",
"sort": "name",
"order": "asc",
"limit": "50",
},
timeout=120,
)
assert r.status_code == 200, r.text
body = r.json()
assert body["has_more"] is False
assert "next_cursor" not in body
def test_cursor_pagination_first_page_mints_cursor(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
"""First-page request (no `after`) must still return `next_cursor` when
more rows exist, or pagination is unreachable from a cold start.
"""
_seed(asset_factory, make_asset_bytes, count=3, tag="cursor-first-page")
r = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,cursor-first-page", "sort": "name", "order": "asc", "limit": "2"},
timeout=120,
)
assert r.status_code == 200, r.text
body = r.json()
assert body["has_more"] is True
assert body.get("next_cursor"), "first page must mint a cursor when more rows exist"
def test_cursor_no_spurious_cursor_when_page_size_equals_remainder(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
"""When `total` is an exact multiple of `limit`, the final page must
NOT carry a next_cursor — there is nothing past it.
"""
_seed(asset_factory, make_asset_bytes, count=4, tag="cursor-exact-multiple")
# Page 1
r = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2"},
timeout=120,
)
assert r.status_code == 200, r.text
cursor = r.json()["next_cursor"]
assert cursor is not None
# Page 2 — should exhaust the set with no cursor for a phantom page 3
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2", "after": cursor},
timeout=120,
)
assert r2.status_code == 200, r2.text
body = r2.json()
assert len(body["assets"]) == 2
assert body["has_more"] is False
assert "next_cursor" not in body
@pytest.mark.parametrize("sort_field", ["created_at", "updated_at", "size"])
def test_cursor_walks_for_non_name_sorts(sort_field, http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
"""Cursor pagination must work for every sort field the contract claims.
Without this, the `created_at` / `updated_at` (time-encoded micros) and
`size` (int-encoded) cursor paths go entirely unexercised end-to-end.
"""
# Sizes increase strictly by index, so `size desc` has a deterministic
# expected order. Time-based sorts (created_at / updated_at) can tie when
# rows are inserted faster than the DB's timestamp resolution; for those
# we check coverage and no-duplicates and let the keyset tiebreaker do
# the rest, instead of sleeping between inserts and asserting an order
# that depends on clock granularity.
names = []
for i in range(4):
n = f"cursor_{sort_field}_{i:02d}.safetensors"
asset_factory(n, ["models", "checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i))
names.append(n)
params = {
"include_tags": f"unit-tests,cursor-{sort_field}",
"sort": sort_field,
"order": "desc",
"limit": "2",
}
seen: list[str] = []
after: str | None = None
pages = 0
while True:
page_params = dict(params)
if after is not None:
page_params["after"] = after
r = http.get(api_base + "/api/assets", params=page_params, timeout=120)
assert r.status_code == 200, r.text
body = r.json()
seen.extend(a["name"] for a in body["assets"])
after = body.get("next_cursor")
pages += 1
if after is None:
break
assert pages < 10, "guard against runaway cursor loop"
# No duplicates: a faulty keyset boundary that returns the same row across
# two pages must fail this check.
assert len(seen) == len(set(seen)), (
f"cursor walk repeated rows for sort={sort_field}: {seen}"
)
# Full coverage: every seeded asset reached exactly once.
assert set(seen) == set(names), (
f"missing items for sort={sort_field}: expected {set(names)}, got {set(seen)}"
)
# Strict order check for the only field with a clock-independent ordering.
if sort_field == "size":
assert seen == list(reversed(names)), (
f"size cursor walked out of order: got {seen}, expected {list(reversed(names))}"
)
def test_cursor_order_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
"""A cursor minted under desc order replayed against asc must 400, not
silently walk the wrong direction."""
_seed(asset_factory, make_asset_bytes, count=3, tag="cursor-order-flip")
r = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-order-flip",
"sort": "name",
"order": "desc",
"limit": "1",
},
timeout=120,
)
assert r.status_code == 200, r.text
cursor = r.json()["next_cursor"]
assert cursor is not None
# Replay with order flipped to asc — server must reject the cursor.
r2 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-order-flip",
"sort": "name",
"order": "asc",
"limit": "1",
"after": cursor,
},
timeout=120,
)
assert r2.status_code == 400, r2.text
assert r2.json()["error"]["code"] == "INVALID_CURSOR"
def test_cursor_invalid_cursor_at_microsecond_boundary(http: requests.Session, api_base: str):
"""A cursor carrying an out-of-range microsecond timestamp must map to
400 INVALID_CURSOR, not 500."""
import base64
import json
# 10^18 microseconds ≈ year 33658, well past datetime.MAX_YEAR.
# `o` and `order=` must be set; otherwise decode fails earlier on the
# missing-order branch and the µs-overflow path is never exercised.
payload = {"s": "created_at", "o": "desc", "v": "999999999999999999999", "id": "asset-x"}
raw = json.dumps(payload, separators=(",", ":")).encode("utf-8")
cursor = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
r = http.get(
api_base + "/api/assets",
params={"after": cursor, "sort": "created_at", "order": "desc"},
timeout=120,
)
assert r.status_code == 400, r.text
assert r.json()["error"]["code"] == "INVALID_CURSOR"
def test_cursor_pagination_stable_after_delete(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-delete")
# Page 1.
r = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-delete",
"sort": "name",
"order": "asc",
"limit": "2",
},
timeout=120,
)
assert r.status_code == 200
body = r.json()
page1_names = [a["name"] for a in body["assets"]]
cursor = body["next_cursor"]
assert cursor is not None
assert page1_names == names[:2]
# Delete an item from page 1 (already returned) — cursor should still
# locate the next page from where it was minted, not re-index.
target_id = body["assets"][0]["id"]
d = http.delete(api_base + f"/api/assets/{target_id}", timeout=120)
assert d.status_code in (200, 204), d.text
# Page 2 via cursor.
r2 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-delete",
"sort": "name",
"order": "asc",
"limit": "2",
"after": cursor,
},
timeout=120,
)
assert r2.status_code == 200, r2.text
body2 = r2.json()
assert [a["name"] for a in body2["assets"]] == names[2:]

View File

@ -3,7 +3,6 @@ import uuid
import pytest
import requests
from helpers import assert_hash_fields_consistent
def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
@ -27,10 +26,6 @@ def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asse
got1 = [a["name"] for a in b1["assets"]]
assert got1 == sorted(names)[:2]
assert b1["has_more"] is True
# Populated assets in list responses must carry both `hash` and `asset_hash` consistently
for asset in b1["assets"]:
assert_hash_fields_consistent(asset)
assert "hash" in asset, "populated asset must emit hash on list endpoint"
r2 = http.get(
api_base + "/api/assets",

View File

@ -29,10 +29,7 @@ def create_seed_file(comfy_tmp_base_dir: Path):
def find_asset(http: requests.Session, api_base: str):
"""Query API for assets matching scope and optional name."""
def _find(scope: str, name: str | None = None) -> list[dict]:
# Scanner now emits tags as ``[root, "<dir1>/<dir2>/..."]`` rather than
# one tag per directory. For files at ``<root>/unit-tests/<scope>/...``
# the second tag is exactly ``"unit-tests/<scope>"``.
params = {"include_tags": f"unit-tests/{scope}"}
params = {"include_tags": f"unit-tests,{scope}"}
if name:
params["name_contains"] = name
r = http.get(f"{api_base}/api/assets", params=params, timeout=120)
@ -141,7 +138,4 @@ def test_special_chars_in_path_escaped_correctly(
trigger_sync_seed_assets(http, api_base)
trigger_sync_seed_assets(http, api_base)
# Scanner emits the full parent subpath as a single slash-joined tag, so
# the lookup tag is ``unit-tests/<scope>`` even when <scope> itself
# contains a slash (parent + special-char dirname).
assert find_asset(scope, fp.name), "Asset with special chars should survive"
assert find_asset(scope.split("/")[0], fp.name), "Asset with special chars should survive"

View File

@ -5,21 +5,6 @@ from concurrent.futures import ThreadPoolExecutor
import requests
import pytest
from app.assets.api.schemas_out import Asset, AssetCreated
from helpers import get_asset_filename
def test_asset_created_inherits_hash_field():
"""AssetCreated must inherit `hash` from Asset so POST /api/assets responses emit it.
Schema-level guard: integration tests cover the wire shape, but inheritance
drift (e.g. AssetCreated ever being redefined to no longer extend Asset)
would silently drop `hash` from a major endpoint without this check.
"""
assert "hash" in Asset.model_fields
assert "hash" in AssetCreated.model_fields
assert AssetCreated.model_fields["hash"].annotation == Asset.model_fields["hash"].annotation
def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes):
name = "dup_a.safetensors"
@ -32,7 +17,6 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma
a1 = r1.json()
assert r1.status_code == 201, a1
assert a1["created_new"] is True
assert a1["hash"] == a1["asset_hash"]
# Second upload with the same data and name creates a new AssetReference (duplicates allowed)
# Returns 200 because Asset already exists, but a new AssetReference is created
@ -42,7 +26,6 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma
a2 = r2.json()
assert r2.status_code in (200, 201), a2
assert a2["asset_hash"] == a1["asset_hash"]
assert a2["hash"] == a1["hash"]
assert a2["id"] != a1["id"] # new reference with same content
# Third upload with the same data but different name also creates new AssetReference
@ -67,7 +50,6 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
b1 = r1.json()
assert r1.status_code == 201, b1
h = b1["asset_hash"]
assert b1["hash"] == h
# Now POST /api/assets with only hash and no file
files = [
@ -81,15 +63,6 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
assert r2.status_code == 200, b2 # fast path returns 200 with created_new == False
assert b2["created_new"] is False
assert b2["asset_hash"] == h
assert b2["hash"] == h
assert b2.get("file_path") is None
assert b2.get("display_name") is None
rg = http.get(f"{api_base}/api/assets/{b2['id']}", timeout=120)
detail = rg.json()
assert rg.status_code == 200, detail
assert detail.get("file_path") is None
assert detail.get("display_name") is None
def test_upload_fastpath_with_known_hash_and_file(
@ -102,7 +75,6 @@ def test_upload_fastpath_with_known_hash_and_file(
b1 = r1.json()
assert r1.status_code == 201, b1
h = b1["asset_hash"]
assert b1["hash"] == h
# Send both file and hash of existing content -> server must drain file and create from hash (200)
files = {"file": ("ignored.bin", b"ignored" * 10, "application/octet-stream")}
@ -112,7 +84,6 @@ def test_upload_fastpath_with_known_hash_and_file(
assert r2.status_code == 200, b2
assert b2["created_new"] is False
assert b2["asset_hash"] == h
assert b2["hash"] == h
def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str):
@ -136,54 +107,6 @@ def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
@pytest.mark.parametrize(
("tags", "extension", "expected_prefix", "expected_display_prefix"),
[
(["input", "unit-tests"], ".png", "input", ""),
(["models", "checkpoints", "unit-tests"], ".safetensors", "models/checkpoints", ""),
],
)
def test_upload_response_includes_file_path_and_display_name(
tags: list[str],
extension: str,
expected_prefix: str,
expected_display_prefix: str,
http: requests.Session,
api_base: str,
asset_factory,
make_asset_bytes,
):
scope = f"response-paths-{uuid.uuid4().hex[:6]}"
scoped_tags = [*tags, scope]
name = f"asset_response_path{extension}"
created = asset_factory(name, scoped_tags, {}, make_asset_bytes(name, 1024))
stored_filename = get_asset_filename(created["asset_hash"], extension)
expected_suffix = f"unit-tests/{scope}/{stored_filename}"
expected_file_path = f"{expected_prefix}/{expected_suffix}"
expected_display_name = f"{expected_display_prefix}{expected_suffix}"
assert created["file_path"] == expected_file_path
assert created["display_name"] == expected_display_name
detail_r = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
detail = detail_r.json()
assert detail_r.status_code == 200, detail
assert detail["file_path"] == expected_file_path
assert detail["display_name"] == expected_display_name
list_r = http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "limit": "50"},
timeout=120,
)
listed = list_r.json()
assert list_r.status_code == 200, listed
match = next(a for a in listed["assets"] if a["id"] == created["id"])
assert match["file_path"] == expected_file_path
assert match["display_name"] == expected_display_name
@pytest.mark.parametrize("root", ["input", "output"])
def test_concurrent_upload_identical_bytes_different_names(
root: str,
@ -219,8 +142,6 @@ def test_concurrent_upload_identical_bytes_different_names(
assert r1.status_code in (200, 201), b1
assert r2.status_code in (200, 201), b2
assert b1["asset_hash"] == b2["asset_hash"]
assert b1["hash"] == b2["hash"]
assert b1["hash"] == b1["asset_hash"]
assert b1["id"] != b2["id"]
created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))])

View File

@ -1,135 +0,0 @@
"""HTTP-layer smoke test: user-added tags via POST /api/assets/{id}/tags
land after path tags when read back via GET /api/assets.
Exercises the full route handler -> service -> query path that the unit
tests at tests-unit/assets_test/queries/test_asset_info.py only cover at
the service layer.
"""
import json
import pytest
import requests
@pytest.fixture
def smoke_asset(http: requests.Session, api_base: str):
"""Upload a single asset into models/checkpoints/unit-tests/smoke
and delete it on teardown."""
name = "smoke_user_tag.safetensors"
tags = ["models", "checkpoints", "unit-tests", "smoke"]
files = {"file": (name, b"S" * 4096, "application/octet-stream")}
form_data = {
"tags": json.dumps(tags),
"name": name,
"user_metadata": json.dumps({}),
}
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
assert r.status_code == 201, r.text
body = r.json()
yield body
http.delete(
f"{api_base}/api/assets/{body['id']}?delete_content=true", timeout=30
)
def _fetch_asset_tags(http, api_base, ref_id):
r = http.get(f"{api_base}/api/assets/{ref_id}", timeout=30)
assert r.status_code == 200, r.text
return r.json()["tags"]
def test_user_tag_lands_after_path_tags_via_http(
http: requests.Session, api_base: str, smoke_asset: dict
):
ref_id = smoke_asset["id"]
initial_tags = _fetch_asset_tags(http, api_base, ref_id)
# Path tags should already be at the front in upload order.
assert initial_tags[:2] == ["models", "checkpoints"]
# Add a user tag that would jump to position 0 under alphabetical sort.
r = http.post(
f"{api_base}/api/assets/{ref_id}/tags",
json={"tags": ["aaa-user-tag"]},
timeout=30,
)
assert r.status_code in (200, 201), r.text
tags_after = _fetch_asset_tags(http, api_base, ref_id)
# Path tags must still be at the front; user tag goes to the end.
assert tags_after[0] == "models"
assert tags_after[1] == "checkpoints"
assert "aaa-user-tag" in tags_after
assert tags_after[-1] == "aaa-user-tag"
def test_user_tag_batch_lands_after_path_tags_via_http(
http: requests.Session, api_base: str, smoke_asset: dict
):
ref_id = smoke_asset["id"]
# Add three user tags in a single request, in non-alphabetical input
# order. They should all land after the path tags (microsecond stagger
# in set_reference_tags / add_tags_to_reference is what makes this
# work — without it, "aaa" would jump to position 0).
r = http.post(
f"{api_base}/api/assets/{ref_id}/tags",
json={"tags": ["zzz-z", "favorite", "aaa-experiment"]},
timeout=30,
)
assert r.status_code in (200, 201), r.text
tags_after = _fetch_asset_tags(http, api_base, ref_id)
assert tags_after[0] == "models"
assert tags_after[1] == "checkpoints"
user_tail = tags_after[len({"models", "checkpoints", "unit-tests", "smoke"}):]
assert set(user_tail) >= {"zzz-z", "favorite", "aaa-experiment"}
# Critically: alphabetical sort would put 'aaa-experiment' at position 0.
assert tags_after.index("aaa-experiment") > tags_after.index("models")
assert tags_after.index("aaa-experiment") > tags_after.index("checkpoints")
@pytest.fixture
def nested_checkpoint_asset(http: requests.Session, api_base: str):
"""Upload a checkpoint at the slash-joined path shape cloud emits
(`models/checkpoints/flux/...`), then delete it on teardown.
"""
name = "nested_checkpoint.safetensors"
tags = ["models", "checkpoints/flux"]
files = {"file": (name, b"S" * 4096, "application/octet-stream")}
form_data = {
"tags": json.dumps(tags),
"name": name,
"user_metadata": json.dumps({}),
}
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
assert r.status_code == 201, r.text
body = r.json()
yield body
http.delete(
f"{api_base}/api/assets/{body['id']}?delete_content=true", timeout=30
)
def test_nested_checkpoint_satisfies_fe_set_filter(
http: requests.Session, api_base: str, nested_checkpoint_asset: dict
):
"""The case Simon flagged: a nested-path checkpoint must still match
`include_tags=models,checkpoints` — the FE combo-widget filter.
"""
ref_id = nested_checkpoint_asset["id"]
stored = _fetch_asset_tags(http, api_base, ref_id)
# tag[1] keeps cloud's slash-joined positional contract; tag[2] holds
# the standalone bucket the FE filter looks for.
assert stored[:3] == ["models", "checkpoints/flux", "checkpoints"]
# The actual FE query — exact set-membership across both tokens.
r = http.get(
f"{api_base}/api/assets",
params=[("include_tags", "models"), ("include_tags", "checkpoints")],
timeout=30,
)
assert r.status_code == 200, r.text
returned_ids = {a["id"] for a in r.json()["assets"]}
assert ref_id in returned_ids

View File

@ -1,245 +0,0 @@
"""Tests for enrich_output_with_assets in comfy_execution/asset_enrichment.py."""
import os
import types
import unittest
from unittest.mock import MagicMock, patch
def _make_args(enable_assets: bool):
a = types.SimpleNamespace()
a.enable_assets = enable_assets
return a
def _make_db_ref(ref_id="ref-id-1"):
ref = MagicMock()
ref.id = ref_id
return ref
def _make_register_result(ref_id="ref-id-2"):
result = MagicMock()
result.ref.id = ref_id
return result
# Platform-appropriate absolute base. tempfile.gettempdir() returns C:\... on
# Windows and /tmp on POSIX, so containment via commonpath behaves naturally.
_DEFAULT_BASE = os.path.join(__import__("tempfile").gettempdir(), "asset-enrichment-test-base")
def _call(output_ui, *, enable_assets=True, file_exists=True, db_ref=None, register_result=None, directory=_DEFAULT_BASE):
fake_session_cm = MagicMock()
fake_session_cm.__enter__ = MagicMock(return_value=MagicMock())
fake_session_cm.__exit__ = MagicMock(return_value=False)
mocked_modules = {
"comfy.cli_args": MagicMock(args=_make_args(enable_assets)),
"folder_paths": MagicMock(get_directory_by_type=MagicMock(return_value=directory)),
"app.assets.services.ingest": MagicMock(
register_file_in_place=MagicMock(return_value=register_result or _make_register_result()),
DependencyMissingError=type("DependencyMissingError", (Exception,), {}),
),
"app.assets.database.queries.asset_reference": MagicMock(
get_reference_by_file_path=MagicMock(return_value=db_ref),
),
"app.database.db": MagicMock(create_session=MagicMock(return_value=fake_session_cm)),
}
# Only os.path.isfile is patched — abspath/join must run natively so the
# containment check sees real platform paths.
with patch.dict("sys.modules", mocked_modules), \
patch("os.path.isfile", return_value=file_exists):
import importlib
import comfy_execution.asset_enrichment as mod
importlib.reload(mod)
return mod.enrich_output_with_assets(output_ui)
class TestEnrichOutputWithAssets(unittest.TestCase):
def test_disabled_returns_unchanged(self):
output = {"images": [{"filename": "a.png", "subfolder": "", "type": "output"}]}
result = _call(output, enable_assets=False)
self.assertNotIn("id", result["images"][0])
def test_non_list_value_passed_through(self):
output = {"text": "hello"}
result = _call(output)
self.assertEqual(result["text"], "hello")
def test_entry_without_filename_unchanged(self):
output = {"latent": [{"subfolder": "", "type": "output"}]}
result = _call(output)
self.assertNotIn("id", result["latent"][0])
def test_entry_without_type_unchanged(self):
output = {"data": [{"filename": "a.png", "subfolder": ""}]}
result = _call(output)
self.assertNotIn("id", result["data"][0])
def test_file_not_on_disk_unchanged(self):
output = {"images": [{"filename": "missing.png", "subfolder": "", "type": "output"}]}
result = _call(output, file_exists=False)
self.assertNotIn("id", result["images"][0])
def test_unknown_type_returns_none_directory_unchanged(self):
output = {"images": [{"filename": "a.png", "subfolder": "", "type": "unknown"}]}
result = _call(output, directory=None)
self.assertNotIn("id", result["images"][0])
def test_db_hit_injects_id(self):
db_ref = _make_db_ref(ref_id="db-ref")
output = {"images": [{"filename": "a.png", "subfolder": "", "type": "output"}]}
result = _call(output, db_ref=db_ref)
img = result["images"][0]
self.assertEqual(img["id"], "db-ref")
# Only id is injected — no asset_hash, name, preview_url, size
self.assertNotIn("asset_hash", img)
self.assertNotIn("name", img)
self.assertNotIn("preview_url", img)
self.assertNotIn("size", img)
def test_db_miss_falls_back_to_register(self):
reg = _make_register_result(ref_id="inline-ref")
output = {"images": [{"filename": "new.png", "subfolder": "", "type": "output"}]}
result = _call(output, db_ref=None, register_result=reg)
img = result["images"][0]
self.assertEqual(img["id"], "inline-ref")
self.assertNotIn("asset_hash", img)
self.assertNotIn("name", img)
def test_original_entry_not_mutated(self):
orig = {"filename": "a.png", "subfolder": "", "type": "output"}
output = {"images": [orig]}
_call(output)
self.assertNotIn("id", orig)
def test_enrichment_error_does_not_block_sibling_entries(self):
call_count = [0]
good_reg = _make_register_result(ref_id="good-ref")
def register_side_effect(abs_path, name, tags):
call_count[0] += 1
if call_count[0] == 1:
raise RuntimeError("boom")
return good_reg
fake_session_cm = MagicMock()
fake_session_cm.__enter__ = MagicMock(return_value=MagicMock())
fake_session_cm.__exit__ = MagicMock(return_value=False)
mocked_modules = {
"comfy.cli_args": MagicMock(args=_make_args(True)),
"folder_paths": MagicMock(get_directory_by_type=MagicMock(return_value=_DEFAULT_BASE)),
"app.assets.services.ingest": MagicMock(
register_file_in_place=register_side_effect,
DependencyMissingError=type("DependencyMissingError", (Exception,), {}),
),
"app.assets.database.queries.asset_reference": MagicMock(
get_reference_by_file_path=MagicMock(return_value=None),
),
"app.database.db": MagicMock(create_session=MagicMock(return_value=fake_session_cm)),
}
output = {
"images": [
{"filename": "bad.png", "subfolder": "", "type": "output"},
{"filename": "good.png", "subfolder": "", "type": "output"},
]
}
with patch.dict("sys.modules", mocked_modules), \
patch("os.path.isfile", return_value=True):
import importlib
import comfy_execution.asset_enrichment as mod
importlib.reload(mod)
result = mod.enrich_output_with_assets(output)
imgs = result["images"]
self.assertNotIn("id", imgs[0])
self.assertEqual(imgs[1]["id"], "good-ref")
def test_multiple_output_keys_all_enriched(self):
output = {
"images": [{"filename": "a.png", "subfolder": "", "type": "output"}],
"videos": [{"filename": "b.mp4", "subfolder": "", "type": "output"}],
}
result = _call(output)
self.assertIn("id", result["images"][0])
self.assertIn("id", result["videos"][0])
def test_none_entry_in_list_unchanged(self):
output = {"images": [None, {"filename": "a.png", "subfolder": "", "type": "output"}]}
result = _call(output)
self.assertIsNone(result["images"][0])
self.assertIn("id", result["images"][1])
def test_path_traversal_subfolder_skipped(self):
fake_session_cm = MagicMock()
fake_session_cm.__enter__ = MagicMock(return_value=MagicMock())
fake_session_cm.__exit__ = MagicMock(return_value=False)
register_mock = MagicMock(return_value=_make_register_result())
mocked_modules = {
"comfy.cli_args": MagicMock(args=_make_args(True)),
"folder_paths": MagicMock(get_directory_by_type=MagicMock(return_value=_DEFAULT_BASE)),
"app.assets.services.ingest": MagicMock(
register_file_in_place=register_mock,
DependencyMissingError=type("DependencyMissingError", (Exception,), {}),
),
"app.assets.database.queries.asset_reference": MagicMock(
get_reference_by_file_path=MagicMock(return_value=None),
),
"app.database.db": MagicMock(create_session=MagicMock(return_value=fake_session_cm)),
}
output = {"images": [{"filename": "passwd", "subfolder": "../../etc", "type": "output"}]}
# Do NOT patch os.path.abspath — real resolution is required for the containment check.
with patch.dict("sys.modules", mocked_modules), \
patch("os.path.isfile", return_value=True):
import importlib
import comfy_execution.asset_enrichment as mod
importlib.reload(mod)
result = mod.enrich_output_with_assets(output)
self.assertNotIn("id", result["images"][0])
register_mock.assert_not_called()
def test_absolute_filename_skipped(self):
fake_session_cm = MagicMock()
fake_session_cm.__enter__ = MagicMock(return_value=MagicMock())
fake_session_cm.__exit__ = MagicMock(return_value=False)
register_mock = MagicMock(return_value=_make_register_result())
mocked_modules = {
"comfy.cli_args": MagicMock(args=_make_args(True)),
"folder_paths": MagicMock(get_directory_by_type=MagicMock(return_value=_DEFAULT_BASE)),
"app.assets.services.ingest": MagicMock(
register_file_in_place=register_mock,
DependencyMissingError=type("DependencyMissingError", (Exception,), {}),
),
"app.assets.database.queries.asset_reference": MagicMock(
get_reference_by_file_path=MagicMock(return_value=None),
),
"app.database.db": MagicMock(create_session=MagicMock(return_value=fake_session_cm)),
}
# Absolute filename — os.path.join discards earlier components when a later one is absolute.
absolute_filename = os.path.abspath(os.sep + "etc" + os.sep + "passwd")
output = {"images": [{"filename": absolute_filename, "subfolder": "", "type": "output"}]}
with patch.dict("sys.modules", mocked_modules), \
patch("os.path.isfile", return_value=True):
import importlib
import comfy_execution.asset_enrichment as mod
importlib.reload(mod)
result = mod.enrich_output_with_assets(output)
self.assertNotIn("id", result["images"][0])
register_mock.assert_not_called()
if __name__ == "__main__":
unittest.main()