Compare commits

..

13 Commits

Author SHA1 Message Date
560e6ee5c1 Add job_ids filter for GET api assets (#13998)
Amp-Thread-ID: https://ampcode.com/threads/T-019e4ca5-b71a-7168-8f56-58b2325f34c3
Co-authored-by: Amp <amp@ampcode.com>
2026-05-22 10:53:50 +12:00
00c88a4634 Add cursor pagination for GET api assets (#14014)
Amp-Thread-ID: https://ampcode.com/threads/T-019e4ca5-b71a-7168-8f56-58b2325f34c3
Co-authored-by: Amp <amp@ampcode.com>
2026-05-22 10:53:01 +12:00
916b33c795 Add asset file_path and display_name response fields (#14005)
Amp-Thread-ID: https://ampcode.com/threads/T-019e4ca5-b71a-7168-8f56-58b2325f34c3
Co-authored-by: Amp <amp@ampcode.com>
2026-05-22 10:52:20 +12:00
35943e01ab Collapse nested asset paths into slash-joined tags (#13994)
Amp-Thread-ID: https://ampcode.com/threads/T-019e4ca5-b71a-7168-8f56-58b2325f34c3
Co-authored-by: Amp <amp@ampcode.com>
2026-05-22 10:50:45 +12:00
d6925bd55e Extract image dimensions into asset metadata (#13991)
Amp-Thread-ID: https://ampcode.com/threads/T-019e4ca5-b71a-7168-8f56-58b2325f34c3
Co-authored-by: Amp <amp@ampcode.com>
2026-05-22 10:50:37 +12:00
634dad6777 Include asset id in executed WebSocket messages (#13862)
Amp-Thread-ID: https://ampcode.com/threads/T-019e4ca5-b71a-7168-8f56-58b2325f34c3
Co-authored-by: Amp <amp@ampcode.com>
2026-05-22 10:50:30 +12:00
4da4b3738d Emit hash alongside asset_hash on Asset responses (#13739)
Amp-Thread-ID: https://ampcode.com/threads/T-019e4ca5-b71a-7168-8f56-58b2325f34c3
Co-authored-by: Amp <amp@ampcode.com>
2026-05-22 10:50:22 +12:00
32e58393b8 Add backport release workflow. (#14038) 2026-05-21 14:49:55 -07:00
b293f8cefd [Partner Nodes] add widget for automatic upscaling for the ByteDance2Reference node (#14032)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-21 11:58:03 -07:00
2ca1480f91 chore: update workflow templates to v0.9.82 (#14034) 2026-05-21 11:48:20 -07:00
6ecf5eca7a [Partner Nodes] add OpenRouter LLM node (#14007)
* [Partner Nodes] add reasoning widget to Anthropic node

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] add new OpenRouterLLM node

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] fix passing images to Grok LLM

Signed-off-by: bigcat88 <bigcat88@icloud.com>

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-21 11:36:11 -07:00
03e511862e Fix reshaping lora application (#14031)
* ModelPatcherDyanmic: purge stale vbar allocs on force cast

* ModelPatcherDynamic: restore backups before load

If doing a clean reload, mutative changes (lora application) could be
applied on-top of the already loaded weight. Restore from backup
unconditionally so that the new load is clean.
2026-05-21 09:47:16 -07:00
aab41a9ddb fix(lanczos): correct dimension transposition for single-channel tensors (#12679) 2026-05-21 23:47:20 +08:00
51 changed files with 4496 additions and 125 deletions

401
.github/workflows/backport_release.yaml vendored Normal file
View File

@ -0,0 +1,401 @@
name: Backport Release
on:
workflow_dispatch:
inputs:
branch:
description: 'Source branch containing the backported commits (PR source branch into master)'
required: true
type: string
permissions:
contents: read
pull-requests: read
checks: read
jobs:
backport-release:
name: Create backport release
runs-on: ubuntu-latest
environment: backport release
steps:
- name: Generate GitHub App token
id: app-token
uses: actions/create-github-app-token@bcd2ba49218906704ab6c1aa796996da409d3eb1
with:
app-id: ${{ secrets.FEN_RELEASE_APP_ID }}
private-key: ${{ secrets.FEN_RELEASE_PRIVATE_KEY }}
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
with:
token: ${{ steps.app-token.outputs.token }}
fetch-depth: 0
fetch-tags: true
- name: Configure git
run: |
git config user.name "fen-release[bot]"
git config user.email "fen-release[bot]@users.noreply.github.com"
- name: Validate source branch exists
env:
SOURCE_BRANCH: ${{ inputs.branch }}
run: |
set -euo pipefail
git fetch origin "refs/heads/${SOURCE_BRANCH}:refs/remotes/origin/${SOURCE_BRANCH}"
if ! git show-ref --verify --quiet "refs/remotes/origin/${SOURCE_BRANCH}"; then
echo "::error::Source branch '${SOURCE_BRANCH}' not found on origin."
exit 1
fi
- name: Determine latest stable release
id: latest
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
run: |
set -euo pipefail
# List all tags matching vMAJOR.MINOR.PATCH and pick the highest by numeric
# comparison of each component. We DO NOT use `sort -V` because it treats
# v0.19.99 as higher than v0.20.1.
latest_tag="$(
git tag --list 'v[0-9]*.[0-9]*.[0-9]*' \
| grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' \
| awk -F'[v.]' '{ printf "%010d %010d %010d %s\n", $2, $3, $4, $0 }' \
| sort -k1,1n -k2,2n -k3,3n \
| tail -n1 \
| awk '{print $4}'
)"
if [[ -z "${latest_tag}" ]]; then
echo "::error::No stable release tags (vMAJOR.MINOR.PATCH) were found."
exit 1
fi
# Parse components
ver="${latest_tag#v}"
major="${ver%%.*}"
rest="${ver#*.}"
minor="${rest%%.*}"
patch="${rest#*.}"
new_patch=$((patch + 1))
new_version="v${major}.${minor}.${new_patch}"
release_branch="release/v${major}.${minor}"
latest_sha="$(git rev-list -n 1 "refs/tags/${latest_tag}")"
echo "latest_tag=${latest_tag}" >> "$GITHUB_OUTPUT"
echo "latest_sha=${latest_sha}" >> "$GITHUB_OUTPUT"
echo "major=${major}" >> "$GITHUB_OUTPUT"
echo "minor=${minor}" >> "$GITHUB_OUTPUT"
echo "patch=${patch}" >> "$GITHUB_OUTPUT"
echo "new_version=${new_version}" >> "$GITHUB_OUTPUT"
echo "new_version_no_v=${major}.${minor}.${new_patch}" >> "$GITHUB_OUTPUT"
echo "release_branch=${release_branch}" >> "$GITHUB_OUTPUT"
echo "Latest stable release: ${latest_tag} (${latest_sha})"
echo "New version will be: ${new_version}"
echo "Release branch: ${release_branch}"
- name: Validate source branch is cut directly from the latest stable release
env:
SOURCE_BRANCH: ${{ inputs.branch }}
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
run: |
set -euo pipefail
source_sha="$(git rev-parse "refs/remotes/origin/${SOURCE_BRANCH}")"
# The source branch must be cut directly off the latest stable tag.
# "Cut directly off" means: walking first-parent from the source tip
# eventually reaches LATEST_TAG_SHA. This rejects branches that were
# cut from master after the tag (which would carry unrelated commits),
# while accepting a branch rooted at the tag with N backport commits
# on top (each of which may itself be a merge — first-parent walks
# through the mainline of the branch).
if ! git rev-list --first-parent "${source_sha}" \
| grep -qx "${LATEST_TAG_SHA}"; then
echo "::error::Source branch '${SOURCE_BRANCH}' is not cut from '${LATEST_TAG}'."
echo "::error::Its first-parent history does not include ${LATEST_TAG_SHA}."
exit 1
fi
# Additionally, every commit added on top of the tag (the set we are
# about to publish) must itself be a descendant of the tag along
# first-parent — i.e. no sibling commits from master sneak in via a
# non-first-parent path. Enforce by requiring that the symmetric
# difference is empty in one direction: commits in source that are
# NOT first-parent-reachable from source starting at the tag.
# We do this by intersecting:
# A = commits reachable from source but not from tag (full DAG)
# B = commits on the first-parent chain from source down to tag
# and requiring A == B.
all_added="$(git rev-list "${LATEST_TAG_SHA}..${source_sha}" | sort)"
first_parent_added="$(
git rev-list --first-parent "${LATEST_TAG_SHA}..${source_sha}" | sort
)"
if [[ "${all_added}" != "${first_parent_added}" ]]; then
echo "::error::Source branch '${SOURCE_BRANCH}' contains commits not on its first-parent chain from '${LATEST_TAG}'."
echo "::error::This usually means the branch was cut from master (not from the tag) or contains a merge from master."
echo "Commits reachable but not on first-parent chain:"
comm -23 <(printf '%s\n' "${all_added}") <(printf '%s\n' "${first_parent_added}") \
| while read -r sha; do
echo " $(git log -1 --format='%h %s' "${sha}")"
done
exit 1
fi
added_count="$(printf '%s\n' "${all_added}" | grep -c . || true)"
echo "Source branch is cut directly from ${LATEST_TAG} with ${added_count} commit(s) on top."
- name: Validate PR exists, is named correctly, and checks pass
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
SOURCE_BRANCH: ${{ inputs.branch }}
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
REPO: ${{ github.repository }}
run: |
set -euo pipefail
expected_title="ComfyUI backport release ${NEW_VERSION}"
# Find open PRs from this branch into master
pr_json="$(
gh pr list \
--repo "${REPO}" \
--state open \
--head "${SOURCE_BRANCH}" \
--base master \
--json number,title,headRefOid \
--limit 10
)"
pr_count="$(echo "${pr_json}" | jq 'length')"
if [[ "${pr_count}" -eq 0 ]]; then
echo "::error::No open PR found from '${SOURCE_BRANCH}' into 'master'."
exit 1
fi
# Pick the PR matching the expected title
pr_number="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
map(select(.title == $t)) | .[0].number // empty
')"
pr_head_sha="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
map(select(.title == $t)) | .[0].headRefOid // empty
')"
if [[ -z "${pr_number}" ]]; then
echo "::error::No open PR from '${SOURCE_BRANCH}' into 'master' is titled '${expected_title}'."
echo "Found PRs:"
echo "${pr_json}" | jq -r '.[] | " #\(.number): \(.title)"'
exit 1
fi
echo "Found PR #${pr_number} titled '${expected_title}' (head ${pr_head_sha})."
# Verify all check runs on the head commit have completed successfully.
# A check is considered passing if conclusion is success, neutral, or skipped.
checks_json="$(
gh api \
--paginate \
"repos/${REPO}/commits/${pr_head_sha}/check-runs" \
--jq '.check_runs[] | {name: .name, status: .status, conclusion: .conclusion}'
)"
if [[ -z "${checks_json}" ]]; then
echo "::error::No check runs found on PR head commit ${pr_head_sha}."
exit 1
fi
echo "Check runs on ${pr_head_sha}:"
echo "${checks_json}" | jq -s '.'
failing="$(echo "${checks_json}" | jq -s '
map(select(
.status != "completed"
or (.conclusion as $c
| ["success","neutral","skipped"]
| index($c) | not)
))
')"
failing_count="$(echo "${failing}" | jq 'length')"
if [[ "${failing_count}" -gt 0 ]]; then
echo "::error::One or more checks have not passed on PR head commit ${pr_head_sha}:"
echo "${failing}" | jq -r '.[] | " - \(.name): status=\(.status) conclusion=\(.conclusion)"'
exit 1
fi
echo "All checks have passed on ${pr_head_sha}."
- name: Prepare release branch
id: prepare
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
REPO: ${{ github.repository }}
SOURCE_BRANCH: ${{ inputs.branch }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
PATCH: ${{ steps.latest.outputs.patch }}
run: |
set -euo pipefail
# Try to fetch the release branch. If patch == 0, it shouldn't exist yet
# and we'll create it from the latest stable tag. If patch > 0, it must
# already exist and its tip must equal the latest stable tag commit (i.e.
# the previous patch release).
if git ls-remote --exit-code --heads origin "${RELEASE_BRANCH}" >/dev/null 2>&1; then
echo "Release branch '${RELEASE_BRANCH}' already exists on origin."
git fetch origin "refs/heads/${RELEASE_BRANCH}:refs/remotes/origin/${RELEASE_BRANCH}"
git checkout -B "${RELEASE_BRANCH}" "refs/remotes/origin/${RELEASE_BRANCH}"
current_tip="$(git rev-parse HEAD)"
if [[ "${current_tip}" != "${LATEST_TAG_SHA}" ]]; then
echo "::error::Release branch '${RELEASE_BRANCH}' tip (${current_tip}) is not at the latest stable release '${LATEST_TAG}' (${LATEST_TAG_SHA})."
echo "::error::Refusing to release on top of a divergent branch."
exit 1
fi
echo "branch_existed=true" >> "$GITHUB_OUTPUT"
else
if [[ "${PATCH}" != "0" ]]; then
echo "::error::Release branch '${RELEASE_BRANCH}' does not exist on origin, but the latest stable release '${LATEST_TAG}' has patch=${PATCH} (>0). This is inconsistent."
exit 1
fi
echo "Release branch '${RELEASE_BRANCH}' does not exist. Creating from ${LATEST_TAG}."
git checkout -B "${RELEASE_BRANCH}" "refs/tags/${LATEST_TAG}"
echo "branch_existed=false" >> "$GITHUB_OUTPUT"
fi
- name: Fast-forward merge source branch into release branch
env:
SOURCE_BRANCH: ${{ inputs.branch }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
run: |
set -euo pipefail
# --ff-only guarantees no merge commit is created. If a fast-forward is
# not possible (i.e. the release branch has commits the source branch
# doesn't), the merge will fail and we abort. Because we already validated
# that the source branch is rooted on the latest stable tag, and the
# release branch tip equals that same tag, this fast-forward should
# always succeed for a well-formed backport branch.
if ! git merge --ff-only "refs/remotes/origin/${SOURCE_BRANCH}"; then
echo "::error::Cannot fast-forward '${RELEASE_BRANCH}' to '${SOURCE_BRANCH}'. A merge commit would be required. Aborting."
exit 1
fi
echo "Fast-forwarded '${RELEASE_BRANCH}' to tip of '${SOURCE_BRANCH}'."
- name: Bump version files
env:
NEW_VERSION_NO_V: ${{ steps.latest.outputs.new_version_no_v }}
run: |
set -euo pipefail
if [[ ! -f comfyui_version.py ]]; then
echo "::error::comfyui_version.py not found in repo root."
exit 1
fi
if [[ ! -f pyproject.toml ]]; then
echo "::error::pyproject.toml not found in repo root."
exit 1
fi
# Replace the version string in comfyui_version.py.
# Expected format: __version__ = "X.Y.Z"
python3 - "$NEW_VERSION_NO_V" <<'PY'
import re, sys, pathlib
new = sys.argv[1]
p = pathlib.Path("comfyui_version.py")
src = p.read_text()
new_src, n = re.subn(
r'(__version__\s*=\s*[\'"])[^\'"]+([\'"])',
lambda m: f'{m.group(1)}{new}{m.group(2)}',
src,
count=1,
)
if n != 1:
sys.exit("Could not find __version__ assignment in comfyui_version.py")
p.write_text(new_src)
p = pathlib.Path("pyproject.toml")
src = p.read_text()
# Replace the first `version = "..."` inside [project] or [tool.poetry].
new_src, n = re.subn(
r'(?m)^(version\s*=\s*")[^"]+(")',
lambda m: f'{m.group(1)}{new}{m.group(2)}',
src,
count=1,
)
if n != 1:
sys.exit("Could not find version assignment in pyproject.toml")
p.write_text(new_src)
PY
echo "Updated version to ${NEW_VERSION_NO_V} in comfyui_version.py and pyproject.toml."
git --no-pager diff -- comfyui_version.py pyproject.toml
- name: Commit version bump and tag release
env:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
run: |
set -euo pipefail
git add comfyui_version.py pyproject.toml
git commit -m "ComfyUI ${NEW_VERSION}"
if git rev-parse -q --verify "refs/tags/${NEW_VERSION}" >/dev/null; then
echo "::error::Tag ${NEW_VERSION} already exists locally."
exit 1
fi
git tag "${NEW_VERSION}"
- name: Verify tag does not already exist on origin
env:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
run: |
set -euo pipefail
if git ls-remote --exit-code --tags origin "refs/tags/${NEW_VERSION}" >/dev/null 2>&1; then
echo "::error::Tag ${NEW_VERSION} already exists on origin. Aborting."
exit 1
fi
- name: Push release branch and tag
env:
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
run: |
set -euo pipefail
# Push the branch first, then the tag. Atomic-ish: if the branch push
# fails we never publish the tag.
git push origin "refs/heads/${RELEASE_BRANCH}:refs/heads/${RELEASE_BRANCH}"
git push origin "refs/tags/${NEW_VERSION}"
echo "Released ${NEW_VERSION} on ${RELEASE_BRANCH}."
- name: Summary
if: always()
env:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
SOURCE_BRANCH: ${{ inputs.branch }}
run: |
{
echo "## Backport release"
echo ""
echo "| Field | Value |"
echo "|---|---|"
echo "| Source branch | \`${SOURCE_BRANCH}\` |"
echo "| Previous stable | \`${LATEST_TAG}\` |"
echo "| New version | \`${NEW_VERSION}\` |"
echo "| Release branch | \`${RELEASE_BRANCH}\` |"
} >> "$GITHUB_STEP_SUMMARY"

View File

@ -39,6 +39,8 @@ from app.assets.services import (
update_asset_metadata,
upload_from_temp_path,
)
from app.assets.services.cursor import InvalidCursorError
from app.assets.services.path_utils import compute_paths_for_response
from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef()
@ -160,10 +162,19 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
preview_url = None
else:
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
asset_content_hash = result.asset.hash if result.asset else None
if result.ref.file_path:
paths = compute_paths_for_response(result.ref.file_path)
file_path, display_name = paths if paths else (None, None)
else:
file_path, display_name = None, None
return schemas_out.Asset(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
file_path=file_path,
display_name=display_name,
hash=asset_content_hash,
asset_hash=asset_content_hash,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
@ -172,7 +183,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
user_metadata=result.ref.user_metadata or {},
metadata=result.ref.system_metadata,
job_id=result.ref.job_id,
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
prompt_id=result.ref.job_id, # deprecated alias of job_id, kept for compatibility
created_at=result.ref.created_at,
updated_at=result.ref.updated_at,
last_access_time=result.ref.last_access_time,
@ -209,24 +220,38 @@ async def list_assets_route(request: web.Request) -> web.Response:
order_candidate = (q.order or "desc").lower()
order = order_candidate if order_candidate in {"asc", "desc"} else "desc"
result = list_assets_page(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=sort,
order=order,
)
try:
result = list_assets_page(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
job_ids=q.job_ids,
limit=q.limit,
offset=q.offset,
sort=sort,
order=order,
after=q.after,
)
except InvalidCursorError as e:
return _build_error_response(400, "INVALID_CURSOR", str(e))
summaries = [_build_asset_response(item) for item in result.items]
# has_more semantics differ by mode:
# - cursor mode: a non-empty next_cursor means there are more results.
# - offset mode: derived from total - (offset + page size).
if q.after is not None:
has_more = result.next_cursor is not None
else:
has_more = (q.offset + len(summaries)) < result.total
payload = schemas_out.AssetsList(
assets=summaries,
total=result.total,
has_more=(q.offset + len(summaries)) < result.total,
has_more=has_more,
next_cursor=result.next_cursor,
)
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
@ -401,12 +426,16 @@ async def upload_asset(request: web.Request) -> web.Response:
)
if spec.tags and spec.tags[0] == "models":
# tag[1] may be the standalone category ("checkpoints") or the
# slash-joined shape ("checkpoints/flux/...") that
# `get_name_and_tags_from_asset_path` and cloud both emit. Match
# `resolve_destination_from_tags` by extracting the first segment.
category = spec.tags[1].split("/", 1)[0] if len(spec.tags) >= 2 else ""
if (
len(spec.tags) < 2
or spec.tags[1] not in folder_paths.folder_names_and_paths
or category not in folder_paths.folder_names_and_paths
):
delete_temp_file_if_exists(parsed.tmp_path)
category = spec.tags[1] if len(spec.tags) >= 2 else ""
return _build_error_response(
400, "INVALID_BODY", f"unknown models category '{category}'"
)

View File

@ -1,4 +1,5 @@
import json
import uuid
from dataclasses import dataclass
from typing import Any, Literal
@ -53,12 +54,18 @@ class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
job_ids: list[str] = Field(default_factory=list, max_length=500)
# Accept either a JSON string (query param) or a dict
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
# Opaque keyset cursor. When supplied, `offset` is ignored. Cursor pagination
# is supported for sort values `created_at`, `updated_at`, `name`, `size`.
# Supplying `after` together with `sort=last_access_time` returns
# 400 INVALID_CURSOR; that sort only supports offset/limit.
after: str | None = None
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
"created_at"
@ -81,6 +88,40 @@ class ListAssetsQuery(BaseModel):
return out
return v
@field_validator("job_ids", mode="before")
@classmethod
def _split_and_validate_job_ids(cls, v):
# Accept "uuid1,uuid2" or ["uuid1","uuid2"] or repeated query params.
# Each entry must parse as a UUID; canonicalized to lowercase hyphenated form.
if v is None:
return []
if isinstance(v, str):
raw = [t.strip() for t in v.split(",") if t.strip()]
elif isinstance(v, list):
raw = []
for item in v:
if not isinstance(item, str):
raise ValueError(
f"job_ids entries must be strings, got {type(item).__name__}"
)
raw.extend([t.strip() for t in item.split(",") if t.strip()])
else:
raise ValueError(
f"job_ids must be a string or list of strings, got {type(v).__name__}"
)
out: list[str] = []
seen: set[str] = set()
for s in raw:
try:
canonical = str(uuid.UUID(s))
except ValueError as e:
raise ValueError(f"job_ids must be UUIDs: {s!r}") from e
if canonical not in seen:
seen.add(canonical)
out.append(canonical)
return out
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):

View File

@ -9,7 +9,10 @@ class Asset(BaseModel):
``id`` here is the AssetReference id, not the content-addressed Asset id."""
id: str
name: str
name: str = Field(..., deprecated=True)
file_path: str | None = None
display_name: str | None = None
hash: str | None = None
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
@ -40,6 +43,8 @@ class AssetsList(BaseModel):
assets: list[Asset]
total: int
has_more: bool
# Opaque cursor for the next page. Omitted when there are no more results.
next_cursor: str | None = None
class TagUsage(BaseModel):

View File

@ -264,11 +264,21 @@ def list_references_page(
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
metadata_filter: dict | None = None,
job_ids: Sequence[str] | None = None,
sort: str | None = None,
order: str | None = None,
after_cursor_value: object | None = None,
after_cursor_id: str | None = None,
) -> tuple[list[AssetReference], dict[str, list[str]], int]:
"""List references with pagination, filtering, and sorting.
When ``after_cursor_value``/``after_cursor_id`` are supplied the query uses
keyset pagination — ``offset`` is ignored and a WHERE clause selects rows
strictly after the given ``(sort_col, id)`` position in the active sort
direction. The cursor value must already be typed for the column
(datetime for time sorts, int for size, str for name); the caller decodes
the opaque cursor string and resolves to the typed value.
Returns (references, tag_map, total_count).
"""
base = (
@ -284,6 +294,9 @@ def list_references_page(
escaped, esc = escape_sql_like_string(name_contains)
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
if job_ids:
base = base.where(AssetReference.job_id.in_(list(job_ids)))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
@ -297,9 +310,31 @@ def list_references_page(
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetReference.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
descending = order == "desc"
base = base.order_by(sort_exp).limit(limit).offset(offset)
# Keyset WHERE: (sort_col, id) strictly less-than / greater-than the cursor.
# Equivalent to: sort_col <op> v OR (sort_col = v AND id <op> cursor_id).
if after_cursor_value is not None and after_cursor_id is not None:
if descending:
keyset = sa.or_(
sort_col < after_cursor_value,
sa.and_(sort_col == after_cursor_value, AssetReference.id < after_cursor_id),
)
else:
keyset = sa.or_(
sort_col > after_cursor_value,
sa.and_(sort_col == after_cursor_value, AssetReference.id > after_cursor_id),
)
base = base.where(keyset)
# Secondary ORDER BY id (matching the primary direction) gives the keyset
# comparison a deterministic tiebreaker on duplicate sort_col values.
id_exp = AssetReference.id.desc() if descending else AssetReference.id.asc()
sort_exp = sort_col.desc() if descending else sort_col.asc()
base = base.order_by(sort_exp, id_exp).limit(limit)
if after_cursor_id is None:
base = base.offset(offset)
count_stmt = (
select(sa.func.count())
@ -314,6 +349,8 @@ def list_references_page(
count_stmt = count_stmt.where(
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
)
if job_ids:
count_stmt = count_stmt.where(AssetReference.job_id.in_(list(job_ids)))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
@ -327,7 +364,12 @@ def list_references_page(
select(AssetReferenceTag.asset_reference_id, Tag.name)
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
.order_by(AssetReferenceTag.tag_name.asc())
# Preserve insertion order so the structural first tag (the root
# category like "models") stays in position 0 and the path-derived
# sub-path tag stays in position 1, matching cloud's behavior.
# tag_name is a deterministic tiebreaker when multiple tags share
# an added_at (same-batch insert via set_reference_tags).
.order_by(AssetReferenceTag.added_at.asc(), AssetReferenceTag.tag_name.asc())
)
for ref_id, tag_name in rows.all():
tag_map[ref_id].append(tag_name)
@ -355,7 +397,8 @@ def fetch_reference_asset_and_tags(
build_visible_owner_clause(owner_id),
)
.options(noload(AssetReference.tags))
.order_by(Tag.name.asc())
# See list_references_page for the rationale behind ordering by added_at.
.order_by(AssetReferenceTag.added_at.asc(), Tag.name.asc())
)
rows = session.execute(stmt).all()

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Iterable, Sequence
import sqlalchemy as sa
@ -20,7 +21,12 @@ from app.assets.database.queries.common import (
build_visible_owner_clause,
iter_row_chunks,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
from app.assets.helpers import (
escape_sql_like_string,
expand_bucket_prefixes,
get_utc_now,
normalize_tags,
)
@dataclass(frozen=True)
@ -44,6 +50,26 @@ class SetTagsResult:
total: list[str]
def _next_added_at_base(session: Session, reference_id: str) -> datetime:
"""Return a timestamp strictly greater than any existing
`added_at` for this reference. On platforms where the wall clock
has insufficient resolution between back-to-back commits (notably
Windows), two write batches on the same reference can otherwise
share a microsecond — the `ORDER BY added_at, tag_name` retrieval
then falls back to the alphabetic tiebreaker and user-tier tags
sort ahead of path-tier tags they were meant to follow.
"""
existing_max = session.execute(
sa.select(sa.func.max(AssetReferenceTag.added_at)).where(
AssetReferenceTag.asset_reference_id == reference_id
)
).scalar()
now = get_utc_now()
if existing_max is None:
return now
return max(existing_max + timedelta(microseconds=1), now)
def validate_tags_exist(session: Session, tags: list[str]) -> None:
"""Raise ValueError if any of the given tag names do not exist."""
existing_tag_names = set(
@ -77,7 +103,13 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
session.execute(
select(AssetReferenceTag.tag_name)
.where(AssetReferenceTag.asset_reference_id == reference_id)
.order_by(AssetReferenceTag.tag_name.asc())
# Match the response-path ordering used by
# list_references_page / fetch_reference_asset_and_tags so
# upload responses and subsequent GETs agree on tag order.
.order_by(
AssetReferenceTag.added_at.asc(),
AssetReferenceTag.tag_name.asc(),
)
)
).all()
]
@ -89,7 +121,7 @@ def set_reference_tags(
tags: Sequence[str],
origin: str = "manual",
) -> SetTagsResult:
desired = normalize_tags(tags)
desired = expand_bucket_prefixes(normalize_tags(tags))
current = set(get_reference_tags(session, reference_id))
@ -98,15 +130,22 @@ def set_reference_tags(
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
# Stagger added_at by microsecond per tag so the retrieval ORDER BY
# added_at preserves input order. Per-tag get_utc_now() calls can
# collide at microsecond resolution on fast machines, dropping the
# query to the tag_name alphabetical tiebreaker — same fix as in
# batch_insert_seed_assets. Read max(existing) so this batch sorts
# strictly after any prior batch on the same reference.
base_ts = _next_added_at_base(session, reference_id)
session.add_all(
[
AssetReferenceTag(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
added_at=base_ts + timedelta(microseconds=i),
)
for t in to_add
for i, t in enumerate(to_add)
]
)
session.flush()
@ -136,7 +175,7 @@ def add_tags_to_reference(
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
norm = normalize_tags(tags)
norm = expand_bucket_prefixes(normalize_tags(tags))
if not norm:
total = get_reference_tags(session, reference_id=reference_id)
return AddTagsResult(added=[], already_present=[], total_tags=total)
@ -146,10 +185,17 @@ def add_tags_to_reference(
current = set(get_reference_tags(session, reference_id))
# Preserve the caller's insertion order rather than alphabetizing —
# the retrieval ORDER BY added_at + microsecond stagger only meaningfully
# preserves insertion order if "the order we insert in" actually matches
# the caller's intent.
want = set(norm)
to_add = sorted(want - current)
to_add = [t for t in norm if t not in current]
if to_add:
# See set_reference_tags for the rationale behind the per-tag stagger
# and the max(existing) seed.
base_ts = _next_added_at_base(session, reference_id)
with session.begin_nested() as nested:
try:
session.add_all(
@ -158,9 +204,9 @@ def add_tags_to_reference(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
added_at=base_ts + timedelta(microseconds=i),
)
for t in to_add
for i, t in enumerate(to_add)
]
)
session.flush()

View File

@ -47,6 +47,50 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
def _known_bucket_prefixes() -> set[str]:
"""Lowercased model-category names eligible for standalone-prefix
expansion. Tags whose first slash segment matches one of these get
the bucket inserted as a separate token, so FE filters like
``include_tags=models,checkpoints`` keep matching even when the
asset lives in a nested subfolder (`models/checkpoints/flux/foo`).
Bare user labels with slashes whose first segment is not a registered
bucket (e.g. ``my-org/team-a``) pass through unchanged.
"""
try:
import folder_paths
return {
name.lower()
for name in folder_paths.folder_names_and_paths.keys()
if name != "custom_nodes"
}
except Exception:
return set()
def expand_bucket_prefixes(tags: list[str]) -> list[str]:
"""Insert standalone bucket tokens after any slash-joined tag whose
first segment is a registered model category. Preserves caller order
and is idempotent (existing bucket tokens are not duplicated).
"""
if not tags:
return list(tags)
buckets = _known_bucket_prefixes()
if not buckets:
return list(tags)
seen = set(tags)
result: list[str] = []
for t in tags:
result.append(t)
if "/" in t:
prefix = t.split("/", 1)[0]
if prefix.lower() in buckets and prefix not in seen:
result.append(prefix)
seen.add(prefix)
return result
def validate_blake3_hash(s: str) -> str:
"""Validate and normalize a blake3 hash string.

View File

@ -33,6 +33,7 @@ from app.assets.services.file_utils import (
verify_file_unchanged,
)
from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
from app.assets.services.image_dimensions import extract_image_dimensions
from app.assets.services.metadata_extract import extract_file_metadata
from app.assets.services.path_utils import (
compute_relative_filename,
@ -506,6 +507,10 @@ def enrich_asset(
if extract_metadata and metadata:
system_metadata = metadata.to_user_metadata()
if mime_type and mime_type.startswith("image/"):
dims = extract_image_dimensions(file_path, mime_type=mime_type)
if dims:
system_metadata.update(dims)
set_reference_system_metadata(session, reference_id, system_metadata)
if full_hash:

View File

@ -1,8 +1,19 @@
import contextlib
import mimetypes
import os
from datetime import timezone
from typing import Sequence
from app.assets.services.cursor import (
CursorPayload,
InvalidCursorError,
decode_cursor,
decode_cursor_int,
decode_cursor_time,
encode_cursor,
encode_cursor_from_time,
)
from app.assets.database.models import Asset
from app.assets.database.queries import (
@ -242,17 +253,55 @@ def get_asset_by_hash(asset_hash: str) -> AssetData | None:
return extract_asset_data(asset)
# Sort fields that support cursor pagination. `last_access_time` is not
# in this list — it falls back to offset/limit.
_CURSOR_SORT_FIELDS = ("created_at", "updated_at", "name", "size")
def list_assets_page(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
job_ids: Sequence[str] | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
after: str | None = None,
) -> ListAssetsResult:
"""List assets with optional cursor pagination.
When ``after`` is supplied it overrides ``offset``. The cursor's sort field
must match ``sort`` and be in the cursor-supported allowlist; mismatches
raise InvalidCursorError so the handler can map to 400 INVALID_CURSOR.
"""
cursor_value: object | None = None
cursor_id: str | None = None
# Mint next_cursor on every page where the sort is cursor-supported, not
# only when the request itself arrived with a cursor. Otherwise a first
# request (no `after`) returns next_cursor=None and the client can never
# enter cursor mode.
mint_cursor = sort in _CURSOR_SORT_FIELDS
if after is not None:
if sort not in _CURSOR_SORT_FIELDS:
raise InvalidCursorError(
f"cursor pagination is not supported for sort={sort!r}"
)
payload = decode_cursor(after, _CURSOR_SORT_FIELDS, expected_order=order)
if payload.sort_field != sort:
raise InvalidCursorError(
f"cursor sort field {payload.sort_field!r} does not match request sort {sort!r}"
)
cursor_value, cursor_id = _resolve_cursor_value(payload), payload.id
# Over-fetch by one row so we can distinguish "exactly `limit` rows total
# remaining" from "more rows past this page" without a second query. Drop
# the sentinel before returning.
fetch_limit = limit + 1 if mint_cursor else limit
with create_session() as session:
refs, tag_map, total = list_references_page(
session,
@ -261,12 +310,23 @@ def list_assets_page(
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
job_ids=job_ids,
limit=fetch_limit,
offset=offset,
sort=sort,
order=order,
after_cursor_value=cursor_value,
after_cursor_id=cursor_id,
)
next_cursor: str | None = None
if mint_cursor and len(refs) > limit:
# There's at least one more row past this page — mint a cursor from
# the last row of the page (i.e. index `limit - 1`, since we
# over-fetched), and drop the sentinel.
next_cursor = _encode_next_cursor(refs[limit - 1], sort, order)
refs = refs[:limit]
items: list[AssetSummaryData] = []
for ref in refs:
items.append(
@ -277,7 +337,39 @@ def list_assets_page(
)
)
return ListAssetsResult(items=items, total=total)
return ListAssetsResult(items=items, total=total, next_cursor=next_cursor)
def _resolve_cursor_value(payload: CursorPayload) -> object:
"""Map a decoded cursor payload to a column-typed Python value."""
if payload.sort_field in ("created_at", "updated_at"):
# DB stores naive UTC; strip tzinfo so the comparison binds against a
# `TIMESTAMP WITHOUT TIME ZONE` column without an offset shift.
return decode_cursor_time(payload).replace(tzinfo=None)
if payload.sort_field == "size":
return decode_cursor_int(payload)
return payload.value # name, str-typed
def _encode_next_cursor(ref, sort: str, order: str) -> str | None:
"""Mint a cursor pointing at *ref* for the given sort dimension.
Returns None when the boundary row carries a NULL sort value (e.g. an asset
record whose size_bytes hasn't been backfilled). Continuing pagination
across a NULL boundary is undefined under keyset ordering — better to
truncate cleanly here than to mint a cursor that mis-positions.
"""
if sort == "name":
return encode_cursor("name", ref.name, ref.id, order=order)
if sort == "size":
if ref.asset is None or ref.asset.size_bytes is None:
return None
return encode_cursor("size", str(ref.asset.size_bytes), ref.id, order=order)
# created_at / updated_at — DB datetimes are naive UTC; attach tz before encoding.
value = ref.created_at if sort == "created_at" else ref.updated_at
if value is None:
return None
return encode_cursor_from_time(sort, value.replace(tzinfo=timezone.utc), ref.id, order=order)
def resolve_hash_to_path(

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import os
import uuid
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, TypedDict
from sqlalchemy.orm import Session
@ -13,13 +13,14 @@ from app.assets.database.queries import (
bulk_insert_references_ignore_conflicts,
bulk_insert_tags_and_meta,
delete_assets_by_ids,
ensure_tags_exist,
get_existing_asset_ids,
get_reference_ids_by_ids,
get_references_by_paths_and_asset_ids,
get_unreferenced_unhashed_asset_ids,
restore_references_by_paths,
)
from app.assets.helpers import get_utc_now
from app.assets.helpers import expand_bucket_prefixes, get_utc_now
if TYPE_CHECKING:
from app.assets.services.metadata_extract import ExtractedMetadata
@ -233,13 +234,20 @@ def batch_insert_seed_assets(
if ref_id not in inserted_ref_ids:
continue
for tag in ref_data["tags"]:
# Stagger added_at by microsecond per tag within a reference so
# the retrieval ORDER BY added_at preserves the input list order
# (the path-derived root category stays at position 0). Without
# this, every tag in a bulk-insert batch shares current_time and
# the tag_name tiebreaker sorts them alphabetically — putting the
# subpath tag ahead of "models" since "c"/"d"/"l" < "m".
ref_tags = expand_bucket_prefixes(ref_data["tags"])
for tag_idx, tag in enumerate(ref_tags):
tag_rows.append(
{
"asset_reference_id": ref_id,
"tag_name": tag,
"origin": "automatic",
"added_at": current_time,
"added_at": current_time + timedelta(microseconds=tag_idx),
}
)
@ -261,6 +269,16 @@ def batch_insert_seed_assets(
}
)
if tag_rows:
# Bucket-prefix expansion may have introduced tags the caller did
# not register via the upstream tag_pool (e.g. `checkpoints` for a
# nested `checkpoints/flux/foo` path). Pre-register the full set so
# the AssetReferenceTag.tag_name FK is satisfied; the underlying
# insert is ON CONFLICT DO NOTHING so re-registration is idempotent.
ensure_tags_exist(
session, {row["tag_name"] for row in tag_rows}, tag_type="user"
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows)
return BulkInsertResult(

View File

@ -0,0 +1,225 @@
"""Opaque keyset-pagination cursor for /api/assets.
Payload JSON uses short keys to keep the encoded length small:
{"s": <sort_field>, "v": <value>, "id": <id>, "o": <order>}
The `o` key binds the cursor to the sort direction it was minted under,
so replaying a `desc` cursor against an `asc` request fails with
``INVALID_CURSOR`` rather than silently walking the wrong direction.
`o` is mandatory on every payload — a cursor without it is rejected as
malformed.
Encoding is base64url with no padding. JSON serialization escapes `<`,
`>`, `&`, U+2028, and U+2029 in encoded string values so asset names
containing those characters produce a stable, byte-identical wire form
across any compatible implementation of the same payload format.
Time values are serialized as Unix microseconds (UTC) — microsecond
precision is sufficient to round-trip the timestamps stored by the
database without rounding rows in the same millisecond bucket.
"""
from __future__ import annotations
import base64
import json
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Iterable, Optional
class InvalidCursorError(ValueError):
"""Raised on a malformed, oversized, or unsupported-sort-field cursor.
Map to a 400 response with code ``INVALID_CURSOR`` at the handler.
"""
# Wire-format length caps. Cursors are user-controlled, so caps protect the
# decode path from oversized allocations and downstream SQL predicates from
# unbounded strings.
#
# MAX_CURSOR_VALUE_LENGTH is 512 to fit the `AssetReference.name` column max
# (`String(512)`) — otherwise a long-named asset would mint a cursor the same
# server then refuses on the next request.
MAX_ENCODED_CURSOR_LENGTH = 1024
MAX_CURSOR_VALUE_LENGTH = 512
MAX_CURSOR_ID_LENGTH = 128
@dataclass(frozen=True)
class CursorPayload:
sort_field: str
value: str
id: str
order: str
_VALID_ORDERS = ("asc", "desc")
def encode_cursor(sort_field: str, value: str, id: str, order: str = "desc") -> str:
"""Encode a cursor payload as a base64url (no-padding) string.
`order` binds the cursor to the sort direction it was minted under so a
later request with a flipped `order` query parameter is rejected with
``INVALID_CURSOR`` rather than silently walking the wrong direction.
"""
if order not in _VALID_ORDERS:
raise InvalidCursorError(f"order must be one of {_VALID_ORDERS}, got {order!r}")
# Symmetric input validation: the encoder must reject anything the
# decoder rejects, or the same server will mint cursors it then 400s on
# the next request.
if not id:
raise InvalidCursorError("id must be non-empty")
if len(id) > MAX_CURSOR_ID_LENGTH:
raise InvalidCursorError("id exceeds maximum length")
if len(value) > MAX_CURSOR_VALUE_LENGTH:
raise InvalidCursorError("value exceeds maximum length")
payload = {"s": sort_field, "v": value, "id": id, "o": order}
raw = json.dumps(payload, separators=(",", ":"), ensure_ascii=False)
# Match the default JSON escaping of HTML-significant characters and JS
# line/paragraph separators (U+2028 / U+2029) so an asset name carrying
# any of them encodes to identical bytes across runtimes. None of these
# characters appear in JSON structural syntax, so a global replace on the
# serialized output can only touch encoded values. Use explicit \uXXXX
# escapes for U+2028 / U+2029 so the source survives any editor / git
# tooling that normalizes invisible separators.
raw = (
raw.replace("<", "\\u003c")
.replace(">", "\\u003e")
.replace("&", "\\u0026")
.replace("\u2028", "\\u2028")
.replace("\u2029", "\\u2029")
)
encoded = base64.urlsafe_b64encode(raw.encode("utf-8")).rstrip(b"=").decode("ascii")
# Final wire-size guard: the per-field caps above are char-counted, but the
# wire cap applies to the base64url of the UTF-8-encoded, escape-expanded
# payload. A value full of multibyte or HTML-significant characters (e.g.
# 512 \u00d7 "\u00e9" or 512 \u00d7 "<") inflates well past MAX_ENCODED_CURSOR_LENGTH even
# though it passes the char-count check. Refuse to mint a cursor the decoder
# on the next request would reject.
if len(encoded) > MAX_ENCODED_CURSOR_LENGTH:
raise InvalidCursorError("encoded cursor exceeds maximum length")
return encoded
def encode_cursor_from_time(sort_field: str, t: datetime, id: str, order: str = "desc") -> str:
"""Encode a time-typed cursor at Unix microsecond precision.
Accepts an aware datetime (any timezone) and normalizes to UTC. Naive
datetimes are rejected so callers can't accidentally encode the local
wall-clock value of a UTC-stored timestamp.
"""
if t.tzinfo is None:
raise ValueError("encode_cursor_from_time requires an aware datetime")
micros = _datetime_to_unix_micros(t.astimezone(timezone.utc))
return encode_cursor(sort_field, str(micros), id, order=order)
def decode_cursor(
cursor: str,
allowed_sort_fields: Iterable[str],
expected_order: str | None = None,
) -> CursorPayload:
"""Parse an opaque cursor.
``allowed_sort_fields`` is the endpoint's accepted sort-field list — a
cursor carrying a field outside this set is rejected so a cursor minted
for one column can't be replayed against another (e.g. a ``created_at``
timestamp string compared against a ``name`` column).
``expected_order`` (``"asc"``/``"desc"``), when supplied, must match the
payload's ``o`` field. ``o`` is required on every payload; a cursor
missing it is rejected as malformed.
Passing no allowed fields rejects every cursor.
"""
if len(cursor) > MAX_ENCODED_CURSOR_LENGTH:
raise InvalidCursorError("cursor exceeds maximum length")
try:
# urlsafe_b64decode requires correct padding; we strip on encode, so
# restore the trailing '=' pad here.
padding = "=" * (-len(cursor) % 4)
raw = base64.urlsafe_b64decode(cursor + padding)
except (ValueError, base64.binascii.Error) as e:
raise InvalidCursorError(f"encoding: {e}") from e
try:
decoded = json.loads(raw)
except (json.JSONDecodeError, UnicodeDecodeError) as e:
raise InvalidCursorError(f"payload: {e}") from e
if not isinstance(decoded, dict):
raise InvalidCursorError("payload: expected object")
sort_field = decoded.get("s")
value = decoded.get("v")
id = decoded.get("id")
order = decoded.get("o")
if not isinstance(sort_field, str) or not isinstance(value, str) or not isinstance(id, str):
raise InvalidCursorError("payload: missing or non-string s/v/id")
if id == "":
raise InvalidCursorError("missing id")
if len(id) > MAX_CURSOR_ID_LENGTH:
raise InvalidCursorError("id exceeds maximum length")
if len(value) > MAX_CURSOR_VALUE_LENGTH:
raise InvalidCursorError("value exceeds maximum length")
if sort_field not in allowed_sort_fields:
raise InvalidCursorError(f"unsupported sort field {sort_field!r}")
if not isinstance(order, str):
raise InvalidCursorError("missing or non-string o")
if order not in _VALID_ORDERS:
raise InvalidCursorError(f"unsupported order {order!r}")
if expected_order is not None and order != expected_order:
raise InvalidCursorError(
f"cursor order {order!r} does not match request order {expected_order!r}"
)
return CursorPayload(sort_field=sort_field, value=value, id=id, order=order)
def decode_cursor_time(payload: Optional[CursorPayload]) -> datetime:
"""Parse a time-typed cursor value as Unix microseconds, returning UTC."""
if payload is None:
raise InvalidCursorError("nil cursor payload")
try:
micros = int(payload.value)
except ValueError as e:
raise InvalidCursorError(f"value is not a valid timestamp: {e}") from e
try:
return _unix_micros_to_datetime(micros)
except (OverflowError, OSError, ValueError) as e:
# Crafted out-of-range microseconds (e.g. > datetime.MAX_YEAR) blow up
# in fromtimestamp / datetime construction. Map to 400, not 500.
raise InvalidCursorError(f"value is out of representable range: {e}") from e
def decode_cursor_int(payload: Optional[CursorPayload]) -> int:
"""Parse a cursor value as a base-10 integer."""
if payload is None:
raise InvalidCursorError("nil cursor payload")
try:
return int(payload.value)
except ValueError as e:
raise InvalidCursorError(f"value is not a valid integer: {e}") from e
_EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc)
def _datetime_to_unix_micros(t: datetime) -> int:
"""Convert an aware UTC datetime to Unix microseconds (integer math)."""
delta = t - _EPOCH
return (delta.days * 86_400 + delta.seconds) * 1_000_000 + delta.microseconds
def _unix_micros_to_datetime(micros: int) -> datetime:
"""Convert Unix microseconds to a UTC datetime, preserving precision."""
seconds, micro_remainder = divmod(micros, 1_000_000)
return datetime.fromtimestamp(seconds, tz=timezone.utc).replace(microsecond=micro_remainder)

View File

@ -0,0 +1,63 @@
"""Image dimension extraction for asset ingest.
Reads only the image header via Pillow to capture width/height cheaply,
without a full pixel decode. Returns a metadata dict suitable for merging
into ``AssetReference.system_metadata``.
"""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
def extract_image_dimensions(
file_path: str, mime_type: str | None = None
) -> dict[str, Any] | None:
"""Extract image dimensions for the file at ``file_path``.
Args:
file_path: Absolute path to a file on disk.
mime_type: Optional MIME type hint. When provided and not prefixed
with ``image/``, extraction is skipped without touching the file.
Returns:
``{"kind": "image", "width": W, "height": H}`` when the file is a
recognizable image with positive dimensions, otherwise ``None``.
The dict shape is intended to be merged into ``system_metadata`` so the
asset response surfaces ``metadata.kind`` plus dimension fields for image
assets. Forward-compatible: future media kinds (e.g. ``"video"`` with
duration/fps) can extend this shape without schema changes.
"""
if mime_type is not None and not mime_type.startswith("image/"):
return None
try:
from PIL import Image, UnidentifiedImageError
except ImportError:
logger.debug(
"Pillow not available; skipping image dimension extraction for %s",
file_path,
)
return None
try:
with Image.open(file_path) as img:
width, height = img.size
except (OSError, UnidentifiedImageError, ValueError) as exc:
logger.debug(
"Failed to read image dimensions from %s: %s", file_path, exc
)
return None
if (
not isinstance(width, int)
or not isinstance(height, int)
or width <= 0
or height <= 0
):
return None
return {"kind": "image", "width": width, "height": height}

View File

@ -17,9 +17,11 @@ from app.assets.database.queries import (
get_reference_by_file_path,
get_reference_tags,
get_or_create_reference,
list_references_by_asset_id,
reference_exists,
remove_missing_tag_for_asset_id,
set_reference_metadata,
set_reference_system_metadata,
set_reference_tags,
update_asset_hash_and_mime,
upsert_asset,
@ -29,6 +31,7 @@ from app.assets.database.queries import (
from app.assets.helpers import get_utc_now, normalize_tags
from app.assets.services.bulk_ingest import batch_insert_seed_assets
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.image_dimensions import extract_image_dimensions
from app.assets.services.path_utils import (
compute_relative_filename,
get_name_and_tags_from_asset_path,
@ -118,6 +121,14 @@ def _ingest_file_from_path(
user_metadata=user_metadata,
)
_maybe_store_image_dimensions(
session,
reference_id=reference_id,
file_path=locator,
mime_type=mime_type,
current_system_metadata=ref.system_metadata,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
@ -288,6 +299,13 @@ def _register_existing_asset(
user_metadata=new_meta,
)
_backfill_image_dimensions_from_siblings(
session,
asset_id=asset.id,
new_reference_id=ref.id,
current_system_metadata=ref.system_metadata,
)
if tags is not None:
set_reference_tags(
session,
@ -334,6 +352,87 @@ def _update_metadata_with_filename(
)
_IMAGE_DIMENSION_KEYS = ("kind", "width", "height")
def _maybe_store_image_dimensions(
session: Session,
reference_id: str,
file_path: str,
mime_type: str | None,
current_system_metadata: dict | None,
) -> None:
"""Populate ``kind``/``width``/``height`` on system_metadata for image refs.
Non-image MIME types are a no-op. Pre-existing keys (e.g. enricher-written
safetensors metadata, download provenance) are preserved by merge.
"""
if not mime_type or not mime_type.startswith("image/"):
return
dims = extract_image_dimensions(file_path, mime_type=mime_type)
if not dims:
return
current = current_system_metadata or {}
merged = dict(current)
merged.update(dims)
if merged != current:
set_reference_system_metadata(
session,
reference_id=reference_id,
system_metadata=merged,
)
def _backfill_image_dimensions_from_siblings(
session: Session,
asset_id: str,
new_reference_id: str,
current_system_metadata: dict | None,
) -> None:
"""Copy image dimension keys from any sibling reference of the same asset.
The from-hash path doesn't read the file bytes, so dimensions can't be
extracted there directly. When another reference of the same asset already
carries image dimensions, copy them onto the new reference so consumers
see consistent metadata regardless of how the asset was registered.
Best-effort: missing siblings, non-image siblings, or absent dimension
keys leave the target reference unchanged.
"""
current = current_system_metadata or {}
if current.get("kind") == "image" and "width" in current and "height" in current:
return
for sibling in list_references_by_asset_id(session, asset_id):
if sibling.id == new_reference_id:
continue
meta = sibling.system_metadata or {}
if meta.get("kind") != "image":
continue
width = meta.get("width")
height = meta.get("height")
if (
type(width) is not int
or type(height) is not int
or width <= 0
or height <= 0
):
continue
merged = dict(current)
merged["kind"] = "image"
merged["width"] = width
merged["height"] = height
if merged != current:
set_reference_system_metadata(
session,
reference_id=new_reference_id,
system_metadata=merged,
)
return
def _sanitize_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
return n if n else fallback

View File

@ -3,11 +3,12 @@ from pathlib import Path
from typing import Literal
import folder_paths
from app.assets.helpers import normalize_tags
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
RootCategory = Literal["input", "output", "temp", "models"]
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build list of (folder_name, base_paths[]) for all model locations.
@ -27,27 +28,51 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
"""Validates and maps tags -> (base_dir, subdirs_for_fs).
Accepts both the legacy one-tag-per-directory shape
(``["models", "diffusers", "Kolors", "text_encoder"]``) and the
slash-joined shape emitted by :func:`get_name_and_tags_from_asset_path`
(``["models", "diffusers/Kolors/text_encoder"]``). Hybrid shapes that
mix the two within a single call (e.g.
``["models", "diffusers", "Kolors/text_encoder"]``) are also
accepted: each entry after ``tags[0]`` is split on ``/`` and
concatenated, so the two shapes — and any mix of them — resolve to
the same destination. The same safety checks are applied to each
component after expansion.
"""
if not tags:
raise ValueError("tags must not be empty")
root = tags[0].lower()
# Expand any slash-joined entries into individual path components so
# the rest of the function can treat both tag shapes uniformly. Each
# component is also stripped, so " a / b " behaves like ["a", "b"].
expanded: list[str] = []
for t in tags[1:]:
for part in str(t).split("/"):
part = part.strip()
if part:
expanded.append(part)
if root == "models":
if len(tags) < 2:
if not expanded:
raise ValueError("at least two tags required for model asset")
category = expanded[0]
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
bases = folder_paths.folder_names_and_paths[category][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
raise ValueError(f"unknown model category '{category}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
raise ValueError(f"no base path configured for category '{category}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
raw_subdirs = expanded[1:]
elif root == "input":
base_dir = os.path.abspath(folder_paths.get_input_directory())
raw_subdirs = tags[1:]
raw_subdirs = expanded
elif root == "output":
base_dir = os.path.abspath(folder_paths.get_output_directory())
raw_subdirs = tags[1:]
raw_subdirs = expanded
else:
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
_sep_chars = frozenset(("/", "\\", os.sep))
@ -65,35 +90,109 @@ def validate_path_within_base(candidate: str, base: str) -> None:
raise ValueError("destination escapes base directory")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
def compute_paths_for_response(
file_path: str,
) -> tuple[str, str | None] | None:
"""Compute (file_path, display_name) for an Asset response.
For non-model paths, returns None.
`file_path` is a logical locator under the asset namespace: `<root>/<rel>`
for input/output/temp assets and `<root>/<bucket>/<rel>` for model assets.
`display_name` is the path below that root or model bucket, suitable for UI
labels. Returns None when the absolute path is not under a known asset root.
"""
try:
root_category, rel_path = get_asset_category_and_relative_path(file_path)
root, bucket, rel = get_asset_root_bucket_and_filepath(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
display_name = rel or None
if bucket is None:
response_file_path = f"{root}/{rel}" if rel else root
else:
response_file_path = f"{root}/{bucket}/{rel}" if rel else f"{root}/{bucket}"
return response_file_path, display_name
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def compute_display_name(file_path: str) -> str | None:
"""Return the asset's `display_name`, or None for unknown paths."""
result = compute_paths_for_response(file_path)
return result[1] if result else None
def compute_file_path(file_path: str) -> str | None:
"""Return the asset's logical `file_path`, or None for unknown paths."""
result = compute_paths_for_response(file_path)
return result[0] if result else None
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the path relative to the asset root or model category, using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
/.../input/sub/image.png -> "sub/image.png"
For unknown paths, returns None.
"""
return compute_display_name(file_path)
def get_asset_root_bucket_and_filepath(
file_path: str,
) -> tuple[RootCategory, str | None, str]:
"""Decompose an absolute path into (root, bucket, path-under-bucket).
`bucket` is only set for model assets. The returned relative path always
uses `/` separators and is empty when the path is exactly the matched root.
Raises:
ValueError: path does not belong to any known root.
"""
fp_abs = os.path.abspath(file_path)
def _check_is_within(child: str, parent: str) -> bool:
return Path(child).is_relative_to(parent)
def _compute_relative(child: str, parent: str) -> str:
# Normalize relative path, stripping any leading ".." components
# by anchoring to root (os.sep) then computing relpath back from it.
rel = os.path.relpath(
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
)
return "" if rel == "." else rel.replace(os.sep, "/")
for root_tag, getter in (
("input", folder_paths.get_input_directory),
("output", folder_paths.get_output_directory),
("temp", folder_paths.get_temp_directory),
):
base = os.path.abspath(getter())
if _check_is_within(fp_abs, base):
return root_tag, None, _compute_relative(fp_abs, base)
# models: check deepest matching base to avoid ambiguity.
best: tuple[int, str, str] | None = None
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _check_is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
return "models", bucket, rel_inside
raise ValueError(
f"Path is not within input, output, temp, or configured model bases: {file_path}"
)
def get_asset_category_and_relative_path(
file_path: str,
) -> tuple[Literal["input", "output", "temp", "models"], str]:
) -> tuple[RootCategory, str]:
"""Determine which root category a file path belongs to.
Categories:
@ -160,7 +259,21 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return (name, tags) derived from a filesystem path.
- name: base filename with extension
- tags: [root_category] + parent folder names in order
- tags: [root_category] for paths with no parent subdirectories,
[root_category, slash_joined_subpath] otherwise. The parent subpath
(everything between the root category and the filename) is collapsed
into a single tag rather than emitted as one tag per directory, so
consumers can use ``tags[1]`` as a stable category identifier that
survives nested directory layouts (e.g. diffusers components).
The subpath is lowercased to match the canonicalization applied by
:func:`ensure_tags_exist`; without that, the
``asset_reference_tags.tag_name`` FK to the lowercased ``tags.name``
would fail for any path containing uppercase letters. The root
category is lowercase by construction in
:func:`get_asset_category_and_relative_path`, so no separate cast
is applied here. Consumers that need to look up providers keyed on
original-case paths should normalize their lookup key to lowercase.
Raises:
ValueError: path does not belong to any known root.
@ -170,4 +283,7 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
parent_parts = [
part for part in p.parent.parts if part not in (".", "..", p.anchor)
]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
tags = [root_category]
if parent_parts:
tags.append("/".join(parent_parts).lower())
return p.name, list(dict.fromkeys(t.strip() for t in tags if t.strip()))

View File

@ -71,6 +71,7 @@ class AssetSummaryData:
class ListAssetsResult:
items: list[AssetSummaryData]
total: int
next_cursor: str | None = None
@dataclass(frozen=True)

View File

@ -1613,6 +1613,16 @@ class ModelPatcherDynamic(ModelPatcher):
#use all ModelPatcherDynamic this is ignored and its all done dynamically.
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
def restore_loaded_backups(self):
restored = self.model.model_loaded_weight_memory
for key in list(self.backup.keys()):
bk = self.backup.pop(key)
comfy.utils.set_attr_param(self.model, key, bk.weight)
for key in list(self.backup_buffers.keys()):
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
self.model.model_loaded_weight_memory = 0
return restored
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
@ -1629,7 +1639,7 @@ class ModelPatcherDynamic(ModelPatcher):
num_patches = 0
allocated_size = 0
self.model.model_loaded_weight_memory = 0
self.restore_loaded_backups()
with self.use_ejected():
self.unpatch_hooks()
@ -1716,6 +1726,9 @@ class ModelPatcherDynamic(ModelPatcher):
force_load=True
if force_load:
if hasattr(m, "_v"):
comfy_aimdo.model_vbar.vbar_unpin(m._v)
delattr(m, "_v")
force_load_param(self, "weight", device_to)
force_load_param(self, "bias", device_to)
else:
@ -1773,13 +1786,7 @@ class ModelPatcherDynamic(ModelPatcher):
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
if freed < memory_to_free:
for key in list(self.backup.keys()):
bk = self.backup.pop(key)
comfy.utils.set_attr_param(self.model, key, bk.weight)
for key in list(self.backup_buffers.keys()):
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
freed += self.model.model_loaded_weight_memory
self.model.model_loaded_weight_memory = 0
freed += self.restore_loaded_backups()
return freed

View File

@ -1019,10 +1019,11 @@ def bislerp(samples, width, height):
def lanczos(samples, width, height):
#the below API is strict and expects grayscale to be squeezed
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
if samples.ndim == 4:
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
images = [torch.from_numpy(t).movedim(-1, 0) if (t := np.array(image).astype(np.float32) / 255.0).ndim == 3 else torch.from_numpy(t) for image in images]
result = torch.stack(images)
return result.to(samples.device, samples.dtype)

View File

@ -35,6 +35,19 @@ class AnthropicMessage(BaseModel):
content: list[AnthropicTextContent | AnthropicImageContent] = Field(...)
class AnthropicThinkingConfig(BaseModel):
type: Literal["enabled", "disabled", "adaptive"] = Field(...)
budget_tokens: int | None = Field(
None, ge=1024,
description="Reasoning budget in tokens. Used when type is 'enabled'. Must be less than max_tokens.",
)
class AnthropicOutputConfig(BaseModel):
"""Used with `thinking.type='adaptive'` on models like Opus 4.7."""
effort: Literal["low", "medium", "high"] | None = Field(None)
class AnthropicMessagesRequest(BaseModel):
model: str = Field(...)
messages: list[AnthropicMessage] = Field(...)
@ -44,6 +57,8 @@ class AnthropicMessagesRequest(BaseModel):
top_p: float | None = Field(None, ge=0.0, le=1.0)
top_k: int | None = Field(None, ge=0)
stop_sequences: list[str] | None = Field(None)
thinking: AnthropicThinkingConfig | None = Field(None)
output_config: AnthropicOutputConfig | None = Field(None)
class AnthropicResponseTextBlock(BaseModel):
@ -51,6 +66,14 @@ class AnthropicResponseTextBlock(BaseModel):
text: str = Field(...)
class AnthropicResponseThinkingBlock(BaseModel):
type: Literal["thinking"] = "thinking"
thinking: str = Field(...)
AnthropicResponseBlock = AnthropicResponseTextBlock | AnthropicResponseThinkingBlock
class AnthropicCacheCreationUsage(BaseModel):
ephemeral_5m_input_tokens: int | None = Field(None)
ephemeral_1h_input_tokens: int | None = Field(None)
@ -69,7 +92,7 @@ class AnthropicMessagesResponse(BaseModel):
type: str | None = Field(None)
role: str | None = Field(None)
model: str | None = Field(None)
content: list[AnthropicResponseTextBlock] | None = Field(None)
content: list[AnthropicResponseBlock] | None = Field(None)
stop_reason: str | None = Field(None)
stop_sequence: str | None = Field(None)
usage: AnthropicMessagesUsage | None = Field(None)

View File

@ -0,0 +1,93 @@
"""Pydantic models for the OpenRouter chat completions API.
See: https://openrouter.ai/docs/api/api-reference/chat/send-chat-completion-request
"""
from typing import Literal
from pydantic import BaseModel, Field
class OpenRouterTextContent(BaseModel):
type: Literal["text"] = "text"
text: str = Field(...)
class OpenRouterImageUrl(BaseModel):
url: str = Field(...)
class OpenRouterImageContent(BaseModel):
type: Literal["image_url"] = "image_url"
image_url: OpenRouterImageUrl = Field(...)
class OpenRouterVideoUrl(BaseModel):
url: str = Field(...)
class OpenRouterVideoContent(BaseModel):
type: Literal["video_url"] = "video_url"
video_url: OpenRouterVideoUrl = Field(...)
OpenRouterContentBlock = OpenRouterTextContent | OpenRouterImageContent | OpenRouterVideoContent
class OpenRouterMessage(BaseModel):
role: Literal["system", "user", "assistant"] = Field(...)
content: str | list[OpenRouterContentBlock] = Field(...)
class OpenRouterReasoningConfig(BaseModel):
effort: str | None = Field(None)
exclude: bool | None = Field(None, description="If true, model reasons but reasoning is excluded from response.")
class OpenRouterWebSearchOptions(BaseModel):
search_context_size: str | None = Field(None)
class OpenRouterChatRequest(BaseModel):
model: str = Field(...)
messages: list[OpenRouterMessage] = Field(...)
seed: int | None = Field(None)
reasoning: OpenRouterReasoningConfig | None = Field(None)
web_search_options: OpenRouterWebSearchOptions | None = Field(None)
stream: bool = Field(False)
class OpenRouterUsage(BaseModel):
prompt_tokens: int | None = Field(None)
completion_tokens: int | None = Field(None)
total_tokens: int | None = Field(None)
cost: float | None = Field(None, description="Server-side authoritative USD cost of the call.")
class OpenRouterResponseMessage(BaseModel):
role: str | None = Field(None)
content: str | None = Field(None)
reasoning: str | None = Field(None)
refusal: str | None = Field(None)
class OpenRouterChoice(BaseModel):
index: int | None = Field(None)
message: OpenRouterResponseMessage | None = Field(None)
finish_reason: str | None = Field(None)
class OpenRouterError(BaseModel):
code: int | str | None = Field(None)
message: str | None = Field(None)
metadata: dict | None = Field(None)
class OpenRouterChatResponse(BaseModel):
id: str | None = Field(None)
model: str | None = Field(None)
object: str | None = Field(None)
provider: str | None = Field(None)
choices: list[OpenRouterChoice] | None = Field(None)
usage: OpenRouterUsage | None = Field(None)
error: OpenRouterError | None = Field(None)

View File

@ -9,8 +9,11 @@ from comfy_api_nodes.apis.anthropic import (
AnthropicMessage,
AnthropicMessagesRequest,
AnthropicMessagesResponse,
AnthropicOutputConfig,
AnthropicResponseTextBlock,
AnthropicRole,
AnthropicTextContent,
AnthropicThinkingConfig,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@ -32,15 +35,29 @@ CLAUDE_MODELS: dict[str, str] = {
"Haiku 4.5": "claude-haiku-4-5-20251001",
}
_THINKING_UNSUPPORTED = {"Haiku 4.5"}
# Models that use the newer "adaptive" thinking mode (Opus 4.7 requires it; older models keep the explicit budget API).
# Anthropic decides the actual budget when adaptive is used, based on the `output_config.effort` hint.
_ADAPTIVE_THINKING_MODELS = {"Opus 4.7", "Opus 4.6", "Sonnet 4.6"}
def _claude_model_inputs():
return [
# Budget mode (Sonnet 4.5): effort -> reasoning budget in tokens. Must be < max_tokens.
# Sized so even the "high" budget fits comfortably under the default max_tokens=32768.
_REASONING_BUDGET: dict[str, int] = {
"low": 2048,
"medium": 8192,
"high": 16384,
}
_REASONING_EFFORTS = ["off", "low", "medium", "high"]
def _claude_model_inputs(model_label: str):
inputs: list = [
IO.Int.Input(
"max_tokens",
default=16000,
min=32,
max=32000,
tooltip="Maximum number of tokens to generate before stopping.",
default=32768,
min=4096,
max=64000,
tooltip="Maximum number of tokens to generate (includes reasoning tokens when enabled).",
advanced=True,
),
IO.Float.Input(
@ -49,10 +66,24 @@ def _claude_model_inputs():
min=0.0,
max=1.0,
step=0.01,
tooltip="Controls randomness. 0.0 is deterministic, 1.0 is most random. Ignored for Opus 4.7.",
tooltip=(
"Controls randomness. 0.0 is deterministic, 1.0 is most random. "
"Ignored for Opus 4.7 and any model when reasoning_effort is set."
),
advanced=True,
),
]
if model_label not in _THINKING_UNSUPPORTED:
inputs.append(
IO.Combo.Input(
"reasoning_effort",
options=_REASONING_EFFORTS,
default="off",
tooltip="Extended thinking effort. 'off' disables reasoning.",
advanced=True,
)
)
return inputs
def _model_price_per_million(model: str) -> tuple[float, float] | None:
@ -95,7 +126,11 @@ def calculate_tokens_price(response: AnthropicMessagesResponse) -> float | None:
def _get_text_from_response(response: AnthropicMessagesResponse) -> str:
if not response.content:
return ""
return "\n".join(block.text for block in response.content if block.text)
# Thinking blocks are silently dropped — we never want reasoning in the output.
return "\n".join(
block.text for block in response.content
if isinstance(block, AnthropicResponseTextBlock) and block.text
)
async def _build_image_content_blocks(
@ -133,7 +168,10 @@ class ClaudeNode(IO.ComfyNode):
),
IO.DynamicCombo.Input(
"model",
options=[IO.DynamicCombo.Option(label, _claude_model_inputs()) for label in CLAUDE_MODELS],
options=[
IO.DynamicCombo.Option(label, _claude_model_inputs(label))
for label in CLAUDE_MODELS
],
tooltip="The Claude model used to generate the response.",
),
IO.Int.Input(
@ -207,8 +245,29 @@ class ClaudeNode(IO.ComfyNode):
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_label = model["model"]
max_tokens = model["max_tokens"]
temperature = None if model_label == "Opus 4.7" else model["temperature"]
max_tokens = model.get("max_tokens", 32768)
reasoning_effort = model.get("reasoning_effort", "off")
thinking_enabled = reasoning_effort not in ("off", None) and model_label not in _THINKING_UNSUPPORTED
# Anthropic requires temperature to be unset (defaults to 1.0) when thinking is enabled.
# Opus 4.7 also rejects user-supplied temperature.
if thinking_enabled or model_label == "Opus 4.7":
temperature = None
else:
temperature = model.get("temperature", 1.0)
thinking_cfg: AnthropicThinkingConfig | None = None
output_cfg: AnthropicOutputConfig | None = None
if thinking_enabled:
if model_label in _ADAPTIVE_THINKING_MODELS:
# Adaptive mode - Anthropic chooses the budget based on effort hint
thinking_cfg = AnthropicThinkingConfig(type="adaptive")
output_cfg = AnthropicOutputConfig(effort=reasoning_effort)
else:
# Budget mode (Sonnet 4.5). Leave at least 1024 tokens for the actual response
budget = _REASONING_BUDGET[reasoning_effort]
budget = min(budget, max(1024, max_tokens - 1024))
thinking_cfg = AnthropicThinkingConfig(type="enabled", budget_tokens=budget)
image_tensors: list[Input.Image] = [t for t in (images or {}).values() if t is not None]
if sum(get_number_of_images(t) for t in image_tensors) > CLAUDE_MAX_IMAGES:
@ -229,6 +288,8 @@ class ClaudeNode(IO.ComfyNode):
messages=[AnthropicMessage(role=AnthropicRole.user, content=content)],
system=system_prompt or None,
temperature=temperature,
thinking=thinking_cfg,
output_config=output_cfg,
),
price_extractor=calculate_tokens_price,
)

View File

@ -43,15 +43,16 @@ from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
download_url_to_video_output,
downscale_video_to_max_pixels,
get_number_of_images,
image_tensor_pair_to_batch,
poll_op,
resize_video_to_pixel_budget,
sync_op,
upload_audio_to_comfyapi,
upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
upscale_video_to_min_pixels,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_string,
@ -110,12 +111,13 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: st
max_px = limits.get("max")
if min_px and pixels < min_px:
raise ValueError(
f"Reference video {index} is too small: {w}x{h} = {pixels:,}px. " f"Minimum is {min_px:,}px for this model."
f"Reference video {index} is too small: {w}x{h} = {pixels:,} total pixels. "
f"Minimum for this model is {min_px:,} total pixels."
)
if max_px and pixels > max_px:
raise ValueError(
f"Reference video {index} is too large: {w}x{h} = {pixels:,}px. "
f"Maximum is {max_px:,}px for this model. Try downscaling the video."
f"Reference video {index} is too large: {w}x{h} = {pixels:,} total pixels. "
f"Maximum for this model is {max_px:,} total pixels. Try downscaling the video."
)
@ -1676,14 +1678,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
"first_frame_asset_id",
default="",
tooltip="Seedance asset_id to use as the first frame. "
"Mutually exclusive with the first_frame image input.",
"Mutually exclusive with the first_frame image input.",
optional=True,
),
IO.String.Input(
"last_frame_asset_id",
default="",
tooltip="Seedance asset_id to use as the last frame. "
"Mutually exclusive with the last_frame image input.",
"Mutually exclusive with the last_frame image input.",
optional=True,
),
IO.Int.Input(
@ -1865,11 +1867,20 @@ def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16
IO.Boolean.Input(
"auto_downscale",
default=False,
advanced=True,
optional=True,
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
),
IO.Boolean.Input(
"auto_upscale",
default=False,
advanced=True,
optional=True,
tooltip="Automatically upscale reference videos that are below the model's minimum pixel count "
"for the selected resolution. Aspect ratio is preserved; videos already meeting the minimum are "
"untouched. Note: upscaling a low-resolution source does not add real detail and may produce "
"lower-quality generations.",
),
IO.Autogrow.Input(
"reference_assets",
template=IO.Autogrow.TemplateNames(
@ -2030,7 +2041,13 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
max_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("max")
if max_px:
for key in reference_videos:
reference_videos[key] = resize_video_to_pixel_budget(reference_videos[key], max_px)
reference_videos[key] = downscale_video_to_max_pixels(reference_videos[key], max_px)
if model.get("auto_upscale") and reference_videos:
min_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("min")
if min_px:
for key in reference_videos:
reference_videos[key] = upscale_video_to_min_pixels(reference_videos[key], min_px)
total_video_duration = 0.0
for i, key in enumerate(reference_videos, 1):

View File

@ -0,0 +1,374 @@
"""API Nodes for OpenRouter LLM chat completions."""
from dataclasses import dataclass
from typing import Literal
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.openrouter import (
OpenRouterChatRequest,
OpenRouterChatResponse,
OpenRouterContentBlock,
OpenRouterImageContent,
OpenRouterImageUrl,
OpenRouterMessage,
OpenRouterReasoningConfig,
OpenRouterTextContent,
OpenRouterVideoContent,
OpenRouterVideoUrl,
OpenRouterWebSearchOptions,
)
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
sync_op,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
)
OPENROUTER_CHAT_ENDPOINT = "/proxy/openrouter/api/v1/chat/completions"
Profile = Literal["standard", "reasoning", "frontier_reasoning", "perplexity", "perplexity_reasoning"]
@dataclass(frozen=True)
class _ModelSpec:
slug: str # exact OpenRouter model id
profile: Profile
price_in: float # USD per token (prompt)
price_out: float # USD per token (completion)
max_images: int = 0 # 0 = no image input; otherwise max URL-passed images supported
max_videos: int = 0 # 0 = no video input; otherwise max URL-passed videos supported
MODELS: list[_ModelSpec] = [
_ModelSpec("anthropic/claude-opus-4.7", "frontier_reasoning", 0.000005, 0.000025, max_images=20),
_ModelSpec("openai/gpt-5.5-pro", "frontier_reasoning", 0.00003, 0.00018, max_images=20),
_ModelSpec("openai/gpt-5.5", "frontier_reasoning", 0.000005, 0.00003, max_images=20),
_ModelSpec("google/gemini-3.5-flash", "reasoning", 0.0000015, 0.000009, max_images=20, max_videos=4),
_ModelSpec("x-ai/grok-4.20", "reasoning", 0.00000125, 0.0000025, max_images=20),
_ModelSpec("x-ai/grok-4.3", "reasoning", 0.00000125, 0.0000025, max_images=20),
_ModelSpec("deepseek/deepseek-v4-pro", "reasoning", 0.000000435, 0.00000087),
_ModelSpec("deepseek/deepseek-v4-flash", "reasoning", 0.000000112, 0.000000224),
_ModelSpec("deepseek/deepseek-v3.2", "reasoning", 0.000000252, 0.000000378),
_ModelSpec("qwen/qwen3.6-max-preview", "reasoning", 0.00000104, 0.00000624),
_ModelSpec("qwen/qwen3.6-plus", "reasoning", 0.000000325, 0.00000195, max_images=10, max_videos=4),
_ModelSpec("qwen/qwen3.6-flash", "reasoning", 0.0000001875, 0.000001125, max_images=10, max_videos=4),
_ModelSpec("mistralai/mistral-large-2512", "standard", 0.0000005, 0.0000015, max_images=8),
_ModelSpec("mistralai/mistral-medium-3-5", "reasoning", 0.0000015, 0.0000075, max_images=8),
_ModelSpec("z-ai/glm-4.6", "reasoning", 0.00000043, 0.00000174),
_ModelSpec("z-ai/glm-5", "reasoning", 0.0000006, 0.00000192),
_ModelSpec("moonshotai/kimi-k2.6", "reasoning", 0.00000073, 0.00000349, max_images=10),
_ModelSpec("moonshotai/kimi-k2-thinking", "reasoning", 0.0000006, 0.0000025),
_ModelSpec("perplexity/sonar-pro", "perplexity", 0.000003, 0.000015),
_ModelSpec("perplexity/sonar-reasoning-pro", "perplexity_reasoning", 0.000002, 0.000008),
_ModelSpec("perplexity/sonar-deep-research", "perplexity_reasoning", 0.000002, 0.000008),
]
_MODELS_BY_SLUG: dict[str, _ModelSpec] = {m.slug: m for m in MODELS}
_REASONING_EFFORTS = ["off", "low", "medium", "high"]
_SEARCH_CONTEXT_SIZES = ["low", "medium", "high"]
def _reasoning_extra_inputs() -> list:
return [
IO.Combo.Input(
"reasoning_effort",
options=_REASONING_EFFORTS,
default="off",
tooltip="Reasoning effort. 'off' disables reasoning entirely.",
advanced=True,
),
]
def _perplexity_extra_inputs() -> list:
return [
IO.Combo.Input(
"search_context_size",
options=_SEARCH_CONTEXT_SIZES,
default="medium",
tooltip="How much web search context to retrieve. Larger = more grounded but slower/pricier.",
advanced=True,
),
]
def _profile_inputs(profile: Profile) -> list:
if profile == "standard":
return []
if profile in ("reasoning", "frontier_reasoning"):
return _reasoning_extra_inputs()
if profile == "perplexity":
return _perplexity_extra_inputs()
if profile == "perplexity_reasoning":
return _perplexity_extra_inputs() + _reasoning_extra_inputs()
raise ValueError(f"Unknown profile: {profile}")
def _media_inputs(spec: _ModelSpec) -> list:
extras: list = []
if spec.max_images > 0:
extras.append(
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, spec.max_images + 1)],
min=0,
),
tooltip=f"Optional reference image(s) — up to {spec.max_images}. Sent as URLs.",
)
)
if spec.max_videos > 0:
extras.append(
IO.Autogrow.Input(
"videos",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("video"),
names=[f"video_{i}" for i in range(1, spec.max_videos + 1)],
min=0,
),
tooltip=f"Optional reference video(s) — up to {spec.max_videos}. Sent as URLs.",
)
)
return extras
def _inputs_for_model(spec: _ModelSpec) -> list:
return _profile_inputs(spec.profile) + _media_inputs(spec)
def _build_model_options() -> list[IO.DynamicCombo.Option]:
return [IO.DynamicCombo.Option(spec.slug, _inputs_for_model(spec)) for spec in MODELS]
def _calculate_price(response: OpenRouterChatResponse) -> float | None:
if response.usage and response.usage.cost is not None:
return float(response.usage.cost)
return None
def _price_badge_jsonata() -> str:
rates_pairs = []
for spec in MODELS:
prompt_per_1k = spec.price_in * 1000
completion_per_1k = spec.price_out * 1000
rates_pairs.append(f' "{spec.slug}": [{prompt_per_1k:.8g}, {completion_per_1k:.8g}]')
rates_block = ",\n".join(rates_pairs)
return (
"(\n"
" $rates := {\n"
f"{rates_block}\n"
" };\n"
" $r := $lookup($rates, widgets.model);\n"
" $r ? {\n"
' "type": "list_usd",\n'
' "usd": $r,\n'
' "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }\n'
' } : {"type": "text", "text": "Token-based"}\n'
")"
)
async def _build_image_blocks(
cls: type[IO.ComfyNode], spec: _ModelSpec, images: list[Input.Image]
) -> list[OpenRouterImageContent]:
urls = await upload_images_to_comfyapi(
cls,
images,
max_images=spec.max_images,
total_pixels=2048 * 2048,
mime_type="image/png",
wait_label="Uploading reference images",
)
return [OpenRouterImageContent(image_url=OpenRouterImageUrl(url=url)) for url in urls]
async def _build_video_blocks(cls: type[IO.ComfyNode], videos: list[Input.Video]) -> list[OpenRouterVideoContent]:
blocks: list[OpenRouterVideoContent] = []
total = len(videos)
for idx, video in enumerate(videos):
label = "Uploading reference video"
if total > 1:
label = f"{label} ({idx + 1}/{total})"
url = await upload_video_to_comfyapi(cls, video, wait_label=label)
blocks.append(OpenRouterVideoContent(video_url=OpenRouterVideoUrl(url=url)))
return blocks
def _user_message(prompt: str, media_blocks: list[OpenRouterContentBlock]) -> OpenRouterMessage:
if not media_blocks:
return OpenRouterMessage(role="user", content=prompt)
blocks: list[OpenRouterContentBlock] = list(media_blocks)
blocks.append(OpenRouterTextContent(text=prompt))
return OpenRouterMessage(role="user", content=blocks)
def _build_messages(
system_prompt: str, prompt: str, media_blocks: list[OpenRouterContentBlock]
) -> list[OpenRouterMessage]:
messages: list[OpenRouterMessage] = []
if system_prompt:
messages.append(OpenRouterMessage(role="system", content=system_prompt))
messages.append(_user_message(prompt, media_blocks))
return messages
def _build_request(
slug: str,
system_prompt: str,
prompt: str,
media_blocks: list[OpenRouterContentBlock],
*,
seed: int,
reasoning_effort: str | None,
search_context_size: str | None,
) -> OpenRouterChatRequest:
reasoning_cfg: OpenRouterReasoningConfig | None = None
if reasoning_effort and reasoning_effort != "off":
# exclude=True asks providers to reason internally but not return the trace
reasoning_cfg = OpenRouterReasoningConfig(effort=reasoning_effort, exclude=True)
web_search_cfg: OpenRouterWebSearchOptions | None = None
if search_context_size:
web_search_cfg = OpenRouterWebSearchOptions(search_context_size=search_context_size)
return OpenRouterChatRequest(
model=slug,
messages=_build_messages(system_prompt, prompt, media_blocks),
seed=seed if seed > 0 else None,
reasoning=reasoning_cfg,
web_search_options=web_search_cfg,
)
def _extract_text(response: OpenRouterChatResponse) -> str:
if response.error:
code = response.error.code if response.error.code is not None else "unknown"
raise ValueError(f"OpenRouter error ({code}): {response.error.message or 'no message'}")
if not response.choices:
raise ValueError("Empty response from OpenRouter (no choices).")
message = response.choices[0].message
if not message:
raise ValueError("Empty response from OpenRouter (no message).")
if message.refusal:
raise ValueError(f"Model refused to respond: {message.refusal}")
return message.content or ""
class OpenRouterLLMNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="OpenRouterLLMNode",
display_name="OpenRouter LLM",
category="api node/text/OpenRouter",
essentials_category="Text Generation",
description=(
"Generate text responses through OpenRouter. Routes to a curated set of popular "
"models from xAI, DeepSeek, Qwen, Mistral, Z.AI (GLM), Moonshot (Kimi), and "
"Perplexity Sonar."
),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text input to the model.",
),
IO.DynamicCombo.Input(
"model",
options=_build_model_options(),
tooltip="The OpenRouter model used to generate the response.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for sampling. Set to 0 to omit. Most models treat this as a hint only.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
advanced=True,
tooltip="Foundational instructions that dictate the model's behavior.",
),
],
outputs=[IO.String.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr=_price_badge_jsonata(),
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
slug: str = model["model"]
spec = _MODELS_BY_SLUG.get(slug)
if spec is None:
raise ValueError(f"Unknown OpenRouter model: {slug}")
reasoning_effort: str | None = model.get("reasoning_effort")
search_context_size: str | None = model.get("search_context_size")
image_tensors: list[Input.Image] = [t for t in (model.get("images") or {}).values() if t is not None]
if image_tensors and sum(get_number_of_images(t) for t in image_tensors) > spec.max_images:
raise ValueError(f"Up to {spec.max_images} images are supported for {slug}.")
video_inputs: list[Input.Video] = [v for v in (model.get("videos") or {}).values() if v is not None]
if video_inputs and len(video_inputs) > spec.max_videos:
raise ValueError(f"Up to {spec.max_videos} videos are supported for {slug}.")
media_blocks: list[OpenRouterContentBlock] = []
if image_tensors:
media_blocks.extend(await _build_image_blocks(cls, spec, image_tensors))
if video_inputs:
media_blocks.extend(await _build_video_blocks(cls, video_inputs))
request = _build_request(
slug,
system_prompt,
prompt,
media_blocks,
seed=seed,
reasoning_effort=reasoning_effort,
search_context_size=search_context_size,
)
response = await sync_op(
cls,
ApiEndpoint(path=OPENROUTER_CHAT_ENDPOINT, method="POST"),
response_model=OpenRouterChatResponse,
data=request,
price_extractor=_calculate_price,
)
return IO.NodeOutput(_extract_text(response))
class OpenRouterExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [OpenRouterLLMNode]
async def comfy_entrypoint() -> OpenRouterExtension:
return OpenRouterExtension()

View File

@ -16,16 +16,17 @@ from .conversions import (
convert_mask_to_image,
downscale_image_tensor,
downscale_image_tensor_by_max_side,
downscale_video_to_max_pixels,
image_tensor_pair_to_batch,
pil_to_bytesio,
resize_mask_to_image,
resize_video_to_pixel_budget,
tensor_to_base64_string,
tensor_to_bytesio,
tensor_to_pil,
text_filepath_to_base64_string,
text_filepath_to_data_uri,
trim_video,
upscale_video_to_min_pixels,
video_to_base64_string,
)
from .download_helpers import (
@ -88,16 +89,17 @@ __all__ = [
"convert_mask_to_image",
"downscale_image_tensor",
"downscale_image_tensor_by_max_side",
"downscale_video_to_max_pixels",
"image_tensor_pair_to_batch",
"pil_to_bytesio",
"resize_mask_to_image",
"resize_video_to_pixel_budget",
"tensor_to_base64_string",
"tensor_to_bytesio",
"tensor_to_pil",
"text_filepath_to_base64_string",
"text_filepath_to_data_uri",
"trim_video",
"upscale_video_to_min_pixels",
"video_to_base64_string",
# Validation utilities
"get_image_dimensions",

View File

@ -415,14 +415,48 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
def resize_video_to_pixel_budget(video: Input.Video, total_pixels: int) -> Input.Video:
"""Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio.
def downscale_video_to_max_pixels(video: Input.Video, max_pixels: int) -> Input.Video:
"""Downscale a video to fit within ``max_pixels`` (w * h), preserving aspect ratio.
Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio.
Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
"""
src_w, src_h = video.get_dimensions()
scale_dims = _compute_downscale_dims(src_w, src_h, total_pixels)
scale_dims = _compute_downscale_dims(src_w, src_h, max_pixels)
if scale_dims is None:
return video
return _apply_video_scale(video, scale_dims)
def _compute_upscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[int, int] | None:
"""Return upscaled (w, h) with even dims meeting at least ``total_pixels``, or None if already large enough.
Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions
are rounded up to even values (many codecs require divisible-by-2). The result is guaranteed to be at
least ``total_pixels``.
"""
pixels = src_w * src_h
if pixels >= total_pixels:
return None
scale = math.sqrt(total_pixels / pixels)
new_w = math.ceil(src_w * scale)
new_h = math.ceil(src_h * scale)
if new_w % 2:
new_w += 1
if new_h % 2:
new_h += 1
return new_w, new_h
def upscale_video_to_min_pixels(video: Input.Video, min_pixels: int) -> Input.Video:
"""Upscale a video to meet at least ``min_pixels`` (w * h), preserving aspect ratio.
Returns the original video object untouched when it already meets the minimum. Preserves frame rate,
duration, and audio. Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
Note: upscaling a low-resolution source does not add real detail; downstream model quality may suffer.
"""
src_w, src_h = video.get_dimensions()
scale_dims = _compute_upscale_dims(src_w, src_h, min_pixels)
if scale_dims is None:
return video
return _apply_video_scale(video, scale_dims)

View File

@ -0,0 +1,73 @@
"""Enrich executed-node output entries with asset id."""
import logging
import os
def enrich_output_with_assets(output_ui: dict) -> dict:
"""Inject asset ``id`` into file-type output entries when --enable-assets is set.
Returns a new dict; entries without a resolvable on-disk file path are left
unchanged. Errors are caught per-entry so a failure never blocks the WS
message from sending.
"""
from comfy.cli_args import args
if not args.enable_assets:
return output_ui
import folder_paths
from app.assets.services.ingest import register_file_in_place, DependencyMissingError
from app.assets.database.queries.asset_reference import get_reference_by_file_path
from app.database.db import create_session
enriched = {}
for key, entries in output_ui.items():
if not isinstance(entries, list):
enriched[key] = entries
continue
new_entries = []
for entry in entries:
if not isinstance(entry, dict) or "filename" not in entry or "type" not in entry:
new_entries.append(entry)
continue
try:
base = folder_paths.get_directory_by_type(entry["type"])
if base is None:
new_entries.append(entry)
continue
base_abs = os.path.abspath(base)
abs_path = os.path.abspath(os.path.join(base_abs, entry.get("subfolder") or "", entry["filename"]))
try:
if os.path.commonpath([base_abs, abs_path]) != base_abs:
raise ValueError("escapes base")
except ValueError:
logging.warning("Asset enrichment skipped (path escapes base): %s", entry.get("filename"))
new_entries.append(entry)
continue
if not os.path.isfile(abs_path):
new_entries.append(entry)
continue
# Try DB lookup first (cached node re-send); fall back to registering inline.
asset_id = None
with create_session() as session:
db_ref = get_reference_by_file_path(session, abs_path)
if db_ref is not None:
asset_id = db_ref.id
if asset_id is None:
result = register_file_in_place(
abs_path=abs_path,
name=entry["filename"],
tags=[entry["type"]],
)
asset_id = result.ref.id
entry = dict(entry)
entry["id"] = asset_id
except DependencyMissingError:
logging.warning("Asset enrichment skipped (blake3 not available): %s", entry.get("filename"))
except Exception:
logging.warning("Failed to enrich output entry with asset id: %s", entry.get("filename"), exc_info=True)
new_entries.append(entry)
enriched[key] = new_entries
return enriched

View File

@ -15,7 +15,7 @@ class SwitchNode(io.ComfyNode):
return io.Schema(
node_id="ComfySwitchNode",
display_name="Switch",
category="utils/logic",
category="logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
@ -46,7 +46,7 @@ class SoftSwitchNode(io.ComfyNode):
return io.Schema(
node_id="ComfySoftSwitchNode",
display_name="Soft Switch",
category="utils/logic",
category="logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
@ -136,7 +136,7 @@ class DCTestNode(io.ComfyNode):
return io.Schema(
node_id="DCTestNode",
display_name="DCTest",
category="utils/logic",
category="logic",
is_output_node=True,
inputs=[io.DynamicCombo.Input("combo", options=[
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
@ -174,7 +174,7 @@ class AutogrowNamesTestNode(io.ComfyNode):
return io.Schema(
node_id="AutogrowNamesTestNode",
display_name="AutogrowNamesTest",
category="utils/logic",
category="logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
@ -194,7 +194,7 @@ class AutogrowPrefixTestNode(io.ComfyNode):
return io.Schema(
node_id="AutogrowPrefixTestNode",
display_name="AutogrowPrefixTest",
category="utils/logic",
category="logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
@ -213,7 +213,7 @@ class ComboOutputTestNode(io.ComfyNode):
return io.Schema(
node_id="ComboOptionTestNode",
display_name="ComboOptionTest",
category="utils/logic",
category="logic",
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
outputs=[io.Combo.Output(), io.Combo.Output()],
@ -230,7 +230,7 @@ class ConvertStringToComboNode(io.ComfyNode):
node_id="ConvertStringToComboNode",
search_aliases=["string to dropdown", "text to combo"],
display_name="Convert String to Combo",
category="utils/logic",
category="logic",
inputs=[io.String.Input("string")],
outputs=[io.Combo.Output()],
)
@ -246,7 +246,7 @@ class InvertBooleanNode(io.ComfyNode):
node_id="InvertBooleanNode",
search_aliases=["not", "toggle", "negate", "flip boolean"],
display_name="Invert Boolean",
category="utils/logic",
category="logic",
inputs=[io.Boolean.Input("boolean")],
outputs=[io.Boolean.Output()],
)

View File

@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode):
return io.Schema(
node_id="ComfyMathExpression",
display_name="Math Expression",
category="utils",
category="logic",
search_aliases=[
"expression", "formula", "calculate", "calculator",
"eval", "math",

View File

@ -14,7 +14,7 @@ class CreateList(io.ComfyNode):
return io.Schema(
node_id="CreateList",
display_name="Create List",
category="utils",
category="logic",
is_input_list=True,
search_aliases=["Image Iterator", "Text Iterator", "Iterator"],
inputs=[io.Autogrow.Input("inputs", template=template_autogrow)],

View File

@ -40,6 +40,7 @@ from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.utils import CurrentNodeContext
from comfy_execution.asset_enrichment import enrich_output_with_assets
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
@ -418,11 +419,15 @@ def _is_intermediate_output(dynprompt, node_id):
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
if server.client_id is None:
return
cached_ui = cached.ui or {}
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
output = cached_ui.get("output", None)
if output:
output = enrich_output_with_assets(output)
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": output, "prompt_id": prompt_id }, server.client_id)
if cached.ui is not None:
ui_outputs[node_id] = cached.ui
@ -562,7 +567,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
"output": output_ui
}
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": enrich_output_with_assets(output_ui), "prompt_id": prompt_id }, server.client_id)
if has_subgraph:
cached_outputs = []
new_node_ids = []

View File

@ -1517,6 +1517,22 @@ paths:
schema:
type: integer
default: 0
description: |
Offset-based pagination. Cursor pagination via `after` is preferred
for sequential walks (stable across concurrent inserts/deletes) but
`offset` remains fully supported for random access (jump-to-page
UIs, "showing items XY of N" displays). When both are supplied,
`after` wins and `offset` is ignored.
- name: after
in: query
schema:
type: string
description: |
Opaque cursor for keyset pagination. Pass the `next_cursor` value
from a previous response to fetch the next page. Stable across
inserts/deletes between pages. Supported with `sort` values
`created_at`, `updated_at`, `name`, and `size`. Malformed or
unsupported cursors return 400 with `INVALID_CURSOR`.
- name: include_tags
in: query
schema:
@ -1556,6 +1572,17 @@ paths:
type: string
enum: [asc, desc]
description: Sort direction
- name: job_ids
in: query
schema:
type: array
maxItems: 500
items:
type: string
format: uuid
style: form
explode: true
description: "Filter assets by associated job UUIDs. Accepts repeated query params (e.g. `?job_ids=a&job_ids=b`) or a single comma-separated value (`?job_ids=a,b`)."
- name: include_public
in: query
schema:
@ -1575,6 +1602,12 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/ListAssetsResponse"
"400":
description: Malformed query or cursor (e.g. `INVALID_CURSOR`)
content:
application/json:
schema:
$ref: "#/components/schemas/AssetsApiError"
post:
operationId: createAsset
tags: [assets]
@ -6610,7 +6643,18 @@ components:
description: Unique identifier for the asset
name:
type: string
deprecated: true
description: Name of the asset file
file_path:
type: string
nullable: true
x-runtime: [cloud, local]
description: "Logical asset locator under the namespace root. Not a unique reference key; use `id` for identity."
display_name:
type: string
nullable: true
x-runtime: [cloud, local]
description: "Human-facing display label for the asset. Not a unique reference key; use `id` for identity."
hash:
type: string
nullable: true
@ -6750,6 +6794,42 @@ components:
type: integer
has_more:
type: boolean
next_cursor:
type: string
description: |
Opaque cursor to fetch the next page. Pass back as the `after`
query parameter. Omitted when there are no more results.
AssetsApiError:
type: object
description: Error envelope returned by the assets API on 400 responses.
required:
- error
properties:
error:
type: object
required:
- code
- message
- details
properties:
code:
type: string
description: |
Machine-readable error code. `INVALID_CURSOR` is returned when the
`after` cursor is malformed, oversized, or its sort field does
not match the request's `sort`. `INVALID_QUERY` covers other
Pydantic validation failures.
enum: [INVALID_CURSOR, INVALID_QUERY]
message:
type: string
details:
type: object
description: |
Free-form, code-specific context. `INVALID_QUERY` populates this
with an `errors` array of Pydantic validation entries;
`INVALID_CURSOR` returns an empty object.
additionalProperties: true
TagInfo:
type: object
@ -6840,6 +6920,13 @@ components:
enum: [input, output, temp]
display_name:
type: string
id:
type: string
format: uuid
description: |
Asset reference UUID. Present only when the server is started with
`--enable-assets` and the file resolves to a registered asset.
Fetch the full asset via `GET /api/assets/{id}`.
NodeOutputs:
type: object
@ -8723,4 +8810,4 @@ components:
items:
$ref: "#/components/schemas/TaskEntry"
pagination:
$ref: "#/components/schemas/PaginationInfo"
$ref: "#/components/schemas/PaginationInfo"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.43.18
comfyui-workflow-templates==0.9.79
comfyui-workflow-templates==0.9.82
comfyui-embedded-docs==0.5.0
torch
torchsde

View File

@ -236,6 +236,8 @@ def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_bas
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
body = r.json()
assert r.status_code == 201, body
from helpers import assert_hash_fields_consistent
assert_hash_fields_consistent(body)
return body

View File

@ -26,3 +26,26 @@ def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
def get_asset_filename(asset_hash: str, extension: str) -> str:
return asset_hash.removeprefix("blake3:") + extension
def assert_hash_fields_consistent(body: dict, expected_hash: str | None = None) -> None:
"""Assert hash and asset_hash invariants on an Asset response.
Both must be present or both absent (so a regression that drops only one
is caught). When present, they must equal each other and, if expected_hash
is provided, must equal that value.
"""
hash_present = "hash" in body
asset_hash_present = "asset_hash" in body
assert hash_present == asset_hash_present, (
f"hash and asset_hash must both be present or both absent: "
f"hash present={hash_present}, asset_hash present={asset_hash_present}"
)
if hash_present:
h = body["hash"]
ah = body["asset_hash"]
assert h == ah, f"hash and asset_hash must match: hash={h!r}, asset_hash={ah!r}"
if expected_hash is not None:
assert h == expected_hash, (
f"hash must equal expected: got {h!r}, expected {expected_hash!r}"
)

View File

@ -21,6 +21,7 @@ from app.assets.database.queries import (
get_reference_ids_by_ids,
ensure_tags_exist,
add_tags_to_reference,
set_reference_tags,
)
from app.assets.helpers import get_utc_now
@ -158,6 +159,203 @@ class TestListReferencesPage:
refs, _, _ = list_references_page(session, sort="name", order="asc")
assert refs[0].name == "large"
def test_job_ids_filter(self, session: Session):
asset = _make_asset(session, "hash1")
job_a = str(uuid.uuid4())
job_b = str(uuid.uuid4())
ref_a = _make_reference(session, asset, name="from_job_a")
ref_a.job_id = job_a
ref_b = _make_reference(session, asset, name="from_job_b")
ref_b.job_id = job_b
_make_reference(session, asset, name="no_job")
session.commit()
# Single job filter
refs, _, total = list_references_page(session, job_ids=[job_a])
assert total == 1
assert refs[0].name == "from_job_a"
# Multi-job filter (IN)
refs, _, total = list_references_page(session, job_ids=[job_a, job_b])
names = sorted(r.name for r in refs)
assert total == 2
assert names == ["from_job_a", "from_job_b"]
# Unknown job id matches nothing
refs, _, total = list_references_page(session, job_ids=[str(uuid.uuid4())])
assert total == 0
assert refs == []
# Empty/None means no filter -> all three references
refs, _, total = list_references_page(session, job_ids=[])
assert total == 3
refs, _, total = list_references_page(session, job_ids=None)
assert total == 3
def test_job_ids_combined_with_other_filters(self, session: Session):
asset = _make_asset(session, "hash1")
job_a = str(uuid.uuid4())
ref_match = _make_reference(session, asset, name="match.bin")
ref_match.job_id = job_a
ref_wrong_name = _make_reference(session, asset, name="other.bin")
ref_wrong_name.job_id = job_a
ref_wrong_job = _make_reference(session, asset, name="match.bin")
ref_wrong_job.job_id = str(uuid.uuid4())
session.commit()
refs, _, total = list_references_page(
session, job_ids=[job_a], name_contains="match"
)
assert total == 1
assert refs[0].id == ref_match.id
class TestTagRetrievalOrder:
"""End-to-end check: tags written through the public write paths come
back from the public read paths in insertion order rather than the
composite-PK alphabetical order SQLite would otherwise impose.
Each test deliberately picks tag names that would sort differently
under alphabetical vs insertion order, so an alphabetical regression
fails loudly.
"""
def _make_ref(self, session: Session) -> AssetReference:
asset = _make_asset(session, "h1")
return _make_reference(session, asset, name="x.bin")
def test_set_reference_tags_preserves_input_order_in_list(self, session: Session):
ref = self._make_ref(session)
# "checkpoints" < "models" alphabetically; if added_at stagger
# works, list_references_page returns insertion order.
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
session.commit()
_, tag_map, _ = list_references_page(session)
assert tag_map[ref.id] == ["models", "checkpoints"]
def test_set_reference_tags_preserves_input_order_in_fetch(self, session: Session):
ref = self._make_ref(session)
# Subpath tag sorts before "models" alphabetically.
set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "diffusers/kolors/text_encoder"],
)
session.commit()
result = fetch_reference_asset_and_tags(session, ref.id)
assert result is not None
_, _, tags = result
# Bucket-prefix expansion appends the standalone `diffusers` token
# at path-tier (microsecond stagger) so FE set-membership filters
# match nested category paths.
assert tags == ["models", "diffusers/kolors/text_encoder", "diffusers"]
def test_add_tags_to_reference_lands_after_path_tags(self, session: Session):
ref = self._make_ref(session)
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
session.commit()
# "aaa-..." sorts before both path tags alphabetically. If added_at
# stagger is missing, alphabetic tiebreak would hoist it to tags[0].
add_tags_to_reference(
session, reference_id=ref.id, tags=["aaa-user-tag"], origin="manual"
)
session.commit()
_, tag_map, _ = list_references_page(session)
assert tag_map[ref.id] == ["models", "checkpoints", "aaa-user-tag"]
def test_multi_tag_batch_lands_after_path_tags(self, session: Session):
ref = self._make_ref(session)
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
session.commit()
# Three user tags inserted in non-alphabetical input order. Per-tag
# microsecond stagger should preserve at least the "user batch is
# after path tags" property; within the user batch insertion order
# is also preserved.
add_tags_to_reference(
session,
reference_id=ref.id,
tags=["zzz-z", "favorite", "experiment-q4"],
origin="manual",
)
session.commit()
_, tag_map, _ = list_references_page(session)
tags = tag_map[ref.id]
assert tags[0:2] == ["models", "checkpoints"]
assert set(tags[2:]) == {"zzz-z", "favorite", "experiment-q4"}
def test_user_batch_lands_after_path_batch_under_clock_collision(
self, session: Session, monkeypatch: pytest.MonkeyPatch
):
"""Windows-specific race: when two back-to-back commits share the
same datetime.now() microsecond, the path-tier and user-tier
added_at values used to collide and alphabetic tiebreak would
hoist user tags ahead of path tags. The fix reads
max(existing_added_at) for the reference and seeds the next batch
past it, deterministically restoring insertion order.
This test simulates the collision by pinning get_utc_now() so the
platform-dependent race becomes a platform-independent failure.
"""
ref = self._make_ref(session)
from datetime import datetime
from app.assets.database import queries as queries_pkg
from app.assets.database.queries import tags as tags_module
frozen = datetime(2026, 1, 1, 0, 0, 0)
monkeypatch.setattr(tags_module, "get_utc_now", lambda: frozen)
monkeypatch.setattr(queries_pkg, "get_utc_now", lambda: frozen, raising=False)
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
session.commit()
# Same frozen timestamp — without the max(existing) seed, the
# user batch would share added_at with the path batch and
# `aaa-user-tag` would sort to position 0 via the alphabetic
# tiebreaker.
add_tags_to_reference(
session, reference_id=ref.id, tags=["aaa-user-tag"], origin="manual"
)
session.commit()
_, tag_map, _ = list_references_page(session)
assert tag_map[ref.id] == ["models", "checkpoints", "aaa-user-tag"]
def test_remove_then_add_does_not_disrupt_path_tag_positions(
self, session: Session
):
ref = self._make_ref(session)
set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "loras/my/custom/path"],
)
session.commit()
add_tags_to_reference(session, reference_id=ref.id, tags=["temp-tag"])
session.commit()
from app.assets.database.queries import remove_tags_from_reference
remove_tags_from_reference(session, reference_id=ref.id, tags=["temp-tag"])
session.commit()
add_tags_to_reference(session, reference_id=ref.id, tags=["second-tag"])
session.commit()
_, tag_map, _ = list_references_page(session)
# `loras` is expanded from the nested category path; user-added
# tags trail behind it via the microsecond stagger.
assert tag_map[ref.id] == [
"models",
"loras/my/custom/path",
"loras",
"second-tag",
]
class TestFetchReferenceAssetAndTags:
def test_returns_none_for_nonexistent(self, session: Session):

View File

@ -0,0 +1,112 @@
"""Keyset-pagination tiebreaker tests for list_references_page.
When multiple rows share the same primary sort value (e.g. four assets
created in the same microsecond), the secondary `ORDER BY id` is what keeps
keyset pagination from losing or repeating rows. This file exercises that
branch directly against an in-memory SQLite session — engineering identical
timestamps via HTTP is unreliable enough that we work at the query layer.
"""
import uuid
from datetime import datetime
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries.asset_reference import list_references_page
def _make_ref(session: Session, created_at: datetime, name: str, owner: str = "") -> AssetReference:
asset = Asset(hash=f"blake3:{uuid.uuid4().hex}", size_bytes=1024)
session.add(asset)
session.flush()
ref = AssetReference(
id=str(uuid.uuid4()),
asset_id=asset.id,
owner_id=owner,
name=name,
file_path=f"/tmp/{name}",
created_at=created_at,
updated_at=created_at,
last_access_time=created_at,
is_missing=False,
)
session.add(ref)
return ref
@pytest.mark.parametrize("order", ["desc", "asc"])
def test_tiebreaker_walks_duplicate_sort_values(session: Session, order: str):
"""Four rows with the SAME created_at must paginate cleanly under cursor
mode — no row dropped, no row repeated, despite the primary sort column
being non-discriminating.
"""
shared_ts = datetime(2024, 5, 20, 12, 0, 0) # naive UTC, like the DB stores
refs = [_make_ref(session, shared_ts, f"tie_{i}.png") for i in range(4)]
session.commit()
expected_ids = sorted([r.id for r in refs], reverse=(order == "desc"))
# Walk the cursor by hand: page size 2, take 3 pages (2 + 2 + 0).
seen: list[str] = []
after_value = None
after_id = None
for _ in range(4): # generous loop bound; ought to be 2 iterations
page, _tag_map, _total = list_references_page(
session,
limit=2,
sort="created_at",
order=order,
after_cursor_value=after_value,
after_cursor_id=after_id,
)
if not page:
break
seen.extend(p.id for p in page)
# Use the last row's (created_at, id) as the next cursor input.
last = page[-1]
after_value, after_id = last.created_at, last.id
if len(page) < 2:
break
assert seen == expected_ids, (
f"keyset tiebreaker failed for order={order}: expected {expected_ids}, got {seen}"
)
def test_tiebreaker_no_duplicates_under_mixed_collisions(session: Session):
"""Some rows share a timestamp, some don't. The cursor must still walk
every row exactly once regardless of where ties sit relative to a
page boundary."""
t1 = datetime(2024, 5, 20, 12, 0, 0)
t2 = datetime(2024, 5, 20, 12, 0, 1)
layout = [t1, t1, t1, t2, t2] # three rows at t1, two at t2
refs = [_make_ref(session, ts, f"mix_{i}.png") for i, ts in enumerate(layout)]
session.commit()
all_ids = {r.id for r in refs}
seen_set: set[str] = set()
seen_list: list[str] = []
after_value = None
after_id = None
for _ in range(6):
page, _, _ = list_references_page(
session,
limit=2,
sort="created_at",
order="desc",
after_cursor_value=after_value,
after_cursor_id=after_id,
)
if not page:
break
for p in page:
assert p.id not in seen_set, f"duplicate row {p.id} appeared in cursor walk"
seen_set.add(p.id)
seen_list.append(p.id)
last = page[-1]
after_value, after_id = last.created_at, last.id
if len(page) < 2:
break
assert seen_set == all_ids, f"missing rows: expected {all_ids}, got {seen_set}"

View File

@ -0,0 +1,60 @@
"""Schema-level unit tests for ListAssetsQuery (no DB required)."""
import uuid
import pytest
from pydantic import ValidationError
from app.assets.api.schemas_in import ListAssetsQuery
class TestJobIdsValidator:
def test_csv_string_parses_and_canonicalizes(self):
a = "AAAAAAAA-BBBB-CCCC-DDDD-EEEEEEEEEEEE"
b = "11111111-2222-3333-4444-555555555555"
q = ListAssetsQuery.model_validate({"job_ids": f"{a},{b}"})
# Canonicalized to lowercase
assert q.job_ids == [a.lower(), b]
def test_repeated_query_params_as_list(self):
a = "11111111-1111-1111-1111-111111111111"
b = "22222222-2222-2222-2222-222222222222"
q = ListAssetsQuery.model_validate({"job_ids": [a, b]})
assert q.job_ids == [a, b]
def test_dedup_preserves_first_seen_order(self):
a = "11111111-1111-1111-1111-111111111111"
b = "22222222-2222-2222-2222-222222222222"
q = ListAssetsQuery.model_validate({"job_ids": [a, b, a]})
assert q.job_ids == [a, b]
def test_default_empty(self):
q = ListAssetsQuery.model_validate({})
assert q.job_ids == []
def test_invalid_uuid_rejected(self):
with pytest.raises(ValidationError) as exc:
ListAssetsQuery.model_validate({"job_ids": "not-a-uuid"})
assert "must be UUIDs" in str(exc.value)
def test_non_string_list_item_rejected(self):
with pytest.raises(ValidationError) as exc:
ListAssetsQuery.model_validate(
{"job_ids": ["11111111-1111-1111-1111-111111111111", 42]}
)
assert "must be strings" in str(exc.value)
def test_non_string_non_list_value_rejected(self):
with pytest.raises(ValidationError) as exc:
ListAssetsQuery.model_validate({"job_ids": {"bad": "shape"}})
assert "must be a string or list of strings" in str(exc.value)
def test_max_length_enforced(self):
too_many = [str(uuid.uuid4()) for _ in range(501)]
with pytest.raises(ValidationError) as exc:
ListAssetsQuery.model_validate({"job_ids": too_many})
assert exc.value.errors()[0]["type"] == "too_long"
def test_max_length_boundary_accepted(self):
at_cap = [str(uuid.uuid4()) for _ in range(500)]
q = ListAssetsQuery.model_validate({"job_ids": at_cap})
assert len(q.job_ids) == 500

View File

@ -160,6 +160,120 @@ class TestAddTagsToReference:
add_tags_to_reference(session, reference_id="nonexistent", tags=["x"])
class TestBucketPrefixExpansion:
"""The standalone bucket token must appear in the asset's tag set for
nested category paths so FE filters like
`include_tags=models,checkpoints` continue to match.
"""
def test_set_reference_tags_inserts_bucket_for_nested_path(
self, session: Session
):
asset = _make_asset(session, "hash-nested")
ref = _make_reference(session, asset)
result = set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "checkpoints/flux"],
)
session.commit()
assert set(result.total) == {"models", "checkpoints/flux", "checkpoints"}
stored = get_reference_tags(session, reference_id=ref.id)
# tag[1] keeps the slash-joined positional contract; the standalone
# bucket lands after it via path-tier microsecond stagger so user
# tags remain at the tail.
assert stored[:3] == ["models", "checkpoints/flux", "checkpoints"]
def test_set_reference_tags_idempotent_on_replay(self, session: Session):
asset = _make_asset(session, "hash-replay")
ref = _make_reference(session, asset)
set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "checkpoints/flux"],
)
# Replay with the same caller-supplied set; expansion is already
# baked in, so nothing should be added or removed.
result = set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "checkpoints/flux"],
)
session.commit()
assert result.added == []
assert result.removed == []
assert set(result.total) == {"models", "checkpoints/flux", "checkpoints"}
def test_add_tags_to_reference_expands_bucket(self, session: Session):
asset = _make_asset(session, "hash-add")
ref = _make_reference(session, asset)
result = add_tags_to_reference(
session,
reference_id=ref.id,
tags=["loras/style/v2"],
)
session.commit()
assert set(result.added) == {"loras/style/v2", "loras"}
stored = get_reference_tags(session, reference_id=ref.id)
assert "loras" in stored
assert "loras/style/v2" in stored
def test_add_tags_does_not_duplicate_existing_bucket(self, session: Session):
asset = _make_asset(session, "hash-dedupe")
ref = _make_reference(session, asset)
add_tags_to_reference(
session, reference_id=ref.id, tags=["models", "checkpoints"]
)
result = add_tags_to_reference(
session, reference_id=ref.id, tags=["checkpoints/flux"]
)
session.commit()
# `checkpoints` was already there from the first add; only the
# slash-joined token is genuinely new.
assert result.added == ["checkpoints/flux"]
assert "checkpoints" in result.already_present
def test_flat_category_is_unaffected(self, session: Session):
asset = _make_asset(session, "hash-flat")
ref = _make_reference(session, asset)
result = set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "checkpoints"],
)
session.commit()
assert set(result.total) == {"models", "checkpoints"}
assert get_reference_tags(session, reference_id=ref.id) == [
"models",
"checkpoints",
]
def test_unknown_prefix_passes_through(self, session: Session):
asset = _make_asset(session, "hash-user")
ref = _make_reference(session, asset)
# `my-org` isn't a registered bucket — the slash-joined user tag
# should not trigger bucket expansion.
result = set_reference_tags(
session,
reference_id=ref.id,
tags=["my-org/team-a"],
)
session.commit()
assert result.total == ["my-org/team-a"]
class TestRemoveTagsFromReference:
def test_removes_tags(self, session: Session):
asset = _make_asset(session, "hash1")

View File

@ -4,7 +4,7 @@ from pathlib import Path
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag
from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets
@ -102,6 +102,82 @@ class TestBatchInsertSeedAssets:
assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}"
class TestBucketPrefixExpansionOnIngest:
"""Path-scanning ingest must persist the standalone bucket token for
nested category paths so the FE set-membership filter
(`include_tags=models,checkpoints`) matches assets organized into
subfolders (`models/checkpoints/flux/foo.safetensors`).
"""
def test_nested_path_inserts_standalone_bucket(
self, session: Session, temp_dir: Path
):
file_path = temp_dir / "flux.safetensors"
file_path.write_bytes(b"content")
specs: list[SeedAssetSpec] = [
{
"abs_path": str(file_path),
"size_bytes": 7,
"mtime_ns": 1234567890000000000,
"info_name": "flux",
# Shape emitted by get_name_and_tags_from_asset_path for a
# nested model path.
"tags": ["models", "checkpoints/flux"],
"fname": "flux.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
}
]
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
ref = session.query(AssetReference).filter_by(name="flux").one()
stored = [
row.tag_name
for row in session.query(AssetReferenceTag)
.filter_by(asset_reference_id=ref.id)
.order_by(AssetReferenceTag.added_at.asc())
.all()
]
assert stored == ["models", "checkpoints/flux", "checkpoints"]
def test_flat_path_remains_two_tags(
self, session: Session, temp_dir: Path
):
file_path = temp_dir / "vanilla.safetensors"
file_path.write_bytes(b"content")
specs: list[SeedAssetSpec] = [
{
"abs_path": str(file_path),
"size_bytes": 7,
"mtime_ns": 1234567890000000000,
"info_name": "vanilla",
"tags": ["models", "checkpoints"],
"fname": "vanilla.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
}
]
batch_insert_seed_assets(session, specs=specs, owner_id="")
ref = session.query(AssetReference).filter_by(name="vanilla").one()
stored = {
row.tag_name
for row in session.query(AssetReferenceTag)
.filter_by(asset_reference_id=ref.id)
.all()
}
# Dedupe means flat layouts don't pick up a redundant `checkpoints`
# row — tag[1] already serves both positional and set-membership.
assert stored == {"models", "checkpoints"}
class TestMetadataExtraction:
def test_extracts_mime_type_for_model_files(self, temp_dir: Path):
"""Verify metadata extraction returns correct mime_type for model files."""

View File

@ -0,0 +1,354 @@
"""Tests for app.assets.services.cursor.
The byte-identity fixtures below pin the wire format so a parallel
implementation in another runtime can mint exchange-compatible cursors
for the same payload. Drift here would break frontend pagination against
any compatible backend.
"""
from __future__ import annotations
import base64
from datetime import datetime, timedelta, timezone
import pytest
from app.assets.services.cursor import (
MAX_CURSOR_ID_LENGTH,
MAX_CURSOR_VALUE_LENGTH,
MAX_ENCODED_CURSOR_LENGTH,
CursorPayload,
InvalidCursorError,
decode_cursor,
decode_cursor_int,
decode_cursor_time,
encode_cursor,
encode_cursor_from_time,
)
ALLOWED = ("created_at", "updated_at", "name", "size")
class TestRoundTrip:
@pytest.mark.parametrize(
"sort_field, value, id",
[
("created_at", "1716200000000000", "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7"),
("size", "1024", "asset-123"),
("name", "my-asset.png", "asset-abc"),
("name", "résumé.txt", "asset-uni"),
],
)
def test_encode_decode(self, sort_field, value, id):
encoded = encode_cursor(sort_field, value, id)
assert encoded != ""
payload = decode_cursor(encoded, ALLOWED)
assert payload.sort_field == sort_field
assert payload.value == value
assert payload.id == id
class TestTimeCursor:
def test_microsecond_precision_preserved(self):
# Pick a time with non-zero microseconds — encoding at ms would lose the µs.
ts = datetime(2024, 5, 20, 12, 53, 20, 123456, tzinfo=timezone.utc)
encoded = encode_cursor_from_time("created_at", ts, "id-1")
payload = decode_cursor(encoded, ALLOWED)
# Value must be a microsecond integer string, not a millisecond one.
assert payload.value == "1716209600123456"
decoded = decode_cursor_time(payload)
assert decoded == ts
def test_decode_returns_utc(self):
payload = CursorPayload(sort_field="created_at", value="1716200000123456", id="id-1", order="desc")
decoded = decode_cursor_time(payload)
assert decoded.tzinfo == timezone.utc
def test_naive_datetime_rejected_on_encode(self):
naive = datetime(2024, 5, 20, 12, 0, 0)
with pytest.raises(ValueError):
encode_cursor_from_time("created_at", naive, "id-1")
def test_non_integer_value_rejected_on_decode(self):
with pytest.raises(InvalidCursorError):
decode_cursor_time(CursorPayload("created_at", "not-a-number", "id-1", "desc"))
def test_none_payload_rejected(self):
with pytest.raises(InvalidCursorError):
decode_cursor_time(None)
def test_non_utc_aware_normalized(self):
# Same instant, different timezone — must encode to the same micros.
utc_ts = datetime(2024, 5, 20, 12, 0, 0, tzinfo=timezone.utc)
offset_ts = utc_ts.astimezone(timezone(timedelta(hours=-5)))
assert encode_cursor_from_time("created_at", utc_ts, "x") == encode_cursor_from_time(
"created_at", offset_ts, "x"
)
class TestIntCursor:
def test_decode_int(self):
assert decode_cursor_int(CursorPayload("size", "1024", "id-1", "desc")) == 1024
def test_decode_int_rejects_non_int(self):
with pytest.raises(InvalidCursorError):
decode_cursor_int(CursorPayload("size", "abc", "id-1", "desc"))
def test_decode_int_rejects_none(self):
with pytest.raises(InvalidCursorError):
decode_cursor_int(None)
class TestInvalidInputs:
def test_oversized_cursor(self):
oversized = "a" * (MAX_ENCODED_CURSOR_LENGTH + 1)
with pytest.raises(InvalidCursorError, match="maximum length"):
decode_cursor(oversized, ALLOWED)
def test_not_base64(self):
with pytest.raises(InvalidCursorError):
decode_cursor("not base64!!!", ALLOWED)
def test_not_json(self):
encoded = base64.urlsafe_b64encode(b"definitely not json").rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError):
decode_cursor(encoded, ALLOWED)
def test_empty_id(self):
# Encoder rejects empty id symmetrically with the decoder, so build the
# payload manually to exercise the decoder's missing-id branch.
raw = b'{"s":"created_at","v":"1","id":"","o":"desc"}'
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError, match="missing id"):
decode_cursor(encoded, ALLOWED)
def test_oversized_id(self):
# Encoder enforces the cap symmetrically; hand-build to exercise decode.
big_id = "a" * (MAX_CURSOR_ID_LENGTH + 1)
raw = ('{"s":"created_at","v":"1","id":"' + big_id + '","o":"desc"}').encode("ascii")
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError, match="id exceeds maximum length"):
decode_cursor(encoded, ALLOWED)
def test_oversized_value(self):
# Encoder enforces the cap symmetrically; hand-build to exercise decode.
big_v = "v" * (MAX_CURSOR_VALUE_LENGTH + 1)
raw = ('{"s":"created_at","v":"' + big_v + '","id":"id-1","o":"desc"}').encode("ascii")
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError, match="value exceeds maximum length"):
decode_cursor(encoded, ALLOWED)
def test_unsupported_sort_field(self):
encoded = encode_cursor("execution_time", "1", "id-1")
with pytest.raises(InvalidCursorError, match="unsupported sort field"):
decode_cursor(encoded, ALLOWED)
def test_no_allowed_fields_rejects_everything(self):
encoded = encode_cursor("created_at", "1", "id-1")
with pytest.raises(InvalidCursorError):
decode_cursor(encoded, ())
def test_non_dict_payload_rejected(self):
encoded = base64.urlsafe_b64encode(b'["array","not","dict"]').rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError, match="expected object"):
decode_cursor(encoded, ALLOWED)
class TestEncodeAtCapsFits:
def test_max_field_lengths_fit_wire_cap(self):
# Worst-case payload: value and id at their per-field caps, with a long
# sort field name. The encoded cursor must fit within MAX_ENCODED_CURSOR_LENGTH
# so the wire cap cannot reject a cursor the encoder mints at the per-field caps.
value = "v" * MAX_CURSOR_VALUE_LENGTH
id = "i" * MAX_CURSOR_ID_LENGTH
sort_field = "very_long_sort_field_name"
encoded = encode_cursor(sort_field, value, id)
assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH
payload = decode_cursor(encoded, (sort_field,))
assert payload.value == value
assert payload.id == id
class TestDatetimeOverflow:
"""Crafted cursors with extreme micros must map to InvalidCursorError,
not OverflowError/OSError leaking as 500.
"""
@pytest.mark.parametrize(
"micros_str",
[
"999999999999999999999", # 10^21 µs — past datetime.MAX_YEAR by ~14 orders
"-999999999999999999999", # symmetric negative — pre-epoch overflow
],
)
def test_out_of_range_micros_rejected(self, micros_str):
encoded = encode_cursor("created_at", micros_str, "asset-x")
payload = decode_cursor(encoded, ALLOWED)
with pytest.raises(InvalidCursorError):
decode_cursor_time(payload)
class TestEncoderDecoderSymmetry:
"""The encoder must reject inputs the decoder rejects, or the same server
will mint a cursor it then 400s on the next request.
"""
def test_long_name_within_cap_round_trips(self):
"""Assets allow names up to 512 chars (`String(512)`); the cursor
encoder must round-trip a value at that cap so a freshly minted
cursor never fails decode on the next request."""
long_name = "n" * MAX_CURSOR_VALUE_LENGTH
encoded = encode_cursor("name", long_name, "asset-x")
payload = decode_cursor(encoded, ALLOWED)
assert payload.value == long_name
def test_encoder_rejects_empty_id(self):
with pytest.raises(InvalidCursorError, match="id must be non-empty"):
encode_cursor("created_at", "1", "")
def test_encoder_rejects_oversized_id(self):
with pytest.raises(InvalidCursorError, match="id exceeds maximum length"):
encode_cursor("created_at", "1", "a" * (MAX_CURSOR_ID_LENGTH + 1))
def test_encoder_rejects_oversized_value(self):
with pytest.raises(InvalidCursorError, match="value exceeds maximum length"):
encode_cursor("name", "v" * (MAX_CURSOR_VALUE_LENGTH + 1), "id-1")
def test_encoder_rejects_multibyte_value_over_wire_cap(self):
"""A value that passes the char-count cap can still inflate past the
wire cap once UTF-8-encoded. Asset name made of 512 × multibyte
characters (e.g. 'é' = 2 bytes) must be rejected at encode time, not
minted into a cursor the next request will 400."""
with pytest.raises(InvalidCursorError, match="encoded cursor exceeds maximum length"):
encode_cursor("name", "é" * MAX_CURSOR_VALUE_LENGTH, "asset-multibyte")
def test_encoder_rejects_escape_heavy_value_over_wire_cap(self):
"""Same wire-cap concern via escape expansion: each `<` serializes to
the six-byte sequence `\\u003c`, so 512 of them blow past the encoded
cap even though the raw char count is within the per-field limit."""
with pytest.raises(InvalidCursorError, match="encoded cursor exceeds maximum length"):
encode_cursor("name", "<" * MAX_CURSOR_VALUE_LENGTH, "asset-escape")
class TestOrderBinding:
def test_order_baked_into_payload(self):
encoded = encode_cursor("created_at", "1", "id-1", order="asc")
payload = decode_cursor(encoded, ALLOWED)
assert payload.order == "asc"
def test_mismatched_order_rejected(self):
encoded = encode_cursor("created_at", "1", "id-1", order="desc")
with pytest.raises(InvalidCursorError, match="does not match request order"):
decode_cursor(encoded, ALLOWED, expected_order="asc")
def test_matching_order_accepted(self):
encoded = encode_cursor("created_at", "1", "id-1", order="desc")
payload = decode_cursor(encoded, ALLOWED, expected_order="desc")
assert payload.order == "desc"
def test_invalid_order_token_rejected_on_encode(self):
with pytest.raises(ValueError):
encode_cursor("created_at", "1", "id-1", order="sideways")
def test_invalid_order_token_rejected_on_decode(self):
# Hand-craft a payload with an illegal `o` value.
raw = b'{"s":"name","v":"x","id":"id-1","o":"sideways"}'
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError, match="unsupported order"):
decode_cursor(encoded, ALLOWED)
def test_cursor_without_order_rejected(self):
"""`o` is mandatory. A cursor minted without it is rejected as
malformed rather than silently walking the keyset in whatever
direction the request happens to ask for."""
raw = b'{"s":"name","v":"x","id":"id-1"}'
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
with pytest.raises(InvalidCursorError, match="missing or non-string o"):
decode_cursor(encoded, ALLOWED, expected_order="desc")
class TestHtmlSignificantCharEscaping:
"""An asset name containing `<`, `>`, `&`, U+2028, or U+2029 must encode
to the same escaped wire bytes as any compatible implementation of the
same payload format. Drift here breaks cross-runtime byte-identity for
those characters.
"""
@pytest.mark.parametrize(
"value, escaped_substring",
[
("foo<bar>.png", "\\u003c"), # `<` escaped
("foo<bar>.png", "\\u003e"), # `>` escaped
("foo&bar.png", "\\u0026"),
("foobar.png", "\\u2028"), # JS line separator
("foobar.png", "\\u2029"), # JS paragraph separator
],
)
def test_html_significant_chars_escaped(self, value, escaped_substring):
encoded = encode_cursor("name", value, "id-1")
decoded_bytes = base64.urlsafe_b64decode(encoded + "=" * (-len(encoded) % 4))
assert escaped_substring in decoded_bytes.decode("ascii"), (
f"Expected {escaped_substring!r} in serialized payload, got: {decoded_bytes!r}"
)
def test_value_round_trips_through_escape(self):
"""Encoding then decoding a value with `<>&` should yield the original
string — the escape only affects the wire form, not the decoded value."""
original = "foo<&>bar.png"
encoded = encode_cursor("name", original, "id-1")
payload = decode_cursor(encoded, ALLOWED)
assert payload.value == original
class TestByteIdentityFixtures:
"""Pin the wire format so it doesn't drift silently.
These fixtures assert exact byte equality of the encoded JSON payload —
a change in key order, escape choice, separator whitespace, or anything
else that shifts a byte fails the test loudly rather than diverging
silently from any external consumer of the same payload format.
"""
@pytest.mark.parametrize(
"sort_field, value, id, order, expected_payload",
[
(
"created_at",
"1716200000000000",
"a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7",
"desc",
'{"s":"created_at","v":"1716200000000000","id":"a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7","o":"desc"}',
),
(
"size",
"1024",
"asset-123",
"asc",
'{"s":"size","v":"1024","id":"asset-123","o":"asc"}',
),
(
"name",
"my-asset.png",
"asset-abc",
"desc",
'{"s":"name","v":"my-asset.png","id":"asset-abc","o":"desc"}',
),
(
"name",
"foo<bar>&baz.png",
"asset-html",
"desc",
# `<`, `>`, `&` escape to <, >, & in the value.
'{"s":"name","v":"foo\\u003cbar\\u003e\\u0026baz.png","id":"asset-html","o":"desc"}',
),
],
)
def test_encoded_payload_shape_pinned(self, sort_field, value, id, order, expected_payload):
encoded = encode_cursor(sort_field, value, id, order=order)
decoded_bytes = base64.urlsafe_b64decode(encoded + "=" * (-len(encoded) % 4))
assert decoded_bytes.decode("utf-8") == expected_payload, (
f"wire format drifted for sort={sort_field!r}, value={value!r}:\n"
f" expected: {expected_payload!r}\n"
f" actual: {decoded_bytes.decode('utf-8')!r}"
)

View File

@ -0,0 +1,86 @@
"""Tests for the image_dimensions service."""
from __future__ import annotations
from pathlib import Path
import pytest
from PIL import Image
from app.assets.services.image_dimensions import extract_image_dimensions
def _make_png(path: Path, size: tuple[int, int]) -> Path:
img = Image.new("RGB", size, color=(123, 45, 67))
img.save(path, format="PNG")
return path
def _make_jpeg(path: Path, size: tuple[int, int]) -> Path:
img = Image.new("RGB", size, color=(10, 20, 30))
img.save(path, format="JPEG", quality=80)
return path
class TestExtractImageDimensions:
def test_extracts_png_dimensions(self, tmp_path: Path):
f = _make_png(tmp_path / "rect.png", (320, 240))
result = extract_image_dimensions(str(f), mime_type="image/png")
assert result == {"kind": "image", "width": 320, "height": 240}
def test_extracts_jpeg_dimensions(self, tmp_path: Path):
f = _make_jpeg(tmp_path / "shot.jpg", (1920, 1080))
result = extract_image_dimensions(str(f), mime_type="image/jpeg")
assert result == {"kind": "image", "width": 1920, "height": 1080}
def test_works_when_mime_type_is_none(self, tmp_path: Path):
f = _make_png(tmp_path / "no_mime.png", (50, 100))
result = extract_image_dimensions(str(f), mime_type=None)
assert result == {"kind": "image", "width": 50, "height": 100}
def test_skips_non_image_mime_without_touching_file(self, tmp_path: Path):
# Path doesn't need to exist — non-image MIME short-circuits.
result = extract_image_dimensions(
str(tmp_path / "model.safetensors"),
mime_type="application/octet-stream",
)
assert result is None
@pytest.mark.parametrize(
"mime",
["application/json", "text/plain", "video/mp4", "audio/mpeg"],
)
def test_skips_all_non_image_mime_types(self, tmp_path: Path, mime: str):
f = tmp_path / "file.bin"
f.write_bytes(b"\x00\x01\x02")
assert extract_image_dimensions(str(f), mime_type=mime) is None
def test_returns_none_for_missing_file(self, tmp_path: Path):
result = extract_image_dimensions(
str(tmp_path / "does_not_exist.png"), mime_type="image/png"
)
assert result is None
def test_returns_none_for_corrupt_image(self, tmp_path: Path):
f = tmp_path / "corrupt.png"
f.write_bytes(b"not actually a png file")
result = extract_image_dimensions(str(f), mime_type="image/png")
assert result is None
def test_returns_none_for_empty_file(self, tmp_path: Path):
f = tmp_path / "empty.png"
f.write_bytes(b"")
result = extract_image_dimensions(str(f), mime_type="image/png")
assert result is None

View File

@ -4,10 +4,12 @@ from pathlib import Path
from unittest.mock import patch
import pytest
from PIL import Image
from sqlalchemy.orm import Session as SASession, Session
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
from app.assets.database.queries import get_reference_tags
from app.assets.helpers import get_utc_now
from app.assets.services.ingest import (
_ingest_file_from_path,
_register_existing_asset,
@ -15,6 +17,11 @@ from app.assets.services.ingest import (
)
def _make_png(path: Path, size: tuple[int, int]) -> Path:
Image.new("RGB", size, color=(80, 120, 200)).save(path, format="PNG")
return path
class TestIngestFileFromPath:
def test_creates_asset_and_reference(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "test_file.bin"
@ -279,4 +286,203 @@ class TestIngestExistingFileTagFK:
ref_tags = sess.query(AssetReferenceTag).all()
ref_tag_names = {rt.tag_name for rt in ref_tags}
assert "output" in ref_tag_names
assert "my-job" in ref_tag_names
class TestIngestImageDimensions:
"""system_metadata should carry {kind, width, height} for image assets."""
def test_image_asset_emits_dimensions(
self, mock_create_session, temp_dir: Path, session: Session
):
f = _make_png(temp_dir / "shot.png", (640, 480))
result = _ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:img1",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000000,
mime_type="image/png",
)
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
assert ref.system_metadata == {
"kind": "image",
"width": 640,
"height": 480,
}
def test_non_image_asset_leaves_system_metadata_empty(
self, mock_create_session, temp_dir: Path, session: Session
):
f = temp_dir / "model.safetensors"
f.write_bytes(b"not an image")
result = _ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:safetensors1",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000000,
mime_type="application/octet-stream",
)
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
assert ref.system_metadata in (None, {})
def test_preserves_existing_system_metadata_keys(
self, mock_create_session, temp_dir: Path, session: Session
):
f = _make_png(temp_dir / "annotated.png", (100, 200))
# First pass populates a sentinel system_metadata key (simulating prior
# enricher write).
result = _ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:img-merge",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000000,
mime_type="image/png",
)
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/x.png"}
session.commit()
# Second pass with the same path triggers the merge code path again.
_ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:img-merge",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000001,
mime_type="image/png",
)
session.refresh(ref)
assert ref.system_metadata["kind"] == "image"
assert ref.system_metadata["width"] == 100
assert ref.system_metadata["height"] == 200
assert ref.system_metadata["source_url"] == "https://example/x.png"
class TestRegisterExistingAssetBackfill:
"""The from-hash path back-fills dimensions from a sibling reference."""
def _add_reference(
self,
session: Session,
asset: Asset,
name: str,
system_metadata: dict | None = None,
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
asset_id=asset.id,
name=name,
owner_id="",
created_at=now,
updated_at=now,
last_access_time=now,
system_metadata=system_metadata or {},
)
session.add(ref)
session.flush()
return ref
def test_backfills_dimensions_from_sibling_image_reference(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:shared", size_bytes=2048, mime_type="image/png")
session.add(asset)
session.flush()
self._add_reference(
session,
asset,
name="original.png",
system_metadata={"kind": "image", "width": 800, "height": 600},
)
session.commit()
result = _register_existing_asset(
asset_hash="blake3:shared",
name="from_hash.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
assert ref.system_metadata.get("kind") == "image"
assert ref.system_metadata.get("width") == 800
assert ref.system_metadata.get("height") == 600
def test_no_backfill_when_sibling_has_no_image_metadata(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:nodims", size_bytes=2048, mime_type="image/png")
session.add(asset)
session.flush()
self._add_reference(
session,
asset,
name="original.png",
system_metadata={"base_model": "flux"}, # no kind=image
)
session.commit()
result = _register_existing_asset(
asset_hash="blake3:nodims",
name="from_hash.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
meta = ref.system_metadata or {}
assert "kind" not in meta
assert "width" not in meta
assert "height" not in meta
def test_no_backfill_when_no_sibling_exists(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:lonely", size_bytes=1024, mime_type="image/png")
session.add(asset)
session.commit()
result = _register_existing_asset(
asset_hash="blake3:lonely",
name="solo.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
assert ref.system_metadata in (None, {})
def test_backfill_preserves_caller_supplied_keys(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:preserve", size_bytes=2048, mime_type="image/png")
session.add(asset)
session.flush()
self._add_reference(
session,
asset,
name="original.png",
system_metadata={"kind": "image", "width": 1024, "height": 768},
)
session.commit()
# Simulate a from-hash path where the new reference already carries
# some system_metadata (e.g. a download-provenance source_url written
# by an earlier step). The back-fill must merge dim keys without
# clobbering existing keys.
result = _register_existing_asset(
asset_hash="blake3:preserve",
name="from_hash.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
# Seed a sentinel key and re-run back-fill via a second register call
# to exercise the merge path with pre-existing data.
ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/p"}
session.commit()
assert ref.system_metadata.get("source_url") == "https://example/p"
assert ref.system_metadata.get("kind") == "image"
assert ref.system_metadata.get("width") == 1024
assert ref.system_metadata.get("height") == 768

View File

@ -6,7 +6,13 @@ from unittest.mock import patch
import pytest
from app.assets.services.path_utils import get_asset_category_and_relative_path
from app.assets.services.path_utils import (
compute_display_name,
compute_file_path,
get_asset_category_and_relative_path,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags,
)
@pytest.fixture
@ -38,6 +44,50 @@ def fake_dirs():
}
@pytest.fixture
def fake_dirs_multi_bucket():
"""Variant fixture with multiple model buckets (checkpoints + diffusers + loras)."""
with tempfile.TemporaryDirectory() as root:
root_path = Path(root)
input_dir = root_path / "input"
output_dir = root_path / "output"
temp_dir = root_path / "temp"
checkpoints_dir = root_path / "models" / "checkpoints"
diffusers_dir = root_path / "models" / "diffusers"
loras_dir = root_path / "models" / "loras"
for d in (
input_dir,
output_dir,
temp_dir,
checkpoints_dir,
diffusers_dir,
loras_dir,
):
d.mkdir(parents=True)
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.get_temp_directory.return_value = str(temp_dir)
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[
("checkpoints", [str(checkpoints_dir)]),
("diffusers", [str(diffusers_dir)]),
("loras", [str(loras_dir)]),
],
):
yield {
"input": input_dir,
"output": output_dir,
"temp": temp_dir,
"checkpoints": checkpoints_dir,
"diffusers": diffusers_dir,
"loras": loras_dir,
}
class TestGetAssetCategoryAndRelativePath:
def test_input_file(self, fake_dirs):
f = fake_dirs["input"] / "photo.png"
@ -79,3 +129,185 @@ class TestGetAssetCategoryAndRelativePath:
def test_unknown_path_raises(self, fake_dirs):
with pytest.raises(ValueError, match="not within"):
get_asset_category_and_relative_path("/some/random/path.png")
class TestGetNameAndTagsFromAssetPath:
"""tags collapse the parent subpath into a single slash-joined tag.
Consumers should be able to read ``tags[1]`` as a stable category
identifier regardless of how deep the file lives in the bucket.
"""
def test_flat_input(self, fake_dirs_multi_bucket):
f = fake_dirs_multi_bucket["input"] / "photo.png"
f.touch()
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "photo.png"
assert tags == ["input"]
def test_flat_output(self, fake_dirs_multi_bucket):
f = fake_dirs_multi_bucket["output"] / "result_00001.png"
f.touch()
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "result_00001.png"
assert tags == ["output"]
def test_flat_models_checkpoint(self, fake_dirs_multi_bucket):
f = fake_dirs_multi_bucket["checkpoints"] / "flux.safetensors"
f.touch()
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "flux.safetensors"
assert tags == ["models", "checkpoints"]
def test_diffusers_nested_subpath_slash_joined(self, fake_dirs_multi_bucket):
"""Diffusers components live in nested directories — the full subpath
must collapse into one tag so consumers can look up the model category
via tags[1] regardless of nesting depth.
The subpath is lowercased to match the canonicalization
:func:`ensure_tags_exist` applies on the write side; without that,
the asset_reference_tags.tag_name FK to tags.name would fail for
any path containing uppercase letters.
"""
nested = (
fake_dirs_multi_bucket["diffusers"]
/ "Kolors"
/ "text_encoder"
)
nested.mkdir(parents=True)
f = nested / "model.safetensors"
f.touch()
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "model.safetensors"
assert tags == ["models", "diffusers/kolors/text_encoder"]
def test_deep_lora_user_subpath_slash_joined(self, fake_dirs_multi_bucket):
"""User-created subdirectories under a model bucket also collapse to a
single tag rather than one tag per directory."""
nested = (
fake_dirs_multi_bucket["loras"]
/ "my"
/ "custom"
/ "path"
)
nested.mkdir(parents=True)
f = nested / "v0001.safetensors"
f.touch()
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "v0001.safetensors"
assert tags == ["models", "loras/my/custom/path"]
class TestResolveDestinationFromTags:
"""resolve_destination_from_tags must accept both the legacy
one-tag-per-directory shape and the new slash-joined shape so that an
upload using the tags it just read back from /api/assets round-trips
to the right on-disk destination.
"""
@pytest.fixture
def resolve_dirs(self):
with tempfile.TemporaryDirectory() as root:
root_path = Path(root)
input_dir = root_path / "input"
output_dir = root_path / "output"
checkpoints_dir = root_path / "models" / "checkpoints"
diffusers_dir = root_path / "models" / "diffusers"
loras_dir = root_path / "models" / "loras"
for d in (input_dir, output_dir, checkpoints_dir, diffusers_dir, loras_dir):
d.mkdir(parents=True)
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.folder_names_and_paths = {
"checkpoints": ([str(checkpoints_dir)], None),
"diffusers": ([str(diffusers_dir)], None),
"loras": ([str(loras_dir)], None),
}
yield {
"input": input_dir,
"output": output_dir,
"checkpoints": checkpoints_dir,
"diffusers": diffusers_dir,
"loras": loras_dir,
}
def test_models_flat_category(self, resolve_dirs):
base, subdirs = resolve_destination_from_tags(["models", "checkpoints"])
assert base == str(resolve_dirs["checkpoints"])
assert subdirs == []
def test_models_slash_joined_new_shape(self, resolve_dirs):
# The shape get_name_and_tags_from_asset_path now emits.
base, subdirs = resolve_destination_from_tags(
["models", "diffusers/kolors/text_encoder"]
)
assert base == str(resolve_dirs["diffusers"])
assert subdirs == ["kolors", "text_encoder"]
def test_models_legacy_one_tag_per_dir(self, resolve_dirs):
# The legacy shape must still resolve identically.
base, subdirs = resolve_destination_from_tags(
["models", "diffusers", "kolors", "text_encoder"]
)
assert base == str(resolve_dirs["diffusers"])
assert subdirs == ["kolors", "text_encoder"]
def test_models_loras_slash_joined(self, resolve_dirs):
base, subdirs = resolve_destination_from_tags(
["models", "loras/my/custom/path"]
)
assert base == str(resolve_dirs["loras"])
assert subdirs == ["my", "custom", "path"]
def test_input_no_subdir(self, resolve_dirs):
base, subdirs = resolve_destination_from_tags(["input"])
assert base == str(resolve_dirs["input"])
assert subdirs == []
def test_input_slash_joined_subdir(self, resolve_dirs):
base, subdirs = resolve_destination_from_tags(["input", "portraits/2026"])
assert base == str(resolve_dirs["input"])
assert subdirs == ["portraits", "2026"]
def test_output_slash_joined_subdir(self, resolve_dirs):
base, subdirs = resolve_destination_from_tags(["output", "runs/abc"])
assert base == str(resolve_dirs["output"])
assert subdirs == ["runs", "abc"]
def test_unknown_category_rejected(self, resolve_dirs):
with pytest.raises(ValueError, match="unknown model category"):
resolve_destination_from_tags(["models", "not_a_real_category"])
def test_unknown_category_via_slash_joined(self, resolve_dirs):
# First segment of a slash-joined tag must still match a registered category.
with pytest.raises(ValueError, match="unknown model category 'bogus'"):
resolve_destination_from_tags(["models", "bogus/sub/path"])
def test_traversal_in_subdir_rejected(self, resolve_dirs):
with pytest.raises(ValueError, match="invalid path component"):
resolve_destination_from_tags(["models", "checkpoints/..", "evil"])
class TestResponsePaths:
def test_input_file_path_and_display_name_include_subfolder(self, fake_dirs):
sub = fake_dirs["input"] / "some" / "folder"
sub.mkdir(parents=True)
f = sub / "image.png"
f.touch()
assert compute_file_path(str(f)) == "input/some/folder/image.png"
assert compute_display_name(str(f)) == "some/folder/image.png"
def test_model_file_path_includes_bucket_display_name_drops_it(self, fake_dirs):
sub = fake_dirs["models"] / "flux"
sub.mkdir()
f = sub / "model.safetensors"
f.touch()
assert compute_file_path(str(f)) == "models/checkpoints/flux/model.safetensors"
assert compute_display_name(str(f)) == "flux/model.safetensors"
def test_unknown_path_returns_none(self, fake_dirs):
assert compute_file_path("/some/random/path.png") is None
assert compute_display_name("/some/random/path.png") is None

View File

@ -32,7 +32,7 @@ def test_seed_asset_removed_when_file_is_deleted(
# Verify it is visible via API and carries no hash (seed)
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
params={"include_tags": "unit-tests/syncseed", "name_contains": name},
timeout=120,
)
body1 = r1.json()
@ -40,7 +40,9 @@ def test_seed_asset_removed_when_file_is_deleted(
# there should be exactly one with that name
matches = [a for a in body1.get("assets", []) if a.get("name") == name]
assert matches
assert matches[0].get("asset_hash") is None
# Seed assets have no hash; exclude_none drops both keys from the response
assert "asset_hash" not in matches[0]
assert "hash" not in matches[0]
asset_info_id = matches[0]["id"]
# Remove the underlying file and sync again
@ -52,7 +54,7 @@ def test_seed_asset_removed_when_file_is_deleted(
# It should disappear (AssetInfo and seed Asset gone)
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
params={"include_tags": "unit-tests/syncseed", "name_contains": name},
timeout=120,
)
body2 = r2.json()
@ -332,7 +334,7 @@ def test_fastpass_removes_stale_state_row_no_missing(
rl = http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}"},
params={"include_tags": f"unit-tests/{scope}"},
timeout=120,
)
bl = rl.json()

View File

@ -21,6 +21,8 @@ def test_create_from_hash_success(
b1 = r1.json()
assert r1.status_code == 201, b1
assert b1["asset_hash"] == h
assert b1["hash"] == h
assert b1["hash"] == b1["asset_hash"]
assert b1["created_new"] is False
aid = b1["id"]
@ -39,6 +41,7 @@ def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asse
detail = rg.json()
assert rg.status_code == 200, detail
assert detail["id"] == aid
assert detail["hash"] == detail["asset_hash"]
assert "user_metadata" in detail
assert "filename" in detail["user_metadata"]
@ -97,6 +100,7 @@ def test_delete_upon_reference_count(
copy = r2.json()
assert r2.status_code == 201, copy
assert copy["asset_hash"] == src_hash
assert copy["hash"] == src_hash
assert copy["created_new"] is False
# Soft-delete original reference (default) -> asset identity must remain
@ -139,6 +143,7 @@ def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset
body = ru.json()
assert ru.status_code == 200, body
assert body["name"] == payload["name"]
assert body["hash"] == body["asset_hash"]
assert body["tags"] == original_tags # tags unchanged
assert body["user_metadata"]["purpose"] == "updated"
# filename should still be present and normalized by server
@ -280,16 +285,24 @@ def test_metadata_filename_is_set_for_seed_asset_without_hash(
trigger_sync_seed_assets(http, api_base)
# Scanner emits tags as ``[root, "<dir1>/<dir2>/..."]`` — the second tag
# is the slash-joined parent subpath. For ``<root>/unit-tests/<scope>/a/b/<name>``
# the second tag is ``"unit-tests/<scope>/a/b"``.
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
params={
"include_tags": f"unit-tests/{scope}/a/b",
"name_contains": name,
},
timeout=120,
)
body = r1.json()
assert r1.status_code == 200, body
matches = [a for a in body.get("assets", []) if a.get("name") == name]
assert matches, "Seed asset should be visible after sync"
assert matches[0].get("asset_hash") is None # still a seed
# Seed assets have no hash; exclude_none drops both keys from the response
assert "asset_hash" not in matches[0]
assert "hash" not in matches[0]
aid = matches[0]["id"]
r2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)

View File

@ -0,0 +1,69 @@
"""Unit tests for app.assets.helpers."""
from app.assets.helpers import expand_bucket_prefixes
class TestExpandBucketPrefixes:
def test_flat_category_unchanged(self):
# `checkpoints` is already a standalone token, no expansion needed.
assert expand_bucket_prefixes(["models", "checkpoints"]) == [
"models",
"checkpoints",
]
def test_nested_category_inserts_bucket(self):
# Path-derived shape for `models/checkpoints/flux/foo.safetensors` —
# the standalone bucket has to be present so the FE set-membership
# filter (`include_tags=models,checkpoints`) matches the asset.
assert expand_bucket_prefixes(["models", "checkpoints/flux"]) == [
"models",
"checkpoints/flux",
"checkpoints",
]
def test_deeply_nested_only_first_segment_expands(self):
# Only the FIRST slash segment ever gets emitted as a standalone —
# intermediate path segments don't have routing significance.
assert expand_bucket_prefixes(
["models", "diffusers/kolors/text_encoder"]
) == ["models", "diffusers/kolors/text_encoder", "diffusers"]
def test_unknown_prefix_does_not_expand(self):
# Free-form user labels with slashes whose first segment is not a
# registered bucket pass through opaquely.
assert expand_bucket_prefixes(["models", "my-org/team-a"]) == [
"models",
"my-org/team-a",
]
def test_idempotent(self):
# Re-applying the helper is a no-op once the bucket is in the set.
expanded = expand_bucket_prefixes(["models", "checkpoints/flux"])
assert expand_bucket_prefixes(expanded) == expanded
def test_does_not_duplicate_existing_bucket(self):
# If the caller already supplied the standalone bucket, don't add a
# second copy.
assert expand_bucket_prefixes(
["models", "checkpoints/flux", "checkpoints"]
) == ["models", "checkpoints/flux", "checkpoints"]
def test_preserves_caller_order(self):
# User tags after path tags must stay after; the inserted bucket
# token slots in immediately after its slash-joined parent so the
# microsecond stagger lands it at path-tier before user-tier.
assert expand_bucket_prefixes(
["models", "loras/style", "favorite", "v2"]
) == ["models", "loras/style", "loras", "favorite", "v2"]
def test_empty_input(self):
assert expand_bucket_prefixes([]) == []
def test_input_root_with_subpath_no_expansion(self):
# `portraits` isn't a registered model category, so the input
# subpath stays opaque (FE filter doesn't have a checkpoint-loader
# analogue for input subfolders).
assert expand_bucket_prefixes(["input", "portraits/2026"]) == [
"input",
"portraits/2026",
]

View File

@ -0,0 +1,349 @@
"""Integration tests for cursor-based pagination on GET /api/assets.
These tests exercise the handler/service/query path end-to-end;
cursor-encoding-level tests live in
tests-unit/assets_test/services/test_cursor.py.
"""
import pytest
import requests
def _seed(asset_factory, make_asset_bytes, count: int, tag: str) -> list[str]:
names = [f"cursor_{i:02d}.safetensors" for i in range(count)]
for n in names:
asset_factory(
n,
["models", "checkpoints", "unit-tests", tag],
{},
make_asset_bytes(n, size=2048),
)
return sorted(names)
def test_cursor_pages_all_items_in_order(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
names = _seed(asset_factory, make_asset_bytes, count=5, tag="cursor-walk")
params = {
"include_tags": "unit-tests,cursor-walk",
"sort": "name",
"order": "asc",
"limit": "2",
}
seen: list[str] = []
after: str | None = None
pages = 0
while True:
page_params = dict(params)
if after is not None:
page_params["after"] = after
r = http.get(api_base + "/api/assets", params=page_params, timeout=120)
assert r.status_code == 200, r.text
body = r.json()
seen.extend(a["name"] for a in body["assets"])
pages += 1
after = body.get("next_cursor")
if after is None:
break
assert body["has_more"] is True
assert pages < 10, "guard against runaway cursor loop"
assert seen == names, f"expected {names}, got {seen}"
# Last page should have has_more False
assert body["has_more"] is False
assert "next_cursor" not in body
def test_cursor_invalid_returns_400(http: requests.Session, api_base: str):
r = http.get(
api_base + "/api/assets",
params={"after": "not-a-real-cursor", "sort": "created_at"},
timeout=120,
)
assert r.status_code == 400, r.text
body = r.json()
assert body["error"]["code"] == "INVALID_CURSOR"
def test_cursor_sort_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
_seed(asset_factory, make_asset_bytes, count=2, tag="cursor-mismatch")
# Take a real cursor minted for sort=name.
r = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-mismatch",
"sort": "name",
"order": "asc",
"limit": "1",
},
timeout=120,
)
assert r.status_code == 200
cursor = r.json()["next_cursor"]
assert cursor is not None
# Replay against sort=created_at — should fail with INVALID_CURSOR.
r2 = http.get(
api_base + "/api/assets",
params={"after": cursor, "sort": "created_at"},
timeout=120,
)
assert r2.status_code == 400, r2.text
assert r2.json()["error"]["code"] == "INVALID_CURSOR"
def test_cursor_wins_over_offset(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-vs-offset")
# Take a cursor that points past the first item.
r = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-vs-offset",
"sort": "name",
"order": "asc",
"limit": "1",
},
timeout=120,
)
assert r.status_code == 200, r.text
cursor = r.json()["next_cursor"]
assert cursor is not None
# Pass both 'after' and a large offset. Cursor must win; offset is ignored.
r2 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-vs-offset",
"sort": "name",
"order": "asc",
"limit": "1",
"after": cursor,
"offset": "999",
},
timeout=120,
)
assert r2.status_code == 200
body = r2.json()
# Should land on the second name in sorted order — not skip ahead by 999.
assert [a["name"] for a in body["assets"]] == [names[1]]
def test_next_cursor_absent_when_no_more_results(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
_seed(asset_factory, make_asset_bytes, count=2, tag="cursor-exhaust")
r = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-exhaust",
"sort": "name",
"order": "asc",
"limit": "50",
},
timeout=120,
)
assert r.status_code == 200, r.text
body = r.json()
assert body["has_more"] is False
assert "next_cursor" not in body
def test_cursor_pagination_first_page_mints_cursor(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
"""First-page request (no `after`) must still return `next_cursor` when
more rows exist, or pagination is unreachable from a cold start.
"""
_seed(asset_factory, make_asset_bytes, count=3, tag="cursor-first-page")
r = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,cursor-first-page", "sort": "name", "order": "asc", "limit": "2"},
timeout=120,
)
assert r.status_code == 200, r.text
body = r.json()
assert body["has_more"] is True
assert body.get("next_cursor"), "first page must mint a cursor when more rows exist"
def test_cursor_no_spurious_cursor_when_page_size_equals_remainder(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
"""When `total` is an exact multiple of `limit`, the final page must
NOT carry a next_cursor — there is nothing past it.
"""
_seed(asset_factory, make_asset_bytes, count=4, tag="cursor-exact-multiple")
# Page 1
r = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2"},
timeout=120,
)
assert r.status_code == 200, r.text
cursor = r.json()["next_cursor"]
assert cursor is not None
# Page 2 — should exhaust the set with no cursor for a phantom page 3
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2", "after": cursor},
timeout=120,
)
assert r2.status_code == 200, r2.text
body = r2.json()
assert len(body["assets"]) == 2
assert body["has_more"] is False
assert "next_cursor" not in body
@pytest.mark.parametrize("sort_field", ["created_at", "updated_at", "size"])
def test_cursor_walks_for_non_name_sorts(sort_field, http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
"""Cursor pagination must work for every sort field the contract claims.
Without this, the `created_at` / `updated_at` (time-encoded micros) and
`size` (int-encoded) cursor paths go entirely unexercised end-to-end.
"""
# Sizes increase strictly by index, so `size desc` has a deterministic
# expected order. Time-based sorts (created_at / updated_at) can tie when
# rows are inserted faster than the DB's timestamp resolution; for those
# we check coverage and no-duplicates and let the keyset tiebreaker do
# the rest, instead of sleeping between inserts and asserting an order
# that depends on clock granularity.
names = []
for i in range(4):
n = f"cursor_{sort_field}_{i:02d}.safetensors"
asset_factory(n, ["models", "checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i))
names.append(n)
params = {
"include_tags": f"unit-tests,cursor-{sort_field}",
"sort": sort_field,
"order": "desc",
"limit": "2",
}
seen: list[str] = []
after: str | None = None
pages = 0
while True:
page_params = dict(params)
if after is not None:
page_params["after"] = after
r = http.get(api_base + "/api/assets", params=page_params, timeout=120)
assert r.status_code == 200, r.text
body = r.json()
seen.extend(a["name"] for a in body["assets"])
after = body.get("next_cursor")
pages += 1
if after is None:
break
assert pages < 10, "guard against runaway cursor loop"
# No duplicates: a faulty keyset boundary that returns the same row across
# two pages must fail this check.
assert len(seen) == len(set(seen)), (
f"cursor walk repeated rows for sort={sort_field}: {seen}"
)
# Full coverage: every seeded asset reached exactly once.
assert set(seen) == set(names), (
f"missing items for sort={sort_field}: expected {set(names)}, got {set(seen)}"
)
# Strict order check for the only field with a clock-independent ordering.
if sort_field == "size":
assert seen == list(reversed(names)), (
f"size cursor walked out of order: got {seen}, expected {list(reversed(names))}"
)
def test_cursor_order_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
"""A cursor minted under desc order replayed against asc must 400, not
silently walk the wrong direction."""
_seed(asset_factory, make_asset_bytes, count=3, tag="cursor-order-flip")
r = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-order-flip",
"sort": "name",
"order": "desc",
"limit": "1",
},
timeout=120,
)
assert r.status_code == 200, r.text
cursor = r.json()["next_cursor"]
assert cursor is not None
# Replay with order flipped to asc — server must reject the cursor.
r2 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-order-flip",
"sort": "name",
"order": "asc",
"limit": "1",
"after": cursor,
},
timeout=120,
)
assert r2.status_code == 400, r2.text
assert r2.json()["error"]["code"] == "INVALID_CURSOR"
def test_cursor_invalid_cursor_at_microsecond_boundary(http: requests.Session, api_base: str):
"""A cursor carrying an out-of-range microsecond timestamp must map to
400 INVALID_CURSOR, not 500."""
import base64
import json
# 10^18 microseconds ≈ year 33658, well past datetime.MAX_YEAR.
# `o` and `order=` must be set; otherwise decode fails earlier on the
# missing-order branch and the µs-overflow path is never exercised.
payload = {"s": "created_at", "o": "desc", "v": "999999999999999999999", "id": "asset-x"}
raw = json.dumps(payload, separators=(",", ":")).encode("utf-8")
cursor = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
r = http.get(
api_base + "/api/assets",
params={"after": cursor, "sort": "created_at", "order": "desc"},
timeout=120,
)
assert r.status_code == 400, r.text
assert r.json()["error"]["code"] == "INVALID_CURSOR"
def test_cursor_pagination_stable_after_delete(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-delete")
# Page 1.
r = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-delete",
"sort": "name",
"order": "asc",
"limit": "2",
},
timeout=120,
)
assert r.status_code == 200
body = r.json()
page1_names = [a["name"] for a in body["assets"]]
cursor = body["next_cursor"]
assert cursor is not None
assert page1_names == names[:2]
# Delete an item from page 1 (already returned) — cursor should still
# locate the next page from where it was minted, not re-index.
target_id = body["assets"][0]["id"]
d = http.delete(api_base + f"/api/assets/{target_id}", timeout=120)
assert d.status_code in (200, 204), d.text
# Page 2 via cursor.
r2 = http.get(
api_base + "/api/assets",
params={
"include_tags": "unit-tests,cursor-delete",
"sort": "name",
"order": "asc",
"limit": "2",
"after": cursor,
},
timeout=120,
)
assert r2.status_code == 200, r2.text
body2 = r2.json()
assert [a["name"] for a in body2["assets"]] == names[2:]

View File

@ -3,6 +3,7 @@ import uuid
import pytest
import requests
from helpers import assert_hash_fields_consistent
def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
@ -26,6 +27,10 @@ def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asse
got1 = [a["name"] for a in b1["assets"]]
assert got1 == sorted(names)[:2]
assert b1["has_more"] is True
# Populated assets in list responses must carry both `hash` and `asset_hash` consistently
for asset in b1["assets"]:
assert_hash_fields_consistent(asset)
assert "hash" in asset, "populated asset must emit hash on list endpoint"
r2 = http.get(
api_base + "/api/assets",

View File

@ -29,7 +29,10 @@ def create_seed_file(comfy_tmp_base_dir: Path):
def find_asset(http: requests.Session, api_base: str):
"""Query API for assets matching scope and optional name."""
def _find(scope: str, name: str | None = None) -> list[dict]:
params = {"include_tags": f"unit-tests,{scope}"}
# Scanner now emits tags as ``[root, "<dir1>/<dir2>/..."]`` rather than
# one tag per directory. For files at ``<root>/unit-tests/<scope>/...``
# the second tag is exactly ``"unit-tests/<scope>"``.
params = {"include_tags": f"unit-tests/{scope}"}
if name:
params["name_contains"] = name
r = http.get(f"{api_base}/api/assets", params=params, timeout=120)
@ -138,4 +141,7 @@ def test_special_chars_in_path_escaped_correctly(
trigger_sync_seed_assets(http, api_base)
trigger_sync_seed_assets(http, api_base)
assert find_asset(scope.split("/")[0], fp.name), "Asset with special chars should survive"
# Scanner emits the full parent subpath as a single slash-joined tag, so
# the lookup tag is ``unit-tests/<scope>`` even when <scope> itself
# contains a slash (parent + special-char dirname).
assert find_asset(scope, fp.name), "Asset with special chars should survive"

View File

@ -5,6 +5,21 @@ from concurrent.futures import ThreadPoolExecutor
import requests
import pytest
from app.assets.api.schemas_out import Asset, AssetCreated
from helpers import get_asset_filename
def test_asset_created_inherits_hash_field():
"""AssetCreated must inherit `hash` from Asset so POST /api/assets responses emit it.
Schema-level guard: integration tests cover the wire shape, but inheritance
drift (e.g. AssetCreated ever being redefined to no longer extend Asset)
would silently drop `hash` from a major endpoint without this check.
"""
assert "hash" in Asset.model_fields
assert "hash" in AssetCreated.model_fields
assert AssetCreated.model_fields["hash"].annotation == Asset.model_fields["hash"].annotation
def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes):
name = "dup_a.safetensors"
@ -17,6 +32,7 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma
a1 = r1.json()
assert r1.status_code == 201, a1
assert a1["created_new"] is True
assert a1["hash"] == a1["asset_hash"]
# Second upload with the same data and name creates a new AssetReference (duplicates allowed)
# Returns 200 because Asset already exists, but a new AssetReference is created
@ -26,6 +42,7 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma
a2 = r2.json()
assert r2.status_code in (200, 201), a2
assert a2["asset_hash"] == a1["asset_hash"]
assert a2["hash"] == a1["hash"]
assert a2["id"] != a1["id"] # new reference with same content
# Third upload with the same data but different name also creates new AssetReference
@ -50,6 +67,7 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
b1 = r1.json()
assert r1.status_code == 201, b1
h = b1["asset_hash"]
assert b1["hash"] == h
# Now POST /api/assets with only hash and no file
files = [
@ -63,6 +81,15 @@ def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_
assert r2.status_code == 200, b2 # fast path returns 200 with created_new == False
assert b2["created_new"] is False
assert b2["asset_hash"] == h
assert b2["hash"] == h
assert b2.get("file_path") is None
assert b2.get("display_name") is None
rg = http.get(f"{api_base}/api/assets/{b2['id']}", timeout=120)
detail = rg.json()
assert rg.status_code == 200, detail
assert detail.get("file_path") is None
assert detail.get("display_name") is None
def test_upload_fastpath_with_known_hash_and_file(
@ -75,6 +102,7 @@ def test_upload_fastpath_with_known_hash_and_file(
b1 = r1.json()
assert r1.status_code == 201, b1
h = b1["asset_hash"]
assert b1["hash"] == h
# Send both file and hash of existing content -> server must drain file and create from hash (200)
files = {"file": ("ignored.bin", b"ignored" * 10, "application/octet-stream")}
@ -84,6 +112,7 @@ def test_upload_fastpath_with_known_hash_and_file(
assert r2.status_code == 200, b2
assert b2["created_new"] is False
assert b2["asset_hash"] == h
assert b2["hash"] == h
def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str):
@ -107,6 +136,54 @@ def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
@pytest.mark.parametrize(
("tags", "extension", "expected_prefix", "expected_display_prefix"),
[
(["input", "unit-tests"], ".png", "input", ""),
(["models", "checkpoints", "unit-tests"], ".safetensors", "models/checkpoints", ""),
],
)
def test_upload_response_includes_file_path_and_display_name(
tags: list[str],
extension: str,
expected_prefix: str,
expected_display_prefix: str,
http: requests.Session,
api_base: str,
asset_factory,
make_asset_bytes,
):
scope = f"response-paths-{uuid.uuid4().hex[:6]}"
scoped_tags = [*tags, scope]
name = f"asset_response_path{extension}"
created = asset_factory(name, scoped_tags, {}, make_asset_bytes(name, 1024))
stored_filename = get_asset_filename(created["asset_hash"], extension)
expected_suffix = f"unit-tests/{scope}/{stored_filename}"
expected_file_path = f"{expected_prefix}/{expected_suffix}"
expected_display_name = f"{expected_display_prefix}{expected_suffix}"
assert created["file_path"] == expected_file_path
assert created["display_name"] == expected_display_name
detail_r = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
detail = detail_r.json()
assert detail_r.status_code == 200, detail
assert detail["file_path"] == expected_file_path
assert detail["display_name"] == expected_display_name
list_r = http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "limit": "50"},
timeout=120,
)
listed = list_r.json()
assert list_r.status_code == 200, listed
match = next(a for a in listed["assets"] if a["id"] == created["id"])
assert match["file_path"] == expected_file_path
assert match["display_name"] == expected_display_name
@pytest.mark.parametrize("root", ["input", "output"])
def test_concurrent_upload_identical_bytes_different_names(
root: str,
@ -142,6 +219,8 @@ def test_concurrent_upload_identical_bytes_different_names(
assert r1.status_code in (200, 201), b1
assert r2.status_code in (200, 201), b2
assert b1["asset_hash"] == b2["asset_hash"]
assert b1["hash"] == b2["hash"]
assert b1["hash"] == b1["asset_hash"]
assert b1["id"] != b2["id"]
created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))])

View File

@ -0,0 +1,135 @@
"""HTTP-layer smoke test: user-added tags via POST /api/assets/{id}/tags
land after path tags when read back via GET /api/assets.
Exercises the full route handler -> service -> query path that the unit
tests at tests-unit/assets_test/queries/test_asset_info.py only cover at
the service layer.
"""
import json
import pytest
import requests
@pytest.fixture
def smoke_asset(http: requests.Session, api_base: str):
"""Upload a single asset into models/checkpoints/unit-tests/smoke
and delete it on teardown."""
name = "smoke_user_tag.safetensors"
tags = ["models", "checkpoints", "unit-tests", "smoke"]
files = {"file": (name, b"S" * 4096, "application/octet-stream")}
form_data = {
"tags": json.dumps(tags),
"name": name,
"user_metadata": json.dumps({}),
}
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
assert r.status_code == 201, r.text
body = r.json()
yield body
http.delete(
f"{api_base}/api/assets/{body['id']}?delete_content=true", timeout=30
)
def _fetch_asset_tags(http, api_base, ref_id):
r = http.get(f"{api_base}/api/assets/{ref_id}", timeout=30)
assert r.status_code == 200, r.text
return r.json()["tags"]
def test_user_tag_lands_after_path_tags_via_http(
http: requests.Session, api_base: str, smoke_asset: dict
):
ref_id = smoke_asset["id"]
initial_tags = _fetch_asset_tags(http, api_base, ref_id)
# Path tags should already be at the front in upload order.
assert initial_tags[:2] == ["models", "checkpoints"]
# Add a user tag that would jump to position 0 under alphabetical sort.
r = http.post(
f"{api_base}/api/assets/{ref_id}/tags",
json={"tags": ["aaa-user-tag"]},
timeout=30,
)
assert r.status_code in (200, 201), r.text
tags_after = _fetch_asset_tags(http, api_base, ref_id)
# Path tags must still be at the front; user tag goes to the end.
assert tags_after[0] == "models"
assert tags_after[1] == "checkpoints"
assert "aaa-user-tag" in tags_after
assert tags_after[-1] == "aaa-user-tag"
def test_user_tag_batch_lands_after_path_tags_via_http(
http: requests.Session, api_base: str, smoke_asset: dict
):
ref_id = smoke_asset["id"]
# Add three user tags in a single request, in non-alphabetical input
# order. They should all land after the path tags (microsecond stagger
# in set_reference_tags / add_tags_to_reference is what makes this
# work — without it, "aaa" would jump to position 0).
r = http.post(
f"{api_base}/api/assets/{ref_id}/tags",
json={"tags": ["zzz-z", "favorite", "aaa-experiment"]},
timeout=30,
)
assert r.status_code in (200, 201), r.text
tags_after = _fetch_asset_tags(http, api_base, ref_id)
assert tags_after[0] == "models"
assert tags_after[1] == "checkpoints"
user_tail = tags_after[len({"models", "checkpoints", "unit-tests", "smoke"}):]
assert set(user_tail) >= {"zzz-z", "favorite", "aaa-experiment"}
# Critically: alphabetical sort would put 'aaa-experiment' at position 0.
assert tags_after.index("aaa-experiment") > tags_after.index("models")
assert tags_after.index("aaa-experiment") > tags_after.index("checkpoints")
@pytest.fixture
def nested_checkpoint_asset(http: requests.Session, api_base: str):
"""Upload a checkpoint at the slash-joined path shape cloud emits
(`models/checkpoints/flux/...`), then delete it on teardown.
"""
name = "nested_checkpoint.safetensors"
tags = ["models", "checkpoints/flux"]
files = {"file": (name, b"S" * 4096, "application/octet-stream")}
form_data = {
"tags": json.dumps(tags),
"name": name,
"user_metadata": json.dumps({}),
}
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
assert r.status_code == 201, r.text
body = r.json()
yield body
http.delete(
f"{api_base}/api/assets/{body['id']}?delete_content=true", timeout=30
)
def test_nested_checkpoint_satisfies_fe_set_filter(
http: requests.Session, api_base: str, nested_checkpoint_asset: dict
):
"""The case Simon flagged: a nested-path checkpoint must still match
`include_tags=models,checkpoints` — the FE combo-widget filter.
"""
ref_id = nested_checkpoint_asset["id"]
stored = _fetch_asset_tags(http, api_base, ref_id)
# tag[1] keeps cloud's slash-joined positional contract; tag[2] holds
# the standalone bucket the FE filter looks for.
assert stored[:3] == ["models", "checkpoints/flux", "checkpoints"]
# The actual FE query — exact set-membership across both tokens.
r = http.get(
f"{api_base}/api/assets",
params=[("include_tags", "models"), ("include_tags", "checkpoints")],
timeout=30,
)
assert r.status_code == 200, r.text
returned_ids = {a["id"] for a in r.json()["assets"]}
assert ref_id in returned_ids

View File

@ -0,0 +1,245 @@
"""Tests for enrich_output_with_assets in comfy_execution/asset_enrichment.py."""
import os
import types
import unittest
from unittest.mock import MagicMock, patch
def _make_args(enable_assets: bool):
a = types.SimpleNamespace()
a.enable_assets = enable_assets
return a
def _make_db_ref(ref_id="ref-id-1"):
ref = MagicMock()
ref.id = ref_id
return ref
def _make_register_result(ref_id="ref-id-2"):
result = MagicMock()
result.ref.id = ref_id
return result
# Platform-appropriate absolute base. tempfile.gettempdir() returns C:\... on
# Windows and /tmp on POSIX, so containment via commonpath behaves naturally.
_DEFAULT_BASE = os.path.join(__import__("tempfile").gettempdir(), "asset-enrichment-test-base")
def _call(output_ui, *, enable_assets=True, file_exists=True, db_ref=None, register_result=None, directory=_DEFAULT_BASE):
fake_session_cm = MagicMock()
fake_session_cm.__enter__ = MagicMock(return_value=MagicMock())
fake_session_cm.__exit__ = MagicMock(return_value=False)
mocked_modules = {
"comfy.cli_args": MagicMock(args=_make_args(enable_assets)),
"folder_paths": MagicMock(get_directory_by_type=MagicMock(return_value=directory)),
"app.assets.services.ingest": MagicMock(
register_file_in_place=MagicMock(return_value=register_result or _make_register_result()),
DependencyMissingError=type("DependencyMissingError", (Exception,), {}),
),
"app.assets.database.queries.asset_reference": MagicMock(
get_reference_by_file_path=MagicMock(return_value=db_ref),
),
"app.database.db": MagicMock(create_session=MagicMock(return_value=fake_session_cm)),
}
# Only os.path.isfile is patched — abspath/join must run natively so the
# containment check sees real platform paths.
with patch.dict("sys.modules", mocked_modules), \
patch("os.path.isfile", return_value=file_exists):
import importlib
import comfy_execution.asset_enrichment as mod
importlib.reload(mod)
return mod.enrich_output_with_assets(output_ui)
class TestEnrichOutputWithAssets(unittest.TestCase):
def test_disabled_returns_unchanged(self):
output = {"images": [{"filename": "a.png", "subfolder": "", "type": "output"}]}
result = _call(output, enable_assets=False)
self.assertNotIn("id", result["images"][0])
def test_non_list_value_passed_through(self):
output = {"text": "hello"}
result = _call(output)
self.assertEqual(result["text"], "hello")
def test_entry_without_filename_unchanged(self):
output = {"latent": [{"subfolder": "", "type": "output"}]}
result = _call(output)
self.assertNotIn("id", result["latent"][0])
def test_entry_without_type_unchanged(self):
output = {"data": [{"filename": "a.png", "subfolder": ""}]}
result = _call(output)
self.assertNotIn("id", result["data"][0])
def test_file_not_on_disk_unchanged(self):
output = {"images": [{"filename": "missing.png", "subfolder": "", "type": "output"}]}
result = _call(output, file_exists=False)
self.assertNotIn("id", result["images"][0])
def test_unknown_type_returns_none_directory_unchanged(self):
output = {"images": [{"filename": "a.png", "subfolder": "", "type": "unknown"}]}
result = _call(output, directory=None)
self.assertNotIn("id", result["images"][0])
def test_db_hit_injects_id(self):
db_ref = _make_db_ref(ref_id="db-ref")
output = {"images": [{"filename": "a.png", "subfolder": "", "type": "output"}]}
result = _call(output, db_ref=db_ref)
img = result["images"][0]
self.assertEqual(img["id"], "db-ref")
# Only id is injected — no asset_hash, name, preview_url, size
self.assertNotIn("asset_hash", img)
self.assertNotIn("name", img)
self.assertNotIn("preview_url", img)
self.assertNotIn("size", img)
def test_db_miss_falls_back_to_register(self):
reg = _make_register_result(ref_id="inline-ref")
output = {"images": [{"filename": "new.png", "subfolder": "", "type": "output"}]}
result = _call(output, db_ref=None, register_result=reg)
img = result["images"][0]
self.assertEqual(img["id"], "inline-ref")
self.assertNotIn("asset_hash", img)
self.assertNotIn("name", img)
def test_original_entry_not_mutated(self):
orig = {"filename": "a.png", "subfolder": "", "type": "output"}
output = {"images": [orig]}
_call(output)
self.assertNotIn("id", orig)
def test_enrichment_error_does_not_block_sibling_entries(self):
call_count = [0]
good_reg = _make_register_result(ref_id="good-ref")
def register_side_effect(abs_path, name, tags):
call_count[0] += 1
if call_count[0] == 1:
raise RuntimeError("boom")
return good_reg
fake_session_cm = MagicMock()
fake_session_cm.__enter__ = MagicMock(return_value=MagicMock())
fake_session_cm.__exit__ = MagicMock(return_value=False)
mocked_modules = {
"comfy.cli_args": MagicMock(args=_make_args(True)),
"folder_paths": MagicMock(get_directory_by_type=MagicMock(return_value=_DEFAULT_BASE)),
"app.assets.services.ingest": MagicMock(
register_file_in_place=register_side_effect,
DependencyMissingError=type("DependencyMissingError", (Exception,), {}),
),
"app.assets.database.queries.asset_reference": MagicMock(
get_reference_by_file_path=MagicMock(return_value=None),
),
"app.database.db": MagicMock(create_session=MagicMock(return_value=fake_session_cm)),
}
output = {
"images": [
{"filename": "bad.png", "subfolder": "", "type": "output"},
{"filename": "good.png", "subfolder": "", "type": "output"},
]
}
with patch.dict("sys.modules", mocked_modules), \
patch("os.path.isfile", return_value=True):
import importlib
import comfy_execution.asset_enrichment as mod
importlib.reload(mod)
result = mod.enrich_output_with_assets(output)
imgs = result["images"]
self.assertNotIn("id", imgs[0])
self.assertEqual(imgs[1]["id"], "good-ref")
def test_multiple_output_keys_all_enriched(self):
output = {
"images": [{"filename": "a.png", "subfolder": "", "type": "output"}],
"videos": [{"filename": "b.mp4", "subfolder": "", "type": "output"}],
}
result = _call(output)
self.assertIn("id", result["images"][0])
self.assertIn("id", result["videos"][0])
def test_none_entry_in_list_unchanged(self):
output = {"images": [None, {"filename": "a.png", "subfolder": "", "type": "output"}]}
result = _call(output)
self.assertIsNone(result["images"][0])
self.assertIn("id", result["images"][1])
def test_path_traversal_subfolder_skipped(self):
fake_session_cm = MagicMock()
fake_session_cm.__enter__ = MagicMock(return_value=MagicMock())
fake_session_cm.__exit__ = MagicMock(return_value=False)
register_mock = MagicMock(return_value=_make_register_result())
mocked_modules = {
"comfy.cli_args": MagicMock(args=_make_args(True)),
"folder_paths": MagicMock(get_directory_by_type=MagicMock(return_value=_DEFAULT_BASE)),
"app.assets.services.ingest": MagicMock(
register_file_in_place=register_mock,
DependencyMissingError=type("DependencyMissingError", (Exception,), {}),
),
"app.assets.database.queries.asset_reference": MagicMock(
get_reference_by_file_path=MagicMock(return_value=None),
),
"app.database.db": MagicMock(create_session=MagicMock(return_value=fake_session_cm)),
}
output = {"images": [{"filename": "passwd", "subfolder": "../../etc", "type": "output"}]}
# Do NOT patch os.path.abspath — real resolution is required for the containment check.
with patch.dict("sys.modules", mocked_modules), \
patch("os.path.isfile", return_value=True):
import importlib
import comfy_execution.asset_enrichment as mod
importlib.reload(mod)
result = mod.enrich_output_with_assets(output)
self.assertNotIn("id", result["images"][0])
register_mock.assert_not_called()
def test_absolute_filename_skipped(self):
fake_session_cm = MagicMock()
fake_session_cm.__enter__ = MagicMock(return_value=MagicMock())
fake_session_cm.__exit__ = MagicMock(return_value=False)
register_mock = MagicMock(return_value=_make_register_result())
mocked_modules = {
"comfy.cli_args": MagicMock(args=_make_args(True)),
"folder_paths": MagicMock(get_directory_by_type=MagicMock(return_value=_DEFAULT_BASE)),
"app.assets.services.ingest": MagicMock(
register_file_in_place=register_mock,
DependencyMissingError=type("DependencyMissingError", (Exception,), {}),
),
"app.assets.database.queries.asset_reference": MagicMock(
get_reference_by_file_path=MagicMock(return_value=None),
),
"app.database.db": MagicMock(create_session=MagicMock(return_value=fake_session_cm)),
}
# Absolute filename — os.path.join discards earlier components when a later one is absolute.
absolute_filename = os.path.abspath(os.sep + "etc" + os.sep + "passwd")
output = {"images": [{"filename": absolute_filename, "subfolder": "", "type": "output"}]}
with patch.dict("sys.modules", mocked_modules), \
patch("os.path.isfile", return_value=True):
import importlib
import comfy_execution.asset_enrichment as mod
importlib.reload(mod)
result = mod.enrich_output_with_assets(output)
self.assertNotIn("id", result["images"][0])
register_mock.assert_not_called()
if __name__ == "__main__":
unittest.main()