Compare commits

..

12 Commits

Author SHA1 Message Date
dc6190e8ba fix(assets): seed added_at past max(existing) to survive Windows clock collisions
The per-tag microsecond stagger preserves intra-batch order, but two
back-to-back write batches on the same reference (e.g.
set_reference_tags for path tags, then add_tags_to_reference for user
tags) call get_utc_now() independently. On Windows the system clock can
return the same datetime for both calls if no OS tick elapsed between
the commits — both batches end up sharing microseconds and
ORDER BY added_at, tag_name falls back to the alphabetic tiebreaker,
sorting user tags ahead of path tags they were meant to follow.

Add _next_added_at_base(reference_id) that reads max(existing added_at)
and returns max(existing + 1us, get_utc_now()), guaranteeing the new
batch sorts strictly after anything previously written for that
reference. Used by set_reference_tags and add_tags_to_reference;
batch_insert_seed_assets stays on raw get_utc_now() since seed inserts
are always the first writes for a new reference.

The accompanying regression test pins get_utc_now() to a frozen value
so the previously-Windows-only race becomes a platform-independent
failure mode under test.
2026-05-20 20:33:39 -07:00
2d21956ac7 fix(assets): expand standalone bucket tag for nested category paths
Path-derived tags for nested model layouts (e.g.
models/checkpoints/flux/foo.safetensors) emitted only the slash-joined
shape `["models", "checkpoints/flux"]`, which broke the frontend
combo-widget set-membership filter `include_tags=models,checkpoints` —
the literal `checkpoints` token was no longer present in the asset's
tag set.

Add `expand_bucket_prefixes` at the tag-write layer. When a tag's first
slash segment is a registered model category (or input/output/temp
root), the bucket is inserted as a standalone token immediately after
the slash-joined form. This preserves tag[1] as the slash-joined
positional contract cloud emits while restoring the set-membership
token the frontend filter requires.

The expansion is bounded to known buckets so free-form user labels
with slashes (`my-org/team-a`) pass through unchanged. The helper is
applied uniformly in `set_reference_tags`, `add_tags_to_reference`,
and `batch_insert_seed_assets` so HTTP uploads, user-tag mutations,
and path-scanning ingest all converge on the same canonical shape.

Also align the upload-route category validator with
`resolve_destination_from_tags` by extracting the first slash segment
of tag[1], so HTTP uploads matching cloud's slash-joined emission
shape are no longer rejected as `unknown models category`.
2026-05-20 20:33:39 -07:00
396bfe4056 Merge branch 'master' into matt/asset-tags-cloud-shape 2026-05-20 19:20:33 -07:00
00940fb24e fix(assets): preserve caller order in add_tags_to_reference + align response helper
Smoke test through the real HTTP upload + tag-add path exposed two
ordering bugs the unit-layer tests missed:

1. add_tags_to_reference did `to_add = sorted(want - current)` — an
   alphabetical pre-sort defeating the microsecond-stagger fix from the
   previous commit. The stagger was encoding alphabetical positions,
   not the caller's insertion order. Fix: build to_add by walking the
   already-normalized caller list and filtering against the current
   set, so the staggered added_at timestamps reflect what the caller
   actually requested.

2. get_reference_tags used .order_by(tag_name.asc()) — alphabetical.
   It's called by the upload response path; meanwhile
   list_references_page and fetch_reference_asset_and_tags were already
   updated to order by added_at. The mismatch meant POST /api/assets
   returned tags in alphabetical order but a subsequent GET returned
   them in insertion order. Fix: order get_reference_tags by added_at
   too, so all three response-path helpers agree.

New tests-unit/assets_test/test_user_tag_http_smoke.py exercises the
full HTTP layer: POST /api/assets to upload, POST /api/assets/{id}/tags
to add a user tag (using tag names like "aaa-user-tag" that would jump
to position 0 under alphabetical), GET /api/assets/{id} to verify
ordering. Catches the bugs above in CI going forward.

Full assets suite: 340 passed, 10 pre-existing skipped.
2026-05-19 21:10:53 -07:00
7ff001d7c8 fix(assets): stagger added_at in set_reference_tags + add ordering tests
Cursor-reviews follow-up on PR #13994:

1. set_reference_tags / add_tags_to_reference now apply the same
   microsecond stagger as batch_insert_seed_assets. Per-tag get_utc_now()
   calls can collide at microsecond resolution on fast machines, dropping
   retrieval to the tag_name alphabetical tiebreaker. Using a single
   base_ts + timedelta(microseconds=i) preserves insertion order for any
   batch.

2. Docstring on get_name_and_tags_from_asset_path corrected: only the
   subpath is lowercased in code; the root category is lowercase by
   construction in get_asset_category_and_relative_path.

3. resolve_destination_from_tags docstring now states explicitly that
   hybrid shapes (mix of legacy multi-tag + new slash-joined within a
   single call) are accepted and resolve to the same destination.

4. New TestTagRetrievalOrder class in test_asset_info.py exercises the
   public write paths (set_reference_tags, add_tags_to_reference,
   remove_tags_from_reference) and asserts the public read paths
   (list_references_page, fetch_reference_asset_and_tags) return tags
   in insertion order rather than alphabetical. Tag names are chosen
   to fail loudly under alphabetical regression — "checkpoints" sorts
   before "models", "aaa-user-tag" sorts before every path tag, etc.

Full assets suite: 338 passed, 10 pre-existing skipped.
2026-05-19 21:05:54 -07:00
19ba85bb2e Merge branch 'master' into matt/asset-tags-cloud-shape 2026-05-19 20:48:47 -07:00
3ffc49aa0e fix(assets): lowercase subpath, parse slash-joined upload tags, stagger added_at
Three bugs surfaced by an end-to-end smoke test of the read+write
round-trip; all in this PR's scope.

1. FK violation on uppercase paths
   get_name_and_tags_from_asset_path was preserving case on the
   subpath (e.g. "diffusers/Kolors/text_encoder"). ensure_tags_exist
   lowercases via normalize_tags before inserting into the tags
   table, so the asset_reference_tags.tag_name FK to tags.name
   failed for any path containing uppercase letters — including
   the diffusers case the PR was designed to support.

   Fix: lowercase the slash-joined subpath in
   get_name_and_tags_from_asset_path to match the canonicalization
   ensure_tags_exist applies. Providers keyed on original-case
   subpaths need to normalize their lookup key to lowercase.

2. resolve_destination_from_tags rejected the new tag shape
   The inverse function only accepted the legacy one-tag-per-dir
   shape (["models", "diffusers", "Kolors", "text_encoder"]).
   An upload using the slash-joined shape returned by /api/assets
   raised "unknown model category" or "invalid path component".

   Fix: pre-split every entry after tags[0] on "/" so both shapes
   resolve identically. For models, the first expanded segment is
   the category and the rest are subdirs; for input/output the
   full expansion becomes the subdirs.

3. Within-batch tag order was lost
   bulk_ingest wrote every tag in a single batch with the same
   added_at = current_time. The retrieval ORDER BY added_at, tag_name
   then fell back to the tag_name tiebreaker, sorting the path-derived
   pair alphabetically — putting "checkpoints/..." ahead of "models"
   since "c" < "m". The tags[0] = root contract was lost on bulk-
   ingested rows.

   Fix: stagger added_at by microseconds per tag index within a
   reference so the retrieval order matches the input list order.
   Path-derived tags now consistently land in position-0 = root,
   position-1 = subpath.

Tests
- TestGetNameAndTagsFromAssetPath updated: subpath is now lowercase.
- New TestResolveDestinationFromTags covers both tag shapes, the
  unknown-category case for slash-joined input, traversal rejection,
  and input/output paths.
- Full suite: 333 passed, 10 pre-existing skipped.
2026-05-19 20:30:04 -07:00
36f9a6fdef feat(assets): preserve insertion order on tag retrieval
The /api/assets response previously sorted tags alphabetically via
.order_by(Tag.name.asc()). That breaks the structurally meaningful
"root category first, then subpath" invariant the path-collapsing
change relies on: alphabetical sort puts a custom user tag (or even
the bare "models" root) at unpredictable positions, so positional
access like tags[1] is not reliable on local.

Cloud already preserves insertion order — its Ent WithTags() eager-
load has no explicit ORDER BY, so Postgres returns rows in physical
insertion order. Local's composite primary key on
(asset_reference_id, tag_name) means SQLite walks the index in
tag_name order even without an explicit ORDER BY, so just dropping
the clause isn't enough.

Switching to ORDER BY added_at ASC, tag_name ASC keeps the path
tags inserted via set_reference_tags in their original order
(microsecond-resolution timestamps disambiguate same-batch inserts;
tag_name is a deterministic tiebreaker for the rare collision case).
Custom tags added later via add_tags_to_reference land after the
path tags in their own added_at bucket.

Applies to both response-shaping queries:
- list_references_page (GET /api/assets, tag_map join)
- fetch_reference_asset_and_tags (GET /api/assets/{id})

Catalog/histogram queries in app/assets/database/queries/tags.py
keep their alphabetical sort — those endpoints are listing all tags,
not per-asset tags, and alphabetical is the right shape there.
2026-05-19 20:14:01 -07:00
a0d1238829 Merge branch 'master' into matt/asset-tags-cloud-shape 2026-05-19 20:06:12 -07:00
1688a5e262 Merge branch 'master' into matt/asset-tags-cloud-shape 2026-05-19 15:00:22 -07:00
7ab346fc7b chore(assets): drop unused normalize_tags import after subpath-collapse refactor
normalize_tags lowercased every tag, which would have stripped case from
the slash-joined subpath (e.g. "diffusers/Kolors/text_encoder" ->
"diffusers/kolors/text_encoder") and broken consumer lookups keyed on
the original-case path. The refactored implementation inlines a strip +
dedup so the import is no longer needed.
2026-05-19 14:51:00 -07:00
5b7288d700 feat(assets): collapse nested asset path into a single slash-joined tag
The /api/assets response previously emitted one tag per parent directory
between the root category and the filename. For nested categories like
diffusers, this produced ["models", "diffusers", "Kolors", "text_encoder"]
where consumers that look up a category via tags[1] would only see the
top-level bucket name and miss the model-specific sub-path that uniquely
identifies the component.

This collapses the parent subpath into a single slash-joined tag so the
result is ["models", "diffusers/Kolors/text_encoder"]. Consumers can now
read tags[1] as a stable category identifier regardless of how deep the
file lives in the bucket. Case is preserved on the subpath so providers
keyed on the original-case path (e.g. "diffusers/Kolors/text_encoder")
resolve correctly.

Same shape applies uniformly:

- input/foo.png                              -> ["input"]
- output/00001.png                           -> ["output"]
- models/checkpoints/flux.safetensors        -> ["models", "checkpoints"]
- models/diffusers/Kolors/text_encoder/m.sft -> ["models", "diffusers/Kolors/text_encoder"]
- models/loras/my/custom/path/v1.safetensors -> ["models", "loras/my/custom/path"]

Integration tests that filtered by individual subdirectory tags
(`include_tags=unit-tests,scope`) updated to use the new slash-joined
shape (`include_tags=unit-tests/scope`). Unit tests cover flat input,
flat output, flat models, diffusers-style nested, and deep user-subpath
cases.
2026-05-19 14:48:49 -07:00
45 changed files with 1327 additions and 5885 deletions

View File

@ -1,519 +0,0 @@
name: Backport Release
on:
workflow_dispatch:
inputs:
commit:
description: 'Full 40-char SHA of the tip commit of the backport source branch (the PR head commit that passed tests). The branch is resolved from this SHA and must be unique.'
required: true
type: string
permissions:
contents: read
pull-requests: read
checks: read
jobs:
backport-release:
name: Create backport release
runs-on: ubuntu-latest
environment: backport release
steps:
- name: Generate GitHub App token
id: app-token
uses: actions/create-github-app-token@bcd2ba49218906704ab6c1aa796996da409d3eb1
with:
app-id: ${{ secrets.FEN_RELEASE_APP_ID }}
private-key: ${{ secrets.FEN_RELEASE_PRIVATE_KEY }}
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
with:
token: ${{ steps.app-token.outputs.token }}
fetch-depth: 0
fetch-tags: true
- name: Configure git
run: |
git config user.name "fen-release[bot]"
git config user.email "fen-release[bot]@users.noreply.github.com"
- name: Resolve source branch from commit SHA
id: resolve
env:
SOURCE_COMMIT: ${{ inputs.commit }}
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
run: |
set -euo pipefail
# Require a full 40-char lowercase-hex SHA. Short SHAs are ambiguous
# and we will be comparing this value against API responses (PR head
# SHA, ref tips) that always return the full form.
if [[ ! "${SOURCE_COMMIT}" =~ ^[0-9a-f]{40}$ ]]; then
echo "::error::Input commit '${SOURCE_COMMIT}' is not a full 40-char lowercase hex SHA."
exit 1
fi
# Fetch all remote branches so we can search for which one(s) point
# at this SHA. `actions/checkout` with fetch-depth: 0 fetches full
# history of the checked-out ref but does not necessarily populate
# every refs/remotes/origin/*, so do it explicitly.
git fetch --prune origin '+refs/heads/*:refs/remotes/origin/*'
# Verify the commit actually exists in this repo's object DB.
if ! git cat-file -e "${SOURCE_COMMIT}^{commit}" 2>/dev/null; then
echo "::error::Commit ${SOURCE_COMMIT} was not found in the repository."
exit 1
fi
# Find every remote branch whose tip == SOURCE_COMMIT. Exactly one
# branch must point at it. If zero, the commit isn't anyone's tip
# (likely stale, force-pushed past, or never the PR head). If more
# than one, the (branch -> SHA) mapping is ambiguous and we refuse
# to guess — the operator must give us a unique branch to release.
mapfile -t matching_branches < <(
git for-each-ref \
--format='%(refname:strip=3)' \
--points-at="${SOURCE_COMMIT}" \
refs/remotes/origin/ \
| grep -vx 'HEAD' || true
)
if [[ "${#matching_branches[@]}" -eq 0 ]]; then
echo "::error::No branch on origin has ${SOURCE_COMMIT} as its tip."
echo "::error::Either the branch was updated after you copied this SHA, or this commit was never the head of a branch."
exit 1
fi
if [[ "${#matching_branches[@]}" -gt 1 ]]; then
echo "::error::More than one branch on origin has ${SOURCE_COMMIT} as its tip; cannot pick one:"
for b in "${matching_branches[@]}"; do
echo "::error:: - ${b}"
done
echo "::error::Refusing to proceed with an ambiguous source branch."
exit 1
fi
source_branch="${matching_branches[0]}"
if [[ "${source_branch}" == "${DEFAULT_BRANCH}" ]]; then
echo "::error::Source branch must not be the default branch ('${DEFAULT_BRANCH}')."
exit 1
fi
echo "Resolved commit ${SOURCE_COMMIT} to branch '${source_branch}'."
echo "source_branch=${source_branch}" >> "$GITHUB_OUTPUT"
- name: Determine latest stable release
id: latest
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
run: |
set -euo pipefail
# List all tags matching vMAJOR.MINOR.PATCH and pick the highest by numeric
# comparison of each component. We DO NOT use `sort -V` because it treats
# v0.19.99 as higher than v0.20.1.
latest_tag="$(
git tag --list 'v[0-9]*.[0-9]*.[0-9]*' \
| grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' \
| awk -F'[v.]' '{ printf "%010d %010d %010d %s\n", $2, $3, $4, $0 }' \
| sort -k1,1n -k2,2n -k3,3n \
| tail -n1 \
| awk '{print $4}'
)"
if [[ -z "${latest_tag}" ]]; then
echo "::error::No stable release tags (vMAJOR.MINOR.PATCH) were found."
exit 1
fi
# Parse components
ver="${latest_tag#v}"
major="${ver%%.*}"
rest="${ver#*.}"
minor="${rest%%.*}"
patch="${rest#*.}"
new_patch=$((patch + 1))
new_version="v${major}.${minor}.${new_patch}"
release_branch="release/v${major}.${minor}"
latest_sha="$(git rev-list -n 1 "refs/tags/${latest_tag}")"
echo "latest_tag=${latest_tag}" >> "$GITHUB_OUTPUT"
echo "latest_sha=${latest_sha}" >> "$GITHUB_OUTPUT"
echo "major=${major}" >> "$GITHUB_OUTPUT"
echo "minor=${minor}" >> "$GITHUB_OUTPUT"
echo "patch=${patch}" >> "$GITHUB_OUTPUT"
echo "new_version=${new_version}" >> "$GITHUB_OUTPUT"
echo "new_version_no_v=${major}.${minor}.${new_patch}" >> "$GITHUB_OUTPUT"
echo "release_branch=${release_branch}" >> "$GITHUB_OUTPUT"
echo "Latest stable release: ${latest_tag} (${latest_sha})"
echo "New version will be: ${new_version}"
echo "Release branch: ${release_branch}"
- name: Validate source branch is cut directly from the latest stable release
env:
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
run: |
set -euo pipefail
# Use the user-provided SHA directly rather than re-resolving the branch
# tip — the resolve step already proved the branch tip equals SOURCE_COMMIT,
# and pinning to the SHA here makes the rest of the job TOCTOU-safe against
# someone pushing to the branch mid-run.
source_sha="${SOURCE_COMMIT}"
# Walking first-parent from the source tip must reach LATEST_TAG_SHA.
# We capture rev-list into a variable and grep against a here-string
# rather than piping `rev-list | grep -q`: under `set -o pipefail`,
# `grep -q` would exit on first match and SIGPIPE the still-streaming
# `rev-list`, propagating exit 141 as a spurious "not found".
first_parent_chain="$(git rev-list --first-parent "${source_sha}")"
if ! grep -Fxq "${LATEST_TAG_SHA}" <<< "${first_parent_chain}"; then
echo "::error::Source branch '${SOURCE_BRANCH}' is not cut from '${LATEST_TAG}'."
echo "::error::Its first-parent history does not include ${LATEST_TAG_SHA}."
exit 1
fi
# Additionally, every commit added on top of the tag (the set we are
# about to publish) must itself be a descendant of the tag along
# first-parent — i.e. no sibling commits from master sneak in via a
# non-first-parent path. Enforce by requiring that the symmetric
# difference is empty in one direction: commits in source that are
# NOT first-parent-reachable from source starting at the tag.
# We do this by intersecting:
# A = commits reachable from source but not from tag (full DAG)
# B = commits on the first-parent chain from source down to tag
# and requiring A == B.
all_added="$(git rev-list "${LATEST_TAG_SHA}..${source_sha}" | sort)"
first_parent_added="$(
git rev-list --first-parent "${LATEST_TAG_SHA}..${source_sha}" | sort
)"
if [[ "${all_added}" != "${first_parent_added}" ]]; then
echo "::error::Source branch '${SOURCE_BRANCH}' contains commits not on its first-parent chain from '${LATEST_TAG}'."
echo "::error::This usually means the branch was cut from master (not from the tag) or contains a merge from master."
echo "Commits reachable but not on first-parent chain:"
comm -23 <(printf '%s\n' "${all_added}") <(printf '%s\n' "${first_parent_added}") \
| while read -r sha; do
echo " $(git log -1 --format='%h %s' "${sha}")"
done
exit 1
fi
added_count="$(printf '%s\n' "${all_added}" | grep -c . || true)"
echo "Source branch is cut directly from ${LATEST_TAG} with ${added_count} commit(s) on top."
- name: Validate PR exists, is open, named correctly, has latest commit, and checks pass
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
REPO: ${{ github.repository }}
run: |
set -euo pipefail
expected_title="ComfyUI backport release ${NEW_VERSION}"
# Find open PRs from this branch into master. The --state open filter
# is load-bearing: a closed/merged PR with passing checks must not be
# accepted as authorization for a new release.
pr_json="$(
gh pr list \
--repo "${REPO}" \
--state open \
--head "${SOURCE_BRANCH}" \
--base master \
--json number,title,headRefOid,state \
--limit 10
)"
pr_count="$(echo "${pr_json}" | jq 'length')"
if [[ "${pr_count}" -eq 0 ]]; then
echo "::error::No open PR found from '${SOURCE_BRANCH}' into 'master'. The PR must exist and be open."
exit 1
fi
# Pick the PR matching the expected title
pr_number="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
map(select(.title == $t)) | .[0].number // empty
')"
pr_head_sha="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
map(select(.title == $t)) | .[0].headRefOid // empty
')"
if [[ -z "${pr_number}" ]]; then
echo "::error::No open PR from '${SOURCE_BRANCH}' into 'master' is titled '${expected_title}'."
echo "Found PRs:"
echo "${pr_json}" | jq -r '.[] | " #\(.number): \(.title)"'
exit 1
fi
# The PR's current head commit must equal the SHA the operator gave us.
# This is what closes the door on releasing stale code: if anyone has
# pushed to the branch since the operator validated tests passed, the
# PR head will have advanced past SOURCE_COMMIT and we abort. (The
# resolve step already proved the branch tip == SOURCE_COMMIT; this
# ties that same SHA to the PR that authorizes the release.)
if [[ "${pr_head_sha}" != "${SOURCE_COMMIT}" ]]; then
echo "::error::PR #${pr_number} head commit is ${pr_head_sha}, but the operator-provided commit is ${SOURCE_COMMIT}."
echo "::error::The PR has new commits since this release was authorized. Re-run with the new head SHA after verifying its checks."
exit 1
fi
echo "Found open PR #${pr_number} titled '${expected_title}' at head ${pr_head_sha} (matches operator-provided commit)."
# Verify all check runs on the head commit have completed successfully.
# A check is considered passing if conclusion is success, neutral, or skipped.
checks_json="$(
gh api \
--paginate \
"repos/${REPO}/commits/${pr_head_sha}/check-runs" \
--jq '.check_runs[] | {name: .name, status: .status, conclusion: .conclusion}'
)"
if [[ -z "${checks_json}" ]]; then
echo "::error::No check runs found on PR head commit ${pr_head_sha}."
exit 1
fi
echo "Check runs on ${pr_head_sha}:"
echo "${checks_json}" | jq -s '.'
failing="$(echo "${checks_json}" | jq -s '
map(select(
.status != "completed"
or (.conclusion as $c
| ["success","neutral","skipped"]
| index($c) | not)
))
')"
failing_count="$(echo "${failing}" | jq 'length')"
if [[ "${failing_count}" -gt 0 ]]; then
echo "::error::One or more checks have not passed on PR head commit ${pr_head_sha}:"
echo "${failing}" | jq -r '.[] | " - \(.name): status=\(.status) conclusion=\(.conclusion)"'
exit 1
fi
echo "All checks have passed on ${pr_head_sha}."
- name: Prepare release branch
id: prepare
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
REPO: ${{ github.repository }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
PATCH: ${{ steps.latest.outputs.patch }}
run: |
set -euo pipefail
# Try to fetch the release branch. If patch == 0, it shouldn't exist yet
# and we'll create it from the latest stable tag. If patch > 0, it must
# already exist and its tip must equal the latest stable tag commit (i.e.
# the previous patch release).
if git ls-remote --exit-code --heads origin "${RELEASE_BRANCH}" >/dev/null 2>&1; then
echo "Release branch '${RELEASE_BRANCH}' already exists on origin."
git fetch origin "refs/heads/${RELEASE_BRANCH}:refs/remotes/origin/${RELEASE_BRANCH}"
git checkout -B "${RELEASE_BRANCH}" "refs/remotes/origin/${RELEASE_BRANCH}"
current_tip="$(git rev-parse HEAD)"
if [[ "${current_tip}" != "${LATEST_TAG_SHA}" ]]; then
echo "::error::Release branch '${RELEASE_BRANCH}' tip (${current_tip}) is not at the latest stable release '${LATEST_TAG}' (${LATEST_TAG_SHA})."
echo "::error::Refusing to release on top of a divergent branch."
exit 1
fi
echo "branch_existed=true" >> "$GITHUB_OUTPUT"
else
if [[ "${PATCH}" != "0" ]]; then
echo "::error::Release branch '${RELEASE_BRANCH}' does not exist on origin, but the latest stable release '${LATEST_TAG}' has patch=${PATCH} (>0). This is inconsistent."
exit 1
fi
echo "Release branch '${RELEASE_BRANCH}' does not exist. Creating from ${LATEST_TAG}."
git checkout -B "${RELEASE_BRANCH}" "refs/tags/${LATEST_TAG}"
echo "branch_existed=false" >> "$GITHUB_OUTPUT"
fi
- name: Fast-forward merge source branch into release branch
env:
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
run: |
set -euo pipefail
# --ff-only guarantees no merge commit is created. If a fast-forward is
# not possible (i.e. the release branch has commits the source branch
# doesn't), the merge will fail and we abort. Because we already validated
# that the source branch is rooted on the latest stable tag, and the
# release branch tip equals that same tag, this fast-forward should
# always succeed for a well-formed backport branch.
#
# We merge the operator-provided SHA, not the branch ref, so a push to
# the branch in the window between resolve and now cannot smuggle new
# commits into the release.
if ! git merge --ff-only "${SOURCE_COMMIT}"; then
echo "::error::Cannot fast-forward '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}'). A merge commit would be required. Aborting."
exit 1
fi
echo "Fast-forwarded '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}')."
- name: Bump version files
env:
NEW_VERSION_NO_V: ${{ steps.latest.outputs.new_version_no_v }}
run: |
set -euo pipefail
if [[ ! -f comfyui_version.py ]]; then
echo "::error::comfyui_version.py not found in repo root."
exit 1
fi
if [[ ! -f pyproject.toml ]]; then
echo "::error::pyproject.toml not found in repo root."
exit 1
fi
# Replace the version string in comfyui_version.py.
# Expected format: __version__ = "X.Y.Z"
python3 - "$NEW_VERSION_NO_V" <<'PY'
import re, sys, pathlib
new = sys.argv[1]
p = pathlib.Path("comfyui_version.py")
src = p.read_text()
new_src, n = re.subn(
r'(__version__\s*=\s*[\'"])[^\'"]+([\'"])',
lambda m: f'{m.group(1)}{new}{m.group(2)}',
src,
count=1,
)
if n != 1:
sys.exit("Could not find __version__ assignment in comfyui_version.py")
p.write_text(new_src)
p = pathlib.Path("pyproject.toml")
src = p.read_text()
# Replace the first `version = "..."` inside [project] or [tool.poetry].
new_src, n = re.subn(
r'(?m)^(version\s*=\s*")[^"]+(")',
lambda m: f'{m.group(1)}{new}{m.group(2)}',
src,
count=1,
)
if n != 1:
sys.exit("Could not find version assignment in pyproject.toml")
p.write_text(new_src)
PY
echo "Updated version to ${NEW_VERSION_NO_V} in comfyui_version.py and pyproject.toml."
git --no-pager diff -- comfyui_version.py pyproject.toml
- name: Commit version bump and tag release
env:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
run: |
set -euo pipefail
git add comfyui_version.py pyproject.toml
git commit -m "ComfyUI ${NEW_VERSION}"
if git rev-parse -q --verify "refs/tags/${NEW_VERSION}" >/dev/null; then
echo "::error::Tag ${NEW_VERSION} already exists locally."
exit 1
fi
git tag "${NEW_VERSION}"
- name: Verify tag does not already exist on origin
env:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
run: |
set -euo pipefail
if git ls-remote --exit-code --tags origin "refs/tags/${NEW_VERSION}" >/dev/null 2>&1; then
echo "::error::Tag ${NEW_VERSION} already exists on origin. Aborting."
exit 1
fi
- name: Push release branch and tag
env:
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
run: |
set -euo pipefail
# Push the branch first, then the tag. Atomic-ish: if the branch push
# fails we never publish the tag.
git push origin "refs/heads/${RELEASE_BRANCH}:refs/heads/${RELEASE_BRANCH}"
git push origin "refs/tags/${NEW_VERSION}"
echo "Released ${NEW_VERSION} on ${RELEASE_BRANCH}."
- name: Delete remote source branch
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
REPO: ${{ github.repository }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
run: |
set -euo pipefail
# Belt-and-braces: the resolve step already refuses the default branch,
# but never delete the default or the release branch under any
# circumstances.
if [[ "${SOURCE_BRANCH}" == "${DEFAULT_BRANCH}" || "${SOURCE_BRANCH}" == "${RELEASE_BRANCH}" ]]; then
echo "::error::Refusing to delete '${SOURCE_BRANCH}' (matches default or release branch)."
exit 1
fi
# Delete the source branch on origin, but only if its tip is still the
# SHA we released from. If someone pushed new commits to it after we
# resolved it, leave it alone — those commits would be silently lost.
current_tip="$(git ls-remote origin "refs/heads/${SOURCE_BRANCH}" | awk '{print $1}')"
if [[ -z "${current_tip}" ]]; then
echo "Source branch '${SOURCE_BRANCH}' no longer exists on origin; nothing to delete."
exit 0
fi
if [[ "${current_tip}" != "${SOURCE_COMMIT}" ]]; then
echo "::warning::Source branch '${SOURCE_BRANCH}' tip (${current_tip}) no longer matches released commit (${SOURCE_COMMIT}). Leaving it in place."
exit 0
fi
git push origin --delete "refs/heads/${SOURCE_BRANCH}"
echo "Deleted remote branch '${SOURCE_BRANCH}'."
- name: Summary
if: always()
env:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
run: |
# SOURCE_BRANCH is empty if the resolve step never produced an output
# (e.g. the workflow failed in or before that step). Show a placeholder
# in that case so the summary table still renders cleanly.
source_branch_display="${SOURCE_BRANCH:-(unresolved)}"
{
echo "## Backport release"
echo ""
echo "| Field | Value |"
echo "|---|---|"
echo "| Source commit | \`${SOURCE_COMMIT}\` |"
echo "| Source branch | \`${source_branch_display}\` |"
echo "| Previous stable | \`${LATEST_TAG}\` |"
echo "| New version | \`${NEW_VERSION}\` |"
echo "| Release branch | \`${RELEASE_BRANCH}\` |"
} >> "$GITHUB_STEP_SUMMARY"

View File

@ -20,7 +20,7 @@
[website-url]: https://www.comfy.org/
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
[discord-url]: https://discord.com/invite/comfyorg
[discord-url]: https://www.comfy.org/discord
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
[twitter-url]: https://x.com/ComfyUI

View File

@ -401,12 +401,16 @@ async def upload_asset(request: web.Request) -> web.Response:
)
if spec.tags and spec.tags[0] == "models":
# tag[1] may be the standalone category ("checkpoints") or the
# slash-joined shape ("checkpoints/flux/...") that
# `get_name_and_tags_from_asset_path` and cloud both emit. Match
# `resolve_destination_from_tags` by extracting the first segment.
category = spec.tags[1].split("/", 1)[0] if len(spec.tags) >= 2 else ""
if (
len(spec.tags) < 2
or spec.tags[1] not in folder_paths.folder_names_and_paths
or category not in folder_paths.folder_names_and_paths
):
delete_temp_file_if_exists(parsed.tmp_path)
category = spec.tags[1] if len(spec.tags) >= 2 else ""
return _build_error_response(
400, "INVALID_BODY", f"unknown models category '{category}'"
)

View File

@ -327,7 +327,12 @@ def list_references_page(
select(AssetReferenceTag.asset_reference_id, Tag.name)
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
.order_by(AssetReferenceTag.tag_name.asc())
# Preserve insertion order so the structural first tag (the root
# category like "models") stays in position 0 and the path-derived
# sub-path tag stays in position 1, matching cloud's behavior.
# tag_name is a deterministic tiebreaker when multiple tags share
# an added_at (same-batch insert via set_reference_tags).
.order_by(AssetReferenceTag.added_at.asc(), AssetReferenceTag.tag_name.asc())
)
for ref_id, tag_name in rows.all():
tag_map[ref_id].append(tag_name)
@ -355,7 +360,8 @@ def fetch_reference_asset_and_tags(
build_visible_owner_clause(owner_id),
)
.options(noload(AssetReference.tags))
.order_by(Tag.name.asc())
# See list_references_page for the rationale behind ordering by added_at.
.order_by(AssetReferenceTag.added_at.asc(), Tag.name.asc())
)
rows = session.execute(stmt).all()

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Iterable, Sequence
import sqlalchemy as sa
@ -20,7 +21,12 @@ from app.assets.database.queries.common import (
build_visible_owner_clause,
iter_row_chunks,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
from app.assets.helpers import (
escape_sql_like_string,
expand_bucket_prefixes,
get_utc_now,
normalize_tags,
)
@dataclass(frozen=True)
@ -44,6 +50,26 @@ class SetTagsResult:
total: list[str]
def _next_added_at_base(session: Session, reference_id: str) -> datetime:
"""Return a timestamp strictly greater than any existing
`added_at` for this reference. On platforms where the wall clock
has insufficient resolution between back-to-back commits (notably
Windows), two write batches on the same reference can otherwise
share a microsecond — the `ORDER BY added_at, tag_name` retrieval
then falls back to the alphabetic tiebreaker and user-tier tags
sort ahead of path-tier tags they were meant to follow.
"""
existing_max = session.execute(
sa.select(sa.func.max(AssetReferenceTag.added_at)).where(
AssetReferenceTag.asset_reference_id == reference_id
)
).scalar()
now = get_utc_now()
if existing_max is None:
return now
return max(existing_max + timedelta(microseconds=1), now)
def validate_tags_exist(session: Session, tags: list[str]) -> None:
"""Raise ValueError if any of the given tag names do not exist."""
existing_tag_names = set(
@ -77,7 +103,13 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
session.execute(
select(AssetReferenceTag.tag_name)
.where(AssetReferenceTag.asset_reference_id == reference_id)
.order_by(AssetReferenceTag.tag_name.asc())
# Match the response-path ordering used by
# list_references_page / fetch_reference_asset_and_tags so
# upload responses and subsequent GETs agree on tag order.
.order_by(
AssetReferenceTag.added_at.asc(),
AssetReferenceTag.tag_name.asc(),
)
)
).all()
]
@ -89,7 +121,7 @@ def set_reference_tags(
tags: Sequence[str],
origin: str = "manual",
) -> SetTagsResult:
desired = normalize_tags(tags)
desired = expand_bucket_prefixes(normalize_tags(tags))
current = set(get_reference_tags(session, reference_id))
@ -98,15 +130,22 @@ def set_reference_tags(
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
# Stagger added_at by microsecond per tag so the retrieval ORDER BY
# added_at preserves input order. Per-tag get_utc_now() calls can
# collide at microsecond resolution on fast machines, dropping the
# query to the tag_name alphabetical tiebreaker — same fix as in
# batch_insert_seed_assets. Read max(existing) so this batch sorts
# strictly after any prior batch on the same reference.
base_ts = _next_added_at_base(session, reference_id)
session.add_all(
[
AssetReferenceTag(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
added_at=base_ts + timedelta(microseconds=i),
)
for t in to_add
for i, t in enumerate(to_add)
]
)
session.flush()
@ -136,7 +175,7 @@ def add_tags_to_reference(
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
norm = normalize_tags(tags)
norm = expand_bucket_prefixes(normalize_tags(tags))
if not norm:
total = get_reference_tags(session, reference_id=reference_id)
return AddTagsResult(added=[], already_present=[], total_tags=total)
@ -146,10 +185,17 @@ def add_tags_to_reference(
current = set(get_reference_tags(session, reference_id))
# Preserve the caller's insertion order rather than alphabetizing —
# the retrieval ORDER BY added_at + microsecond stagger only meaningfully
# preserves insertion order if "the order we insert in" actually matches
# the caller's intent.
want = set(norm)
to_add = sorted(want - current)
to_add = [t for t in norm if t not in current]
if to_add:
# See set_reference_tags for the rationale behind the per-tag stagger
# and the max(existing) seed.
base_ts = _next_added_at_base(session, reference_id)
with session.begin_nested() as nested:
try:
session.add_all(
@ -158,9 +204,9 @@ def add_tags_to_reference(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
added_at=base_ts + timedelta(microseconds=i),
)
for t in to_add
for i, t in enumerate(to_add)
]
)
session.flush()

View File

@ -47,6 +47,50 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
def _known_bucket_prefixes() -> set[str]:
"""Lowercased model-category names eligible for standalone-prefix
expansion. Tags whose first slash segment matches one of these get
the bucket inserted as a separate token, so FE filters like
``include_tags=models,checkpoints`` keep matching even when the
asset lives in a nested subfolder (`models/checkpoints/flux/foo`).
Bare user labels with slashes whose first segment is not a registered
bucket (e.g. ``my-org/team-a``) pass through unchanged.
"""
try:
import folder_paths
return {
name.lower()
for name in folder_paths.folder_names_and_paths.keys()
if name != "custom_nodes"
}
except Exception:
return set()
def expand_bucket_prefixes(tags: list[str]) -> list[str]:
"""Insert standalone bucket tokens after any slash-joined tag whose
first segment is a registered model category. Preserves caller order
and is idempotent (existing bucket tokens are not duplicated).
"""
if not tags:
return list(tags)
buckets = _known_bucket_prefixes()
if not buckets:
return list(tags)
seen = set(tags)
result: list[str] = []
for t in tags:
result.append(t)
if "/" in t:
prefix = t.split("/", 1)[0]
if prefix.lower() in buckets and prefix not in seen:
result.append(prefix)
seen.add(prefix)
return result
def validate_blake3_hash(s: str) -> str:
"""Validate and normalize a blake3 hash string.

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import os
import uuid
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, TypedDict
from sqlalchemy.orm import Session
@ -13,13 +13,14 @@ from app.assets.database.queries import (
bulk_insert_references_ignore_conflicts,
bulk_insert_tags_and_meta,
delete_assets_by_ids,
ensure_tags_exist,
get_existing_asset_ids,
get_reference_ids_by_ids,
get_references_by_paths_and_asset_ids,
get_unreferenced_unhashed_asset_ids,
restore_references_by_paths,
)
from app.assets.helpers import get_utc_now
from app.assets.helpers import expand_bucket_prefixes, get_utc_now
if TYPE_CHECKING:
from app.assets.services.metadata_extract import ExtractedMetadata
@ -233,13 +234,20 @@ def batch_insert_seed_assets(
if ref_id not in inserted_ref_ids:
continue
for tag in ref_data["tags"]:
# Stagger added_at by microsecond per tag within a reference so
# the retrieval ORDER BY added_at preserves the input list order
# (the path-derived root category stays at position 0). Without
# this, every tag in a bulk-insert batch shares current_time and
# the tag_name tiebreaker sorts them alphabetically — putting the
# subpath tag ahead of "models" since "c"/"d"/"l" < "m".
ref_tags = expand_bucket_prefixes(ref_data["tags"])
for tag_idx, tag in enumerate(ref_tags):
tag_rows.append(
{
"asset_reference_id": ref_id,
"tag_name": tag,
"origin": "automatic",
"added_at": current_time,
"added_at": current_time + timedelta(microseconds=tag_idx),
}
)
@ -261,6 +269,16 @@ def batch_insert_seed_assets(
}
)
if tag_rows:
# Bucket-prefix expansion may have introduced tags the caller did
# not register via the upstream tag_pool (e.g. `checkpoints` for a
# nested `checkpoints/flux/foo` path). Pre-register the full set so
# the AssetReferenceTag.tag_name FK is satisfied; the underlying
# insert is ON CONFLICT DO NOTHING so re-registration is idempotent.
ensure_tags_exist(
session, {row["tag_name"] for row in tag_rows}, tag_type="user"
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows)
return BulkInsertResult(

View File

@ -3,7 +3,6 @@ from pathlib import Path
from typing import Literal
import folder_paths
from app.assets.helpers import normalize_tags
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
@ -27,27 +26,51 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
"""Validates and maps tags -> (base_dir, subdirs_for_fs).
Accepts both the legacy one-tag-per-directory shape
(``["models", "diffusers", "Kolors", "text_encoder"]``) and the
slash-joined shape emitted by :func:`get_name_and_tags_from_asset_path`
(``["models", "diffusers/Kolors/text_encoder"]``). Hybrid shapes that
mix the two within a single call (e.g.
``["models", "diffusers", "Kolors/text_encoder"]``) are also
accepted: each entry after ``tags[0]`` is split on ``/`` and
concatenated, so the two shapes — and any mix of them — resolve to
the same destination. The same safety checks are applied to each
component after expansion.
"""
if not tags:
raise ValueError("tags must not be empty")
root = tags[0].lower()
# Expand any slash-joined entries into individual path components so
# the rest of the function can treat both tag shapes uniformly. Each
# component is also stripped, so " a / b " behaves like ["a", "b"].
expanded: list[str] = []
for t in tags[1:]:
for part in str(t).split("/"):
part = part.strip()
if part:
expanded.append(part)
if root == "models":
if len(tags) < 2:
if not expanded:
raise ValueError("at least two tags required for model asset")
category = expanded[0]
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
bases = folder_paths.folder_names_and_paths[category][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
raise ValueError(f"unknown model category '{category}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
raise ValueError(f"no base path configured for category '{category}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
raw_subdirs = expanded[1:]
elif root == "input":
base_dir = os.path.abspath(folder_paths.get_input_directory())
raw_subdirs = tags[1:]
raw_subdirs = expanded
elif root == "output":
base_dir = os.path.abspath(folder_paths.get_output_directory())
raw_subdirs = tags[1:]
raw_subdirs = expanded
else:
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
_sep_chars = frozenset(("/", "\\", os.sep))
@ -160,7 +183,21 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return (name, tags) derived from a filesystem path.
- name: base filename with extension
- tags: [root_category] + parent folder names in order
- tags: [root_category] for paths with no parent subdirectories,
[root_category, slash_joined_subpath] otherwise. The parent subpath
(everything between the root category and the filename) is collapsed
into a single tag rather than emitted as one tag per directory, so
consumers can use ``tags[1]`` as a stable category identifier that
survives nested directory layouts (e.g. diffusers components).
The subpath is lowercased to match the canonicalization applied by
:func:`ensure_tags_exist`; without that, the
``asset_reference_tags.tag_name`` FK to the lowercased ``tags.name``
would fail for any path containing uppercase letters. The root
category is lowercase by construction in
:func:`get_asset_category_and_relative_path`, so no separate cast
is applied here. Consumers that need to look up providers keyed on
original-case paths should normalize their lookup key to lowercase.
Raises:
ValueError: path does not belong to any known root.
@ -170,4 +207,7 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
parent_parts = [
part for part in p.parent.parts if part not in (".", "..", p.anchor)
]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
tags = [root_category]
if parent_parts:
tags.append("/".join(parent_parts).lower())
return p.name, list(dict.fromkeys(t.strip() for t in tags if t.strip()))

View File

@ -62,8 +62,6 @@ def get_comfy_package_versions():
def check_comfy_packages_versions():
"""Warn for every comfy* package whose installed version is below requirements.txt."""
from packaging.version import InvalidVersion, parse as parse_pep440
outdated_packages = []
for pkg in get_comfy_package_versions():
installed_str = pkg["installed"]
required_str = pkg["required"]
@ -75,26 +73,19 @@ def check_comfy_packages_versions():
logging.error(f"Failed to check {pkg['name']} version: {e}")
continue
if outdated:
outdated_packages.append((pkg["name"], installed_str, required_str))
else:
logging.info("{} version: {}".format(pkg["name"], installed_str))
if outdated_packages:
package_warnings = "\n".join(
f"Installed {name} version {installed} is lower than the recommended version {required}."
for name, installed, required in outdated_packages
)
app.logger.log_startup_warning(
f"""
app.logger.log_startup_warning(
f"""
________________________________________________________________________
WARNING WARNING WARNING WARNING WARNING
{package_warnings}
Installed {pkg["name"]} version {installed_str} is lower than the recommended version {required_str}.
{get_missing_requirements_message()}
________________________________________________________________________
""".strip()
)
)
else:
logging.info("{} version: {}".format(pkg["name"], installed_str))
REQUEST_TIMEOUT = 10 # seconds

View File

@ -1217,7 +1217,7 @@ def get_aimdo_cast_buffer(offload_stream, device):
def get_pin_buffer(offload_stream):
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
if pin_buffer is None:
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3), mark_cold=False)
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3))
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
elif offload_stream is not None:
event = getattr(pin_buffer, "_comfy_event", None)

View File

@ -1613,16 +1613,6 @@ class ModelPatcherDynamic(ModelPatcher):
#use all ModelPatcherDynamic this is ignored and its all done dynamically.
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
def restore_loaded_backups(self):
restored = self.model.model_loaded_weight_memory
for key in list(self.backup.keys()):
bk = self.backup.pop(key)
comfy.utils.set_attr_param(self.model, key, bk.weight)
for key in list(self.backup_buffers.keys()):
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
self.model.model_loaded_weight_memory = 0
return restored
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
@ -1639,7 +1629,7 @@ class ModelPatcherDynamic(ModelPatcher):
num_patches = 0
allocated_size = 0
self.restore_loaded_backups()
self.model.model_loaded_weight_memory = 0
with self.use_ejected():
self.unpatch_hooks()
@ -1726,9 +1716,6 @@ class ModelPatcherDynamic(ModelPatcher):
force_load=True
if force_load:
if hasattr(m, "_v"):
comfy_aimdo.model_vbar.vbar_unpin(m._v)
delattr(m, "_v")
force_load_param(self, "weight", device_to)
force_load_param(self, "bias", device_to)
else:
@ -1786,7 +1773,13 @@ class ModelPatcherDynamic(ModelPatcher):
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
if freed < memory_to_free:
freed += self.restore_loaded_backups()
for key in list(self.backup.keys()):
bk = self.backup.pop(key)
comfy.utils.set_attr_param(self.model, key, bk.weight)
for key in list(self.backup_buffers.keys()):
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
freed += self.model.model_loaded_weight_memory
self.model.model_loaded_weight_memory = 0
return freed

View File

@ -265,6 +265,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
cond_shapes = collections.defaultdict(list)
for tt in batch_amount:
cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()}
for k, v in to_run[tt][0].conditioning.items():
cond_shapes[k].append(v.size())

View File

@ -1019,11 +1019,10 @@ def bislerp(samples, width, height):
def lanczos(samples, width, height):
#the below API is strict and expects grayscale to be squeezed
if samples.ndim == 4:
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
images = [torch.from_numpy(t).movedim(-1, 0) if (t := np.array(image).astype(np.float32) / 255.0).ndim == 3 else torch.from_numpy(t) for image in images]
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
result = torch.stack(images)
return result.to(samples.device, samples.dtype)

View File

@ -35,19 +35,6 @@ class AnthropicMessage(BaseModel):
content: list[AnthropicTextContent | AnthropicImageContent] = Field(...)
class AnthropicThinkingConfig(BaseModel):
type: Literal["enabled", "disabled", "adaptive"] = Field(...)
budget_tokens: int | None = Field(
None, ge=1024,
description="Reasoning budget in tokens. Used when type is 'enabled'. Must be less than max_tokens.",
)
class AnthropicOutputConfig(BaseModel):
"""Used with `thinking.type='adaptive'` on models like Opus 4.7."""
effort: Literal["low", "medium", "high"] | None = Field(None)
class AnthropicMessagesRequest(BaseModel):
model: str = Field(...)
messages: list[AnthropicMessage] = Field(...)
@ -57,8 +44,6 @@ class AnthropicMessagesRequest(BaseModel):
top_p: float | None = Field(None, ge=0.0, le=1.0)
top_k: int | None = Field(None, ge=0)
stop_sequences: list[str] | None = Field(None)
thinking: AnthropicThinkingConfig | None = Field(None)
output_config: AnthropicOutputConfig | None = Field(None)
class AnthropicResponseTextBlock(BaseModel):
@ -66,14 +51,6 @@ class AnthropicResponseTextBlock(BaseModel):
text: str = Field(...)
class AnthropicResponseThinkingBlock(BaseModel):
type: Literal["thinking"] = "thinking"
thinking: str = Field(...)
AnthropicResponseBlock = AnthropicResponseTextBlock | AnthropicResponseThinkingBlock
class AnthropicCacheCreationUsage(BaseModel):
ephemeral_5m_input_tokens: int | None = Field(None)
ephemeral_1h_input_tokens: int | None = Field(None)
@ -92,7 +69,7 @@ class AnthropicMessagesResponse(BaseModel):
type: str | None = Field(None)
role: str | None = Field(None)
model: str | None = Field(None)
content: list[AnthropicResponseBlock] | None = Field(None)
content: list[AnthropicResponseTextBlock] | None = Field(None)
stop_reason: str | None = Field(None)
stop_sequence: str | None = Field(None)
usage: AnthropicMessagesUsage | None = Field(None)

View File

@ -1,93 +0,0 @@
"""Pydantic models for the OpenRouter chat completions API.
See: https://openrouter.ai/docs/api/api-reference/chat/send-chat-completion-request
"""
from typing import Literal
from pydantic import BaseModel, Field
class OpenRouterTextContent(BaseModel):
type: Literal["text"] = "text"
text: str = Field(...)
class OpenRouterImageUrl(BaseModel):
url: str = Field(...)
class OpenRouterImageContent(BaseModel):
type: Literal["image_url"] = "image_url"
image_url: OpenRouterImageUrl = Field(...)
class OpenRouterVideoUrl(BaseModel):
url: str = Field(...)
class OpenRouterVideoContent(BaseModel):
type: Literal["video_url"] = "video_url"
video_url: OpenRouterVideoUrl = Field(...)
OpenRouterContentBlock = OpenRouterTextContent | OpenRouterImageContent | OpenRouterVideoContent
class OpenRouterMessage(BaseModel):
role: Literal["system", "user", "assistant"] = Field(...)
content: str | list[OpenRouterContentBlock] = Field(...)
class OpenRouterReasoningConfig(BaseModel):
effort: str | None = Field(None)
exclude: bool | None = Field(None, description="If true, model reasons but reasoning is excluded from response.")
class OpenRouterWebSearchOptions(BaseModel):
search_context_size: str | None = Field(None)
class OpenRouterChatRequest(BaseModel):
model: str = Field(...)
messages: list[OpenRouterMessage] = Field(...)
seed: int | None = Field(None)
reasoning: OpenRouterReasoningConfig | None = Field(None)
web_search_options: OpenRouterWebSearchOptions | None = Field(None)
stream: bool = Field(False)
class OpenRouterUsage(BaseModel):
prompt_tokens: int | None = Field(None)
completion_tokens: int | None = Field(None)
total_tokens: int | None = Field(None)
cost: float | None = Field(None, description="Server-side authoritative USD cost of the call.")
class OpenRouterResponseMessage(BaseModel):
role: str | None = Field(None)
content: str | None = Field(None)
reasoning: str | None = Field(None)
refusal: str | None = Field(None)
class OpenRouterChoice(BaseModel):
index: int | None = Field(None)
message: OpenRouterResponseMessage | None = Field(None)
finish_reason: str | None = Field(None)
class OpenRouterError(BaseModel):
code: int | str | None = Field(None)
message: str | None = Field(None)
metadata: dict | None = Field(None)
class OpenRouterChatResponse(BaseModel):
id: str | None = Field(None)
model: str | None = Field(None)
object: str | None = Field(None)
provider: str | None = Field(None)
choices: list[OpenRouterChoice] | None = Field(None)
usage: OpenRouterUsage | None = Field(None)
error: OpenRouterError | None = Field(None)

View File

@ -1,5 +1,7 @@
from enum import Enum
from __future__ import annotations
from enum import Enum
from typing import Optional, List
from pydantic import BaseModel, Field
@ -9,76 +11,44 @@ class Rodin3DGenerateRequest(BaseModel):
material: str = Field(..., description="The material type.")
quality_override: int = Field(..., description="The poly count of the mesh.")
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
TAPose: bool | None = Field(None, description="")
class Rodin3DGen25Request(BaseModel):
tier: str = Field(..., description="Gen-2.5 tier (e.g. Gen-2.5-High).")
prompt: str | None = Field(None, description="Required for Text-to-3D; ignored otherwise.")
seed: int | None = Field(None, description="0-65535.")
material: str | None = Field(None, description="PBR | Shaded | All | None.")
geometry_file_format: str | None = Field(None, description="glb | usdz | fbx | obj | stl.")
texture_mode: str | None = Field(None, description="legacy | extreme-low | low | medium | high.")
mesh_mode: str | None = Field(None, description="Raw (triangular) | Quad.")
quality_override: int | None = Field(None, description="Mesh face count override.")
geometry_instruct_mode: str | None = Field(None, description="faithful | creative.")
bbox_condition: list[int] | None = Field(None, description="Bounding box [Width(Y), Height(Z), Length(X)] in cm.")
height: int | None = Field(None, description="Approximate model height in cm.")
TAPose: bool | None = Field(None, description="T/A pose for human-like models.")
hd_texture: bool | None = Field(None, description="Enhanced texture quality.")
texture_delight: bool | None = Field(None, description="Remove baked lighting from textures.")
is_micro: bool | None = Field(None, description="Micro detail (Extreme-High only).")
use_original_alpha: bool | None = Field(None, description="Preserve image transparency.")
preview_render: bool | None = Field(None, description="Generate high-quality preview render.")
addons: list[str] | None = Field(None, description='Optional addons, e.g. ["HighPack"].')
TAPose: Optional[bool] = Field(None, description="")
class GenerateJobsData(BaseModel):
uuids: list[str] = Field(..., description="str LIST")
uuids: List[str] = Field(..., description="str LIST")
subscription_key: str = Field(..., description="subscription key")
class Rodin3DGenerateResponse(BaseModel):
message: str | None = Field(None, description="Return message.")
prompt: str | None = Field(None, description="Generated Prompt from image.")
submit_time: str | None = Field(None, description="Submit Time")
uuid: str | None = Field(None, description="Task str")
jobs: GenerateJobsData | None = Field(None, description="Details of jobs")
message: Optional[str] = Field(None, description="Return message.")
prompt: Optional[str] = Field(None, description="Generated Prompt from image.")
submit_time: Optional[str] = Field(None, description="Submit Time")
uuid: Optional[str] = Field(None, description="Task str")
jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs")
class JobStatus(str, Enum):
"""
Status for jobs
"""
Done = "Done"
Failed = "Failed"
Generating = "Generating"
Waiting = "Waiting"
class Rodin3DCheckStatusRequest(BaseModel):
subscription_key: str = Field(..., description="subscription from generate endpoint")
class JobItem(BaseModel):
uuid: str = Field(..., description="uuid")
status: JobStatus = Field(..., description="Status Currently")
status: JobStatus = Field(...,description="Status Currently")
class Rodin3DCheckStatusResponse(BaseModel):
jobs: list[JobItem] = Field(..., description="Job status List")
jobs: List[JobItem] = Field(..., description="Job status List")
class Rodin3DDownloadRequest(BaseModel):
task_uuid: str = Field(..., description="Task str")
class RodinResourceItem(BaseModel):
url: str = Field(..., description="Download Url")
name: str = Field(..., description="File name with ext")
class Rodin3DDownloadResponse(BaseModel):
items: list[RodinResourceItem] = Field(..., alias="list", description="Source List")
list: List[RodinResourceItem] = Field(..., description="Source List")

View File

@ -9,11 +9,8 @@ from comfy_api_nodes.apis.anthropic import (
AnthropicMessage,
AnthropicMessagesRequest,
AnthropicMessagesResponse,
AnthropicOutputConfig,
AnthropicResponseTextBlock,
AnthropicRole,
AnthropicTextContent,
AnthropicThinkingConfig,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@ -35,29 +32,15 @@ CLAUDE_MODELS: dict[str, str] = {
"Haiku 4.5": "claude-haiku-4-5-20251001",
}
_THINKING_UNSUPPORTED = {"Haiku 4.5"}
# Models that use the newer "adaptive" thinking mode (Opus 4.7 requires it; older models keep the explicit budget API).
# Anthropic decides the actual budget when adaptive is used, based on the `output_config.effort` hint.
_ADAPTIVE_THINKING_MODELS = {"Opus 4.7", "Opus 4.6", "Sonnet 4.6"}
# Budget mode (Sonnet 4.5): effort -> reasoning budget in tokens. Must be < max_tokens.
# Sized so even the "high" budget fits comfortably under the default max_tokens=32768.
_REASONING_BUDGET: dict[str, int] = {
"low": 2048,
"medium": 8192,
"high": 16384,
}
_REASONING_EFFORTS = ["off", "low", "medium", "high"]
def _claude_model_inputs(model_label: str):
inputs: list = [
def _claude_model_inputs():
return [
IO.Int.Input(
"max_tokens",
default=32768,
min=4096,
max=64000,
tooltip="Maximum number of tokens to generate (includes reasoning tokens when enabled).",
default=16000,
min=32,
max=32000,
tooltip="Maximum number of tokens to generate before stopping.",
advanced=True,
),
IO.Float.Input(
@ -66,24 +49,10 @@ def _claude_model_inputs(model_label: str):
min=0.0,
max=1.0,
step=0.01,
tooltip=(
"Controls randomness. 0.0 is deterministic, 1.0 is most random. "
"Ignored for Opus 4.7 and any model when reasoning_effort is set."
),
tooltip="Controls randomness. 0.0 is deterministic, 1.0 is most random. Ignored for Opus 4.7.",
advanced=True,
),
]
if model_label not in _THINKING_UNSUPPORTED:
inputs.append(
IO.Combo.Input(
"reasoning_effort",
options=_REASONING_EFFORTS,
default="off",
tooltip="Extended thinking effort. 'off' disables reasoning.",
advanced=True,
)
)
return inputs
def _model_price_per_million(model: str) -> tuple[float, float] | None:
@ -126,11 +95,7 @@ def calculate_tokens_price(response: AnthropicMessagesResponse) -> float | None:
def _get_text_from_response(response: AnthropicMessagesResponse) -> str:
if not response.content:
return ""
# Thinking blocks are silently dropped — we never want reasoning in the output.
return "\n".join(
block.text for block in response.content
if isinstance(block, AnthropicResponseTextBlock) and block.text
)
return "\n".join(block.text for block in response.content if block.text)
async def _build_image_content_blocks(
@ -168,10 +133,7 @@ class ClaudeNode(IO.ComfyNode):
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(label, _claude_model_inputs(label))
for label in CLAUDE_MODELS
],
options=[IO.DynamicCombo.Option(label, _claude_model_inputs()) for label in CLAUDE_MODELS],
tooltip="The Claude model used to generate the response.",
),
IO.Int.Input(
@ -245,29 +207,8 @@ class ClaudeNode(IO.ComfyNode):
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_label = model["model"]
max_tokens = model.get("max_tokens", 32768)
reasoning_effort = model.get("reasoning_effort", "off")
thinking_enabled = reasoning_effort not in ("off", None) and model_label not in _THINKING_UNSUPPORTED
# Anthropic requires temperature to be unset (defaults to 1.0) when thinking is enabled.
# Opus 4.7 also rejects user-supplied temperature.
if thinking_enabled or model_label == "Opus 4.7":
temperature = None
else:
temperature = model.get("temperature", 1.0)
thinking_cfg: AnthropicThinkingConfig | None = None
output_cfg: AnthropicOutputConfig | None = None
if thinking_enabled:
if model_label in _ADAPTIVE_THINKING_MODELS:
# Adaptive mode - Anthropic chooses the budget based on effort hint
thinking_cfg = AnthropicThinkingConfig(type="adaptive")
output_cfg = AnthropicOutputConfig(effort=reasoning_effort)
else:
# Budget mode (Sonnet 4.5). Leave at least 1024 tokens for the actual response
budget = _REASONING_BUDGET[reasoning_effort]
budget = min(budget, max(1024, max_tokens - 1024))
thinking_cfg = AnthropicThinkingConfig(type="enabled", budget_tokens=budget)
max_tokens = model["max_tokens"]
temperature = None if model_label == "Opus 4.7" else model["temperature"]
image_tensors: list[Input.Image] = [t for t in (images or {}).values() if t is not None]
if sum(get_number_of_images(t) for t in image_tensors) > CLAUDE_MAX_IMAGES:
@ -288,8 +229,6 @@ class ClaudeNode(IO.ComfyNode):
messages=[AnthropicMessage(role=AnthropicRole.user, content=content)],
system=system_prompt or None,
temperature=temperature,
thinking=thinking_cfg,
output_config=output_cfg,
),
price_extractor=calculate_tokens_price,
)

View File

@ -43,16 +43,15 @@ from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
download_url_to_video_output,
downscale_video_to_max_pixels,
get_number_of_images,
image_tensor_pair_to_batch,
poll_op,
resize_video_to_pixel_budget,
sync_op,
upload_audio_to_comfyapi,
upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
upscale_video_to_min_pixels,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_string,
@ -111,13 +110,12 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: st
max_px = limits.get("max")
if min_px and pixels < min_px:
raise ValueError(
f"Reference video {index} is too small: {w}x{h} = {pixels:,} total pixels. "
f"Minimum for this model is {min_px:,} total pixels."
f"Reference video {index} is too small: {w}x{h} = {pixels:,}px. " f"Minimum is {min_px:,}px for this model."
)
if max_px and pixels > max_px:
raise ValueError(
f"Reference video {index} is too large: {w}x{h} = {pixels:,} total pixels. "
f"Maximum for this model is {max_px:,} total pixels. Try downscaling the video."
f"Reference video {index} is too large: {w}x{h} = {pixels:,}px. "
f"Maximum is {max_px:,}px for this model. Try downscaling the video."
)
@ -1678,14 +1676,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
"first_frame_asset_id",
default="",
tooltip="Seedance asset_id to use as the first frame. "
"Mutually exclusive with the first_frame image input.",
"Mutually exclusive with the first_frame image input.",
optional=True,
),
IO.String.Input(
"last_frame_asset_id",
default="",
tooltip="Seedance asset_id to use as the last frame. "
"Mutually exclusive with the last_frame image input.",
"Mutually exclusive with the last_frame image input.",
optional=True,
),
IO.Int.Input(
@ -1867,20 +1865,11 @@ def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16
IO.Boolean.Input(
"auto_downscale",
default=False,
advanced=True,
optional=True,
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
),
IO.Boolean.Input(
"auto_upscale",
default=False,
advanced=True,
optional=True,
tooltip="Automatically upscale reference videos that are below the model's minimum pixel count "
"for the selected resolution. Aspect ratio is preserved; videos already meeting the minimum are "
"untouched. Note: upscaling a low-resolution source does not add real detail and may produce "
"lower-quality generations.",
),
IO.Autogrow.Input(
"reference_assets",
template=IO.Autogrow.TemplateNames(
@ -2041,13 +2030,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
max_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("max")
if max_px:
for key in reference_videos:
reference_videos[key] = downscale_video_to_max_pixels(reference_videos[key], max_px)
if model.get("auto_upscale") and reference_videos:
min_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("min")
if min_px:
for key in reference_videos:
reference_videos[key] = upscale_video_to_min_pixels(reference_videos[key], min_px)
reference_videos[key] = resize_video_to_pixel_budget(reference_videos[key], max_px)
total_video_duration = 0.0
for i, key in enumerate(reference_videos, 1):

View File

@ -1,374 +0,0 @@
"""API Nodes for OpenRouter LLM chat completions."""
from dataclasses import dataclass
from typing import Literal
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.openrouter import (
OpenRouterChatRequest,
OpenRouterChatResponse,
OpenRouterContentBlock,
OpenRouterImageContent,
OpenRouterImageUrl,
OpenRouterMessage,
OpenRouterReasoningConfig,
OpenRouterTextContent,
OpenRouterVideoContent,
OpenRouterVideoUrl,
OpenRouterWebSearchOptions,
)
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
sync_op,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
)
OPENROUTER_CHAT_ENDPOINT = "/proxy/openrouter/api/v1/chat/completions"
Profile = Literal["standard", "reasoning", "frontier_reasoning", "perplexity", "perplexity_reasoning"]
@dataclass(frozen=True)
class _ModelSpec:
slug: str # exact OpenRouter model id
profile: Profile
price_in: float # USD per token (prompt)
price_out: float # USD per token (completion)
max_images: int = 0 # 0 = no image input; otherwise max URL-passed images supported
max_videos: int = 0 # 0 = no video input; otherwise max URL-passed videos supported
MODELS: list[_ModelSpec] = [
_ModelSpec("anthropic/claude-opus-4.7", "frontier_reasoning", 0.000005, 0.000025, max_images=20),
_ModelSpec("openai/gpt-5.5-pro", "frontier_reasoning", 0.00003, 0.00018, max_images=20),
_ModelSpec("openai/gpt-5.5", "frontier_reasoning", 0.000005, 0.00003, max_images=20),
_ModelSpec("google/gemini-3.5-flash", "reasoning", 0.0000015, 0.000009, max_images=20, max_videos=4),
_ModelSpec("x-ai/grok-4.20", "reasoning", 0.00000125, 0.0000025, max_images=20),
_ModelSpec("x-ai/grok-4.3", "reasoning", 0.00000125, 0.0000025, max_images=20),
_ModelSpec("deepseek/deepseek-v4-pro", "reasoning", 0.000000435, 0.00000087),
_ModelSpec("deepseek/deepseek-v4-flash", "reasoning", 0.000000112, 0.000000224),
_ModelSpec("deepseek/deepseek-v3.2", "reasoning", 0.000000252, 0.000000378),
_ModelSpec("qwen/qwen3.6-max-preview", "reasoning", 0.00000104, 0.00000624),
_ModelSpec("qwen/qwen3.6-plus", "reasoning", 0.000000325, 0.00000195, max_images=10, max_videos=4),
_ModelSpec("qwen/qwen3.6-flash", "reasoning", 0.0000001875, 0.000001125, max_images=10, max_videos=4),
_ModelSpec("mistralai/mistral-large-2512", "standard", 0.0000005, 0.0000015, max_images=8),
_ModelSpec("mistralai/mistral-medium-3-5", "reasoning", 0.0000015, 0.0000075, max_images=8),
_ModelSpec("z-ai/glm-4.6", "reasoning", 0.00000043, 0.00000174),
_ModelSpec("z-ai/glm-5", "reasoning", 0.0000006, 0.00000192),
_ModelSpec("moonshotai/kimi-k2.6", "reasoning", 0.00000073, 0.00000349, max_images=10),
_ModelSpec("moonshotai/kimi-k2-thinking", "reasoning", 0.0000006, 0.0000025),
_ModelSpec("perplexity/sonar-pro", "perplexity", 0.000003, 0.000015),
_ModelSpec("perplexity/sonar-reasoning-pro", "perplexity_reasoning", 0.000002, 0.000008),
_ModelSpec("perplexity/sonar-deep-research", "perplexity_reasoning", 0.000002, 0.000008),
]
_MODELS_BY_SLUG: dict[str, _ModelSpec] = {m.slug: m for m in MODELS}
_REASONING_EFFORTS = ["off", "low", "medium", "high"]
_SEARCH_CONTEXT_SIZES = ["low", "medium", "high"]
def _reasoning_extra_inputs() -> list:
return [
IO.Combo.Input(
"reasoning_effort",
options=_REASONING_EFFORTS,
default="off",
tooltip="Reasoning effort. 'off' disables reasoning entirely.",
advanced=True,
),
]
def _perplexity_extra_inputs() -> list:
return [
IO.Combo.Input(
"search_context_size",
options=_SEARCH_CONTEXT_SIZES,
default="medium",
tooltip="How much web search context to retrieve. Larger = more grounded but slower/pricier.",
advanced=True,
),
]
def _profile_inputs(profile: Profile) -> list:
if profile == "standard":
return []
if profile in ("reasoning", "frontier_reasoning"):
return _reasoning_extra_inputs()
if profile == "perplexity":
return _perplexity_extra_inputs()
if profile == "perplexity_reasoning":
return _perplexity_extra_inputs() + _reasoning_extra_inputs()
raise ValueError(f"Unknown profile: {profile}")
def _media_inputs(spec: _ModelSpec) -> list:
extras: list = []
if spec.max_images > 0:
extras.append(
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, spec.max_images + 1)],
min=0,
),
tooltip=f"Optional reference image(s) — up to {spec.max_images}. Sent as URLs.",
)
)
if spec.max_videos > 0:
extras.append(
IO.Autogrow.Input(
"videos",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("video"),
names=[f"video_{i}" for i in range(1, spec.max_videos + 1)],
min=0,
),
tooltip=f"Optional reference video(s) — up to {spec.max_videos}. Sent as URLs.",
)
)
return extras
def _inputs_for_model(spec: _ModelSpec) -> list:
return _profile_inputs(spec.profile) + _media_inputs(spec)
def _build_model_options() -> list[IO.DynamicCombo.Option]:
return [IO.DynamicCombo.Option(spec.slug, _inputs_for_model(spec)) for spec in MODELS]
def _calculate_price(response: OpenRouterChatResponse) -> float | None:
if response.usage and response.usage.cost is not None:
return float(response.usage.cost)
return None
def _price_badge_jsonata() -> str:
rates_pairs = []
for spec in MODELS:
prompt_per_1k = spec.price_in * 1000
completion_per_1k = spec.price_out * 1000
rates_pairs.append(f' "{spec.slug}": [{prompt_per_1k:.8g}, {completion_per_1k:.8g}]')
rates_block = ",\n".join(rates_pairs)
return (
"(\n"
" $rates := {\n"
f"{rates_block}\n"
" };\n"
" $r := $lookup($rates, widgets.model);\n"
" $r ? {\n"
' "type": "list_usd",\n'
' "usd": $r,\n'
' "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }\n'
' } : {"type": "text", "text": "Token-based"}\n'
")"
)
async def _build_image_blocks(
cls: type[IO.ComfyNode], spec: _ModelSpec, images: list[Input.Image]
) -> list[OpenRouterImageContent]:
urls = await upload_images_to_comfyapi(
cls,
images,
max_images=spec.max_images,
total_pixels=2048 * 2048,
mime_type="image/png",
wait_label="Uploading reference images",
)
return [OpenRouterImageContent(image_url=OpenRouterImageUrl(url=url)) for url in urls]
async def _build_video_blocks(cls: type[IO.ComfyNode], videos: list[Input.Video]) -> list[OpenRouterVideoContent]:
blocks: list[OpenRouterVideoContent] = []
total = len(videos)
for idx, video in enumerate(videos):
label = "Uploading reference video"
if total > 1:
label = f"{label} ({idx + 1}/{total})"
url = await upload_video_to_comfyapi(cls, video, wait_label=label)
blocks.append(OpenRouterVideoContent(video_url=OpenRouterVideoUrl(url=url)))
return blocks
def _user_message(prompt: str, media_blocks: list[OpenRouterContentBlock]) -> OpenRouterMessage:
if not media_blocks:
return OpenRouterMessage(role="user", content=prompt)
blocks: list[OpenRouterContentBlock] = list(media_blocks)
blocks.append(OpenRouterTextContent(text=prompt))
return OpenRouterMessage(role="user", content=blocks)
def _build_messages(
system_prompt: str, prompt: str, media_blocks: list[OpenRouterContentBlock]
) -> list[OpenRouterMessage]:
messages: list[OpenRouterMessage] = []
if system_prompt:
messages.append(OpenRouterMessage(role="system", content=system_prompt))
messages.append(_user_message(prompt, media_blocks))
return messages
def _build_request(
slug: str,
system_prompt: str,
prompt: str,
media_blocks: list[OpenRouterContentBlock],
*,
seed: int,
reasoning_effort: str | None,
search_context_size: str | None,
) -> OpenRouterChatRequest:
reasoning_cfg: OpenRouterReasoningConfig | None = None
if reasoning_effort and reasoning_effort != "off":
# exclude=True asks providers to reason internally but not return the trace
reasoning_cfg = OpenRouterReasoningConfig(effort=reasoning_effort, exclude=True)
web_search_cfg: OpenRouterWebSearchOptions | None = None
if search_context_size:
web_search_cfg = OpenRouterWebSearchOptions(search_context_size=search_context_size)
return OpenRouterChatRequest(
model=slug,
messages=_build_messages(system_prompt, prompt, media_blocks),
seed=seed if seed > 0 else None,
reasoning=reasoning_cfg,
web_search_options=web_search_cfg,
)
def _extract_text(response: OpenRouterChatResponse) -> str:
if response.error:
code = response.error.code if response.error.code is not None else "unknown"
raise ValueError(f"OpenRouter error ({code}): {response.error.message or 'no message'}")
if not response.choices:
raise ValueError("Empty response from OpenRouter (no choices).")
message = response.choices[0].message
if not message:
raise ValueError("Empty response from OpenRouter (no message).")
if message.refusal:
raise ValueError(f"Model refused to respond: {message.refusal}")
return message.content or ""
class OpenRouterLLMNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="OpenRouterLLMNode",
display_name="OpenRouter LLM",
category="api node/text/OpenRouter",
essentials_category="Text Generation",
description=(
"Generate text responses through OpenRouter. Routes to a curated set of popular "
"models from xAI, DeepSeek, Qwen, Mistral, Z.AI (GLM), Moonshot (Kimi), and "
"Perplexity Sonar."
),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text input to the model.",
),
IO.DynamicCombo.Input(
"model",
options=_build_model_options(),
tooltip="The OpenRouter model used to generate the response.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for sampling. Set to 0 to omit. Most models treat this as a hint only.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
advanced=True,
tooltip="Foundational instructions that dictate the model's behavior.",
),
],
outputs=[IO.String.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr=_price_badge_jsonata(),
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
slug: str = model["model"]
spec = _MODELS_BY_SLUG.get(slug)
if spec is None:
raise ValueError(f"Unknown OpenRouter model: {slug}")
reasoning_effort: str | None = model.get("reasoning_effort")
search_context_size: str | None = model.get("search_context_size")
image_tensors: list[Input.Image] = [t for t in (model.get("images") or {}).values() if t is not None]
if image_tensors and sum(get_number_of_images(t) for t in image_tensors) > spec.max_images:
raise ValueError(f"Up to {spec.max_images} images are supported for {slug}.")
video_inputs: list[Input.Video] = [v for v in (model.get("videos") or {}).values() if v is not None]
if video_inputs and len(video_inputs) > spec.max_videos:
raise ValueError(f"Up to {spec.max_videos} videos are supported for {slug}.")
media_blocks: list[OpenRouterContentBlock] = []
if image_tensors:
media_blocks.extend(await _build_image_blocks(cls, spec, image_tensors))
if video_inputs:
media_blocks.extend(await _build_video_blocks(cls, video_inputs))
request = _build_request(
slug,
system_prompt,
prompt,
media_blocks,
seed=seed,
reasoning_effort=reasoning_effort,
search_context_size=search_context_size,
)
response = await sync_op(
cls,
ApiEndpoint(path=OPENROUTER_CHAT_ENDPOINT, method="POST"),
response_model=OpenRouterChatResponse,
data=request,
price_extractor=_calculate_price,
)
return IO.NodeOutput(_extract_text(response))
class OpenRouterExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [OpenRouterLLMNode]
async def comfy_entrypoint() -> OpenRouterExtension:
return OpenRouterExtension()

View File

@ -5,37 +5,32 @@ Rodin API docs: https://developer.hyper3d.ai/
"""
from inspect import cleandoc
import folder_paths as comfy_paths
import os
import logging
import math
import os
from inspect import cleandoc
from io import BytesIO
from typing import Any
import aiohttp
from PIL import Image
from typing_extensions import override
import folder_paths as comfy_paths
from comfy_api.latest import IO, ComfyExtension, Types
from PIL import Image
from comfy_api_nodes.apis.rodin import (
JobStatus,
Rodin3DGenerateRequest,
Rodin3DGenerateResponse,
Rodin3DCheckStatusRequest,
Rodin3DCheckStatusResponse,
Rodin3DDownloadRequest,
Rodin3DDownloadResponse,
Rodin3DGen25Request,
Rodin3DGenerateRequest,
Rodin3DGenerateResponse,
JobStatus,
)
from comfy_api_nodes.util import (
sync_op,
poll_op,
ApiEndpoint,
download_url_to_bytesio,
download_url_to_file_3d,
poll_op,
sync_op,
validate_string,
)
from comfy_api.latest import ComfyExtension, IO, Types
COMMON_PARAMETERS = [
IO.Int.Input(
@ -56,30 +51,40 @@ COMMON_PARAMETERS = [
]
_QUALITY_MESH_OPTIONS: dict[str, tuple[str, int]] = {
"4K-Quad": ("Quad", 4000),
"8K-Quad": ("Quad", 8000),
"18K-Quad": ("Quad", 18000),
"50K-Quad": ("Quad", 50000),
"200K-Quad": ("Quad", 200000),
"2K-Triangle": ("Raw", 2000),
"20K-Triangle": ("Raw", 20000),
"150K-Triangle": ("Raw", 150000),
"200K-Triangle": ("Raw", 200000),
"500K-Triangle": ("Raw", 500000),
"1M-Triangle": ("Raw", 1000000),
}
def get_quality_mode(poly_count):
polycount = poly_count.split("-")
poly = polycount[1]
count = polycount[0]
if poly == "Triangle":
mesh_mode = "Raw"
elif poly == "Quad":
mesh_mode = "Quad"
else:
mesh_mode = "Quad"
if count == "4K":
quality_override = 4000
elif count == "8K":
quality_override = 8000
elif count == "18K":
quality_override = 18000
elif count == "50K":
quality_override = 50000
elif count == "2K":
quality_override = 2000
elif count == "20K":
quality_override = 20000
elif count == "150K":
quality_override = 150000
elif count == "500K":
quality_override = 500000
else:
quality_override = 18000
return mesh_mode, quality_override
def get_quality_mode(poly_count: str) -> tuple[str, int]:
"""Map a polygon-count preset like '18K-Quad' to (mesh_mode, quality_override).
Falls back to ('Quad', 18000) for unknown labels; legacy parity.
"""
return _QUALITY_MESH_OPTIONS.get(poly_count, ("Quad", 18000))
def tensor_to_filelike(tensor, max_pixels: int = 2048 * 2048):
def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
"""
Converts a PyTorch tensor to a file-like object.
@ -91,8 +96,8 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048 * 2048):
- io.BytesIO: A file-like object containing the image data.
"""
array = tensor.cpu().numpy()
array = (array * 255).astype("uint8")
image = Image.fromarray(array, "RGB")
array = (array * 255).astype('uint8')
image = Image.fromarray(array, 'RGB')
original_width, original_height = image.size
original_pixels = original_width * original_height
@ -107,7 +112,7 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048 * 2048):
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
img_byte_arr = BytesIO()
image.save(img_byte_arr, format="PNG") # PNG is used for lossless compression
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
img_byte_arr.seek(0)
return img_byte_arr
@ -140,9 +145,11 @@ async def create_generate_task(
TAPose=ta_pose,
),
files=[
("images", open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image))
for image in images
if image is not None
(
"images",
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image)
)
for image in images if image is not None
],
content_type="multipart/form-data",
)
@ -170,7 +177,6 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
return "DONE"
return "Generating"
def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None:
if not response.jobs:
return None
@ -208,7 +214,7 @@ async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.Fi
model_file_path = None
file_3d = None
for i in url_list.items:
for i in url_list.list:
file_path = os.path.join(save_path, i.name)
if i.name.lower().endswith(".glb"):
model_file_path = os.path.join(result_folder_name, i.name)
@ -483,16 +489,7 @@ class Rodin3D_Gen2(IO.ComfyNode):
IO.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
IO.Combo.Input(
"Polygon_count",
options=[
"4K-Quad",
"8K-Quad",
"18K-Quad",
"50K-Quad",
"2K-Triangle",
"20K-Triangle",
"150K-Triangle",
"500K-Triangle",
],
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
default="500K-Triangle",
optional=True,
),
@ -545,566 +542,6 @@ class Rodin3D_Gen2(IO.ComfyNode):
return IO.NodeOutput(model_path, file_3d)
def _rodin_multipart_parser(data: dict[str, Any]) -> aiohttp.FormData:
"""Convert a Rodin request dict to an aiohttp form, fixing bool/list serialization.
Booleans --> "true"/"false". Lists --> one field per element.
"""
form = aiohttp.FormData(default_to_multipart=True)
for key, value in data.items():
if value is None:
continue
if isinstance(value, bool):
form.add_field(key, "true" if value else "false")
elif isinstance(value, list):
for item in value:
form.add_field(key, str(item))
elif isinstance(value, (bytes, bytearray)):
form.add_field(key, value)
else:
form.add_field(key, str(value))
return form
async def _create_gen25_task(
cls: type[IO.ComfyNode],
request: Rodin3DGen25Request,
images: list | None,
) -> tuple[str, str]:
"""Submit a Gen-2.5 generate job; returns (task_uuid, subscription_key)."""
if images is not None and len(images) > 5:
raise ValueError("Rodin Gen-2.5 supports at most 5 input images.")
files = None
if images:
files = [
(
"images",
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image),
)
for image in images
if image is not None
]
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"),
response_model=Rodin3DGenerateResponse,
data=request,
files=files,
content_type="multipart/form-data",
multipart_parser=_rodin_multipart_parser,
)
if not response.uuid or not response.jobs or not response.jobs.subscription_key:
raise RuntimeError(f"Rodin Gen-2.5 submit failed: message={response.message!r}")
return response.uuid, response.jobs.subscription_key
_PREVIEWABLE_3D_EXTS = {".glb", ".obj", ".fbx", ".stl", ".gltf"}
async def _download_gen25_files(
download_list: Rodin3DDownloadResponse,
task_uuid: str,
geometry_file_format: str,
) -> Types.File3D | None:
"""Download every file in the list; return the File3D matching the chosen format."""
folder_name = f"Rodin3D_Gen25_{task_uuid}"
save_dir = os.path.join(comfy_paths.get_output_directory(), folder_name)
os.makedirs(save_dir, exist_ok=True)
target_ext = f".{geometry_file_format.lower().lstrip('.')}"
file_3d: Types.File3D | None = None
for item in download_list.items:
file_path = os.path.join(save_dir, item.name)
ext = os.path.splitext(item.name.lower())[1]
# Prefer the file matching the user's chosen format; fall back below.
if file_3d is None and ext == target_ext and ext in _PREVIEWABLE_3D_EXTS:
file_3d = await download_url_to_file_3d(item.url, target_ext.lstrip("."))
with open(file_path, "wb") as f:
f.write(file_3d.get_bytes())
continue
await download_url_to_bytesio(item.url, file_path)
# If the chosen format wasn't found, surface any model file we did get.
if file_3d is None:
for item in download_list.items:
ext = os.path.splitext(item.name.lower())[1]
if ext in _PREVIEWABLE_3D_EXTS:
file_3d = await download_url_to_file_3d(item.url, ext.lstrip("."))
break
return file_3d
_MODE_REGULAR = "Regular"
_MODE_FAST = "Fast"
_MODE_EXTREME_HIGH = "Extreme-High"
_REGULAR_POLY_OPTIONS = [
"Default",
"4K-Quad",
"8K-Quad",
"18K-Quad",
"50K-Quad",
"2K-Triangle",
"20K-Triangle",
"150K-Triangle",
"500K-Triangle",
"1M-Triangle",
]
_TEXTURE_MODE_OPTIONS = ["Default", "legacy", "extreme-low", "low", "medium", "high"]
_GEOMETRY_FORMAT_OPTIONS = ["glb", "fbx", "obj", "stl"]
_MATERIAL_OPTIONS = ["PBR", "Shaded", "All", "None"]
def _build_mode_input(name: str = "mode") -> IO.DynamicCombo.Input:
return IO.DynamicCombo.Input(
name,
options=[
IO.DynamicCombo.Option(
_MODE_REGULAR,
[
IO.Combo.Input(
"tier",
options=["Gen-2.5-Low", "Gen-2.5-Medium", "Gen-2.5-High"],
default="Gen-2.5-High",
tooltip="Quality tier. Higher tiers produce higher-fidelity geometry.",
),
IO.Combo.Input(
"polygon_count",
options=_REGULAR_POLY_OPTIONS,
default="Default",
tooltip="Preset face count. 'Default' uses the server's default for the selected tier.",
),
IO.Boolean.Input(
"creative",
default=False,
tooltip="Creative mode (Medium/High only). Enhances generative robustness.",
),
],
),
IO.DynamicCombo.Option(
_MODE_FAST,
[
IO.Combo.Input(
"tier",
options=[
"Gen-2.5-Extreme-Low",
"Gen-2.5-Low",
"Gen-2.5-Medium",
"Gen-2.5-High",
],
default="Gen-2.5-Low",
),
IO.Int.Input(
"mesh_faces",
default=20000,
min=1000,
max=20000,
display_mode=IO.NumberDisplay.number,
tooltip="Mesh face count (1K-20K in Fast mode).",
),
],
),
IO.DynamicCombo.Option(
_MODE_EXTREME_HIGH,
[
IO.Combo.Input("mesh_mode", options=["Raw", "Quad"], default="Raw"),
IO.Int.Input(
"mesh_faces",
default=1000000,
min=20000,
max=2000000,
display_mode=IO.NumberDisplay.number,
tooltip=(
"Mesh face count. Raw mode: 20K-2M. "
"Quad mode: keep under 200K (upstream may reject higher values)."
),
),
IO.Boolean.Input(
"is_micro",
default=False,
tooltip="Enable micro detail (Extreme-High only).",
),
IO.Boolean.Input(
"creative",
default=False,
tooltip="Creative mode. Enhances generative robustness.",
),
],
),
],
tooltip=(
"Generation mode. Regular = balanced. Fast = 1K-20K faces for rapid prototyping. "
"Extreme-High = 20K-2M faces with optional micro details."
),
)
def _build_common_inputs(*, include_image_only: bool) -> list:
inputs: list = [
IO.Combo.Input("material", options=_MATERIAL_OPTIONS, default="Shaded"),
IO.Combo.Input("geometry_file_format", options=_GEOMETRY_FORMAT_OPTIONS, default="glb"),
IO.Combo.Input(
"texture_mode",
options=_TEXTURE_MODE_OPTIONS,
default="Default",
optional=True,
tooltip="Texture quality preset. 'Default' uses the server's default for the selected tier.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=65535,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
optional=True,
),
IO.Boolean.Input(
"TAPose", default=False, optional=True, advanced=True, tooltip="T/A pose for human-like models."
),
IO.Boolean.Input(
"hd_texture", default=False, optional=True, advanced=True, tooltip="High-quality texture enhancement."
),
IO.Boolean.Input(
"texture_delight",
default=False,
optional=True,
advanced=True,
tooltip="Remove baked lighting from textures.",
),
]
if include_image_only:
inputs.append(
IO.Boolean.Input(
"use_original_alpha",
default=False,
optional=True,
advanced=True,
tooltip="Preserve image transparency.",
)
)
inputs.extend(
[
IO.Boolean.Input(
"addon_highpack",
default=False,
optional=True,
advanced=True,
tooltip="HighPack addon: 4K textures and ~16x faces in Quad mode.",
),
IO.Int.Input(
"bbox_width",
default=0,
min=0,
max=300,
display_mode=IO.NumberDisplay.number,
optional=True,
advanced=True,
tooltip="Bounding-box width (Y axis). Set to 0 with the others to skip bbox.",
),
IO.Int.Input(
"bbox_height",
default=0,
min=0,
max=300,
display_mode=IO.NumberDisplay.number,
optional=True,
advanced=True,
tooltip="Bounding-box height (Z axis).",
),
IO.Int.Input(
"bbox_length",
default=0,
min=0,
max=300,
display_mode=IO.NumberDisplay.number,
optional=True,
advanced=True,
tooltip="Bounding-box length (X axis).",
),
IO.Int.Input(
"height_cm",
default=0,
min=0,
max=10000,
display_mode=IO.NumberDisplay.number,
optional=True,
advanced=True,
tooltip="Approximate model height in centimeters (0 to skip).",
),
]
)
return inputs
_PRICE_EXPR = """
(
$baseCredits := widgets.mode = "extreme-high" ? 1.0 : 0.5;
$addonCredits := widgets.addon_highpack ? 1.0 : 0.0;
$total := ($baseCredits * 1.5) + ($addonCredits * 0.8);
{"type":"usd","usd": $total}
)
"""
def _resolve_mode_params(mode_input: dict) -> dict:
"""Translate the DynamicCombo `mode` payload into Gen-2.5 request fields.
Returns a dict with: tier, quality_override, mesh_mode, geometry_instruct_mode, is_micro.
Missing keys mean "do not send" (so we don't override server defaults).
"""
selected = mode_input["mode"]
out: dict = {}
if selected == _MODE_REGULAR:
out["tier"] = mode_input["tier"]
polygon = mode_input.get("polygon_count", "Default")
if polygon != "Default":
mesh_mode, faces = get_quality_mode(polygon)
out["mesh_mode"] = mesh_mode
out["quality_override"] = faces
if mode_input.get("creative"):
out["geometry_instruct_mode"] = "creative"
elif selected == _MODE_FAST:
out["tier"] = mode_input["tier"]
out["mesh_mode"] = "Raw"
out["quality_override"] = int(mode_input["mesh_faces"])
elif selected == _MODE_EXTREME_HIGH:
out["tier"] = "Gen-2.5-Extreme-High"
out["mesh_mode"] = mode_input["mesh_mode"]
out["quality_override"] = int(mode_input["mesh_faces"])
if mode_input.get("is_micro"):
out["is_micro"] = True
if mode_input.get("creative"):
out["geometry_instruct_mode"] = "creative"
return out
def _build_request(
*,
mode_input: dict,
material: str,
geometry_file_format: str,
texture_mode: str,
seed: int,
TAPose: bool,
hd_texture: bool,
texture_delight: bool,
addon_highpack: bool,
bbox_width: int,
bbox_height: int,
bbox_length: int,
height_cm: int,
prompt: str | None = None,
use_original_alpha: bool = False,
) -> Rodin3DGen25Request:
mode_params = _resolve_mode_params(mode_input)
bbox = None
if bbox_width and bbox_height and bbox_length:
bbox = [bbox_width, bbox_height, bbox_length]
return Rodin3DGen25Request(
tier=mode_params["tier"],
prompt=prompt or None,
seed=seed,
material=material,
geometry_file_format=geometry_file_format,
texture_mode=None if texture_mode == "Default" else texture_mode,
mesh_mode=mode_params.get("mesh_mode"),
quality_override=mode_params.get("quality_override"),
geometry_instruct_mode=mode_params.get("geometry_instruct_mode"),
bbox_condition=bbox,
height=height_cm or None,
TAPose=TAPose or None,
hd_texture=hd_texture or None,
texture_delight=texture_delight or None,
is_micro=mode_params.get("is_micro"),
use_original_alpha=use_original_alpha or None,
addons=["HighPack"] if addon_highpack else None,
)
class Rodin3D_Gen25_Image(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Rodin3D_Gen25_Image",
display_name="Rodin 3D Gen-2.5 - Image to 3D",
category="api node/3d/Rodin",
description=(
"Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. "
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
),
inputs=[
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplatePrefix(IO.Image.Input("image"), prefix="image", min=1, max=5),
tooltip="1-5 images. The first image is used for materials when multi-view.",
),
_build_mode_input(),
*_build_common_inputs(include_image_only=True),
],
outputs=[IO.File3DAny.Output(display_name="model_file")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode", "addon_highpack"]),
expr=_PRICE_EXPR,
),
)
@classmethod
async def execute(
cls,
images: IO.Autogrow.Type,
mode: dict,
material: str,
geometry_file_format: str,
texture_mode: str,
seed: int,
TAPose: bool,
hd_texture: bool,
texture_delight: bool,
use_original_alpha: bool,
addon_highpack: bool,
bbox_width: int,
bbox_height: int,
bbox_length: int,
height_cm: int,
) -> IO.NodeOutput:
image_tensors = [img for img in images.values() if img is not None]
if not image_tensors:
raise ValueError("Rodin Gen-2.5 Image-to-3D requires at least one image.")
# Flatten multi-image tensors into individual frames; the API accepts each as a separate part.
flat_images: list = []
for tensor in image_tensors:
if hasattr(tensor, "shape") and len(tensor.shape) == 4:
for i in range(tensor.shape[0]):
flat_images.append(tensor[i])
else:
flat_images.append(tensor)
if len(flat_images) > 5:
raise ValueError(f"Rodin Gen-2.5 accepts at most 5 images; received {len(flat_images)}.")
request = _build_request(
mode_input=mode,
material=material,
geometry_file_format=geometry_file_format,
texture_mode=texture_mode,
seed=seed,
TAPose=TAPose,
hd_texture=hd_texture,
texture_delight=texture_delight,
addon_highpack=addon_highpack,
bbox_width=bbox_width,
bbox_height=bbox_height,
bbox_length=bbox_length,
height_cm=height_cm,
prompt=None,
use_original_alpha=use_original_alpha,
)
task_uuid, subscription_key = await _create_gen25_task(cls, request, flat_images)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
file_3d = await _download_gen25_files(download_list, task_uuid, geometry_file_format)
return IO.NodeOutput(file_3d)
class Rodin3D_Gen25_Text(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Rodin3D_Gen25_Text",
display_name="Rodin 3D Gen-2.5 - Text to 3D",
category="api node/3d/Rodin",
description=(
"Generate a 3D model from a text prompt via Rodin Gen-2.5. "
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the 3D model.",
),
_build_mode_input(),
*_build_common_inputs(include_image_only=False),
],
outputs=[IO.File3DAny.Output(display_name="model_file")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode", "addon_highpack"]),
expr=_PRICE_EXPR,
),
)
@classmethod
async def execute(
cls,
prompt: str,
mode: dict,
material: str,
geometry_file_format: str,
texture_mode: str,
seed: int,
TAPose: bool,
hd_texture: bool,
texture_delight: bool,
addon_highpack: bool,
bbox_width: int,
bbox_height: int,
bbox_length: int,
height_cm: int,
) -> IO.NodeOutput:
validate_string(prompt, field_name="prompt", min_length=1, max_length=2500)
request = _build_request(
mode_input=mode,
material=material,
geometry_file_format=geometry_file_format,
texture_mode=texture_mode,
seed=seed,
TAPose=TAPose,
hd_texture=hd_texture,
texture_delight=texture_delight,
addon_highpack=addon_highpack,
bbox_width=bbox_width,
bbox_height=bbox_height,
bbox_length=bbox_length,
height_cm=height_cm,
prompt=prompt,
)
task_uuid, subscription_key = await _create_gen25_task(cls, request, images=None)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
file_3d = await _download_gen25_files(download_list, task_uuid, geometry_file_format)
return IO.NodeOutput(file_3d)
class Rodin3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -1114,8 +551,6 @@ class Rodin3DExtension(ComfyExtension):
Rodin3D_Smooth,
Rodin3D_Sketch,
Rodin3D_Gen2,
Rodin3D_Gen25_Image,
Rodin3D_Gen25_Text,
]

View File

@ -16,17 +16,16 @@ from .conversions import (
convert_mask_to_image,
downscale_image_tensor,
downscale_image_tensor_by_max_side,
downscale_video_to_max_pixels,
image_tensor_pair_to_batch,
pil_to_bytesio,
resize_mask_to_image,
resize_video_to_pixel_budget,
tensor_to_base64_string,
tensor_to_bytesio,
tensor_to_pil,
text_filepath_to_base64_string,
text_filepath_to_data_uri,
trim_video,
upscale_video_to_min_pixels,
video_to_base64_string,
)
from .download_helpers import (
@ -89,17 +88,16 @@ __all__ = [
"convert_mask_to_image",
"downscale_image_tensor",
"downscale_image_tensor_by_max_side",
"downscale_video_to_max_pixels",
"image_tensor_pair_to_batch",
"pil_to_bytesio",
"resize_mask_to_image",
"resize_video_to_pixel_budget",
"tensor_to_base64_string",
"tensor_to_bytesio",
"tensor_to_pil",
"text_filepath_to_base64_string",
"text_filepath_to_data_uri",
"trim_video",
"upscale_video_to_min_pixels",
"video_to_base64_string",
# Validation utilities
"get_image_dimensions",

View File

@ -415,48 +415,14 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
def downscale_video_to_max_pixels(video: Input.Video, max_pixels: int) -> Input.Video:
"""Downscale a video to fit within ``max_pixels`` (w * h), preserving aspect ratio.
def resize_video_to_pixel_budget(video: Input.Video, total_pixels: int) -> Input.Video:
"""Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio.
Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio.
Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
"""
src_w, src_h = video.get_dimensions()
scale_dims = _compute_downscale_dims(src_w, src_h, max_pixels)
if scale_dims is None:
return video
return _apply_video_scale(video, scale_dims)
def _compute_upscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[int, int] | None:
"""Return upscaled (w, h) with even dims meeting at least ``total_pixels``, or None if already large enough.
Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions
are rounded up to even values (many codecs require divisible-by-2). The result is guaranteed to be at
least ``total_pixels``.
"""
pixels = src_w * src_h
if pixels >= total_pixels:
return None
scale = math.sqrt(total_pixels / pixels)
new_w = math.ceil(src_w * scale)
new_h = math.ceil(src_h * scale)
if new_w % 2:
new_w += 1
if new_h % 2:
new_h += 1
return new_w, new_h
def upscale_video_to_min_pixels(video: Input.Video, min_pixels: int) -> Input.Video:
"""Upscale a video to meet at least ``min_pixels`` (w * h), preserving aspect ratio.
Returns the original video object untouched when it already meets the minimum. Preserves frame rate,
duration, and audio. Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
Note: upscaling a low-resolution source does not add real detail; downstream model quality may suffer.
"""
src_w, src_h = video.get_dimensions()
scale_dims = _compute_upscale_dims(src_w, src_h, min_pixels)
scale_dims = _compute_downscale_dims(src_w, src_h, total_pixels)
if scale_dims is None:
return video
return _apply_video_scale(video, scale_dims)

View File

@ -543,7 +543,7 @@ class AudioConcat(IO.ComfyNode):
return IO.Schema(
node_id="AudioConcat",
search_aliases=["join audio", "combine audio", "append audio"],
display_name="Concatenate Audio",
display_name="Audio Concat",
description="Concatenates the audio1 to audio2 in the specified direction.",
category="audio",
inputs=[
@ -597,7 +597,7 @@ class AudioMerge(IO.ComfyNode):
return IO.Schema(
node_id="AudioMerge",
search_aliases=["mix audio", "overlay audio", "layer audio"],
display_name="Merge Audio",
display_name="Audio Merge",
description="Combine two audio tracks by overlaying their waveforms.",
category="audio",
inputs=[
@ -667,9 +667,8 @@ class AudioAdjustVolume(IO.ComfyNode):
return IO.Schema(
node_id="AudioAdjustVolume",
search_aliases=["audio gain", "loudness", "audio level"],
display_name="Adjust Audio Volume",
display_name="Audio Adjust Volume",
category="audio",
description="Adjust the volume of the audio by a specified amount in decibels (dB).",
inputs=[
IO.Audio.Input("audio"),
IO.Int.Input(

View File

@ -47,10 +47,8 @@ class LoadImageDataSetFromFolderNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LoadImageDataSetFromFolder",
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
display_name="Load Image (from Folder)",
category="image",
description="Load a dataset of images from a specified folder and return a list of images. Supported formats: PNG, JPG, JPEG, WEBP.",
display_name="Load Image Dataset from Folder",
category="dataset",
is_experimental=True,
inputs=[
io.Combo.Input(
@ -86,16 +84,14 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LoadImageTextDataSetFromFolder",
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
display_name="Load Image-Text (from Folder)",
category="image",
description="Load a dataset of pairs of images and text captions from a specified folder and return them as a list. Supported formats: PNG, JPG, JPEG, WEBP.",
display_name="Load Image and Text Dataset from Folder",
category="dataset",
is_experimental=True,
inputs=[
io.Combo.Input(
"folder",
options=folder_paths.get_input_subfolders(),
tooltip="The folder to load images and text captions from.",
tooltip="The folder to load images from.",
)
],
outputs=[
@ -210,10 +206,8 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveImageDataSetToFolder",
search_aliases=["save folder", "save to folder", "save dataset", "save images", "export dataset"],
display_name="Save Image (to Folder) (DEPRECATED)",
category="image",
description="Save a dataset of images to a specified folder. Supported formats: PNG.",
display_name="Save Image Dataset to Folder",
category="dataset",
is_experimental=True,
is_output_node=True,
is_input_list=True, # Receive images as list
@ -232,7 +226,6 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
),
],
outputs=[],
is_deprecated=True, # This node is redundant and superseded by existing Save Image nodes where the target folder can be specified in the filename_prefix
)
@classmethod
@ -253,20 +246,14 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveImageTextDataSetToFolder",
search_aliases=["save folder", "save to folder", "save dataset", "save images", "save text", "export dataset"],
display_name="Save Image-Text (to Folder)",
category="image",
description="Save a dataset of pairs of images and text captions to a specified folder. Images are saved as PNG files and captions are saved as TXT files with the same filename_prefix.",
display_name="Save Image and Text Dataset to Folder",
category="dataset",
is_experimental=True,
is_output_node=True,
is_input_list=True, # Receive both images and texts as lists
inputs=[
io.Image.Input("images", tooltip="List of images to save."),
io.String.Input("texts",
optional=True,
force_input=True,
tooltip="List of text captions to save."
),
io.String.Input("texts", tooltip="List of text captions to save."),
io.String.Input(
"folder_name",
default="dataset",
@ -283,7 +270,7 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
)
@classmethod
def execute(cls, images, folder_name, filename_prefix, texts=None):
def execute(cls, images, texts, folder_name, filename_prefix):
# Extract scalar values
folder_name = folder_name[0]
filename_prefix = filename_prefix[0]
@ -292,12 +279,11 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
saved_files = save_images_to_folder(images, output_dir, filename_prefix)
# Save captions
if texts:
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
caption_filename = filename.replace(".png", ".txt")
caption_path = os.path.join(output_dir, caption_filename)
with open(caption_path, "w", encoding="utf-8") as f:
f.write(caption)
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
caption_filename = filename.replace(".png", ".txt")
caption_path = os.path.join(output_dir, caption_filename)
with open(caption_path, "w", encoding="utf-8") as f:
f.write(caption)
logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.")
return io.NodeOutput()
@ -328,13 +314,11 @@ class ImageProcessingNode(io.ComfyNode):
Child classes should set:
node_id: Unique node identifier (required)
search_aliases: List of search aliases (optional)
display_name: Display name (optional, defaults to node_id)
description: Node description (optional)
extra_inputs: List of additional io.Input objects beyond "images" (optional)
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
is_output_list: True (list output) or False (single output) (optional, default True)
is_deprecated: True if the node is deprecated (optional, default False)
Child classes must implement ONE of:
_process(cls, image, **kwargs) -> tensor (for single-item processing)
@ -342,13 +326,12 @@ class ImageProcessingNode(io.ComfyNode):
"""
node_id = None
search_aliases = []
display_name = None
description = None
extra_inputs = []
is_group_process = None # None = auto-detect, True/False = explicit
is_output_list = None # None = auto-detect based on processing mode
is_deprecated = False
@classmethod
def _detect_processing_mode(cls):
"""Detect whether this node uses group or individual processing.
@ -419,10 +402,8 @@ class ImageProcessingNode(io.ComfyNode):
return io.Schema(
node_id=cls.node_id,
search_aliases=cls.search_aliases,
display_name=cls.display_name or cls.node_id,
category=cls.category,
description=cls.description,
category="dataset/image",
is_experimental=True,
is_input_list=is_group, # True for group, False for individual
inputs=inputs,
@ -491,13 +472,11 @@ class TextProcessingNode(io.ComfyNode):
Child classes should set:
node_id: Unique node identifier (required)
search_aliases: List of search aliases (optional)
display_name: Display name (optional, defaults to node_id)
description: Node description (optional)
extra_inputs: List of additional io.Input objects beyond "texts" (optional)
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
is_output_list: True (list output) or False (single output) (optional, default True)
is_deprecated: True if the node is deprecated (optional, default False)
Child classes must implement ONE of:
_process(cls, text, **kwargs) -> str (for single-item processing)
@ -505,13 +484,12 @@ class TextProcessingNode(io.ComfyNode):
"""
node_id = None
search_aliases = []
display_name = None
description = None
extra_inputs = []
is_group_process = None # None = auto-detect, True/False = explicit
is_output_list = None # None = auto-detect based on processing mode
is_deprecated = False
@classmethod
def _detect_processing_mode(cls):
"""Detect whether this node uses group or individual processing.
@ -649,17 +627,15 @@ class TextProcessingNode(io.ComfyNode):
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
node_id = "ResizeImagesByShorterEdge"
display_name = "Resize Images by Shorter Edge (DEPRECATED)"
category = "image/transform"
description = "Resize images so that the shorter edge matches the specified dimension while preserving aspect ratio."
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale shorter dimension
display_name = "Resize Images by Shorter Edge"
description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio."
extra_inputs = [
io.Int.Input(
"shorter_edge",
default=512,
min=1,
max=8192,
tooltip="Target dimension for the shorter edge.",
tooltip="Target length for the shorter edge.",
),
]
@ -679,17 +655,15 @@ class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
node_id = "ResizeImagesByLongerEdge"
display_name = "Resize Images by Longer Edge (DEPRECATED)"
category = "image/transform"
description = "Resize images so that the longer edge matches the specified dimension while preserving aspect ratio."
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale longer dimension
display_name = "Resize Images by Longer Edge"
description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio."
extra_inputs = [
io.Int.Input(
"longer_edge",
default=1024,
min=1,
max=8192,
tooltip="Target dimension for the longer edge.",
tooltip="Target length for the longer edge.",
),
]
@ -712,10 +686,8 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
class CenterCropImagesNode(ImageProcessingNode):
node_id = "CenterCropImages"
search_aliases=["crop", "cut", "trim"]
display_name="Crop Image (Center)"
category="image/transform"
description = "Center crop an image to the specified dimensions."
display_name = "Center Crop Images"
description = "Center crop all images to the specified dimensions."
extra_inputs = [
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
@ -734,11 +706,10 @@ class CenterCropImagesNode(ImageProcessingNode):
class RandomCropImagesNode(ImageProcessingNode):
node_id = "RandomCropImages"
search_aliases=["crop", "cut", "trim"]
display_name = "Crop Image (Random)"
category="image/transform"
description = "Randomly crop an image to the specified dimensions."
display_name = "Random Crop Images"
description = (
"Randomly crop all images to the specified dimensions (for data augmentation)."
)
extra_inputs = [
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
@ -763,9 +734,7 @@ class RandomCropImagesNode(ImageProcessingNode):
class NormalizeImagesNode(ImageProcessingNode):
node_id = "NormalizeImages"
search_aliases=["normalize", "normalize colors"]
display_name = "Normalize Image Colors"
category = "image/color"
display_name = "Normalize Images"
description = "Normalize images using mean and standard deviation."
extra_inputs = [
io.Float.Input(
@ -793,10 +762,8 @@ class NormalizeImagesNode(ImageProcessingNode):
class AdjustBrightnessNode(ImageProcessingNode):
node_id = "AdjustBrightness"
search_aliases=["brightness"]
display_name = "Adjust Brightness"
category="image/adjustments"
description = "Adjust the brightness of an image."
description = "Adjust brightness of all images."
extra_inputs = [
io.Float.Input(
"factor",
@ -814,10 +781,8 @@ class AdjustBrightnessNode(ImageProcessingNode):
class AdjustContrastNode(ImageProcessingNode):
node_id = "AdjustContrast"
search_aliases=["contrast"]
display_name = "Adjust Contrast"
category="image/adjustments"
description = "Adjust the contrast of an image."
description = "Adjust contrast of all images."
extra_inputs = [
io.Float.Input(
"factor",
@ -835,10 +800,8 @@ class AdjustContrastNode(ImageProcessingNode):
class ShuffleDatasetNode(ImageProcessingNode):
node_id = "ShuffleDataset"
search_aliases=["shuffle", "randomize", "mix"]
display_name = "Shuffle Images List"
category = "image/batch"
description = "Randomly shuffle the order of images in a list."
display_name = "Shuffle Image Dataset"
description = "Randomly shuffle the order of images in the dataset."
is_group_process = True # Requires full list to shuffle
extra_inputs = [
io.Int.Input(
@ -860,15 +823,13 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ShuffleImageTextDataset",
search_aliases=["shuffle", "randomize", "mix"],
display_name = "Shuffle Pairs of Image-Text",
category = "image/batch",
description = "Randomly shuffle the order of pairs of image-text in a list.",
display_name="Shuffle Image-Text Dataset",
category="dataset/image",
is_experimental=True,
is_input_list=True,
inputs=[
io.Image.Input("images", tooltip="List of images to shuffle."),
io.String.Input("texts", tooltip="List of texts to shuffle.", force_input=True),
io.String.Input("texts", tooltip="List of texts to shuffle."),
io.Int.Input(
"seed",
default=0,
@ -904,11 +865,8 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
class TextToLowercaseNode(TextProcessingNode):
node_id = "TextToLowercase"
search_aliases=["lowercase"]
display_name = "Convert Text to Lowercase (DEPRECATED)"
category = "text"
description = "Convert text to lowercase."
is_deprecated = True # This node is superseded by the Convert Text Case node
display_name = "Text to Lowercase"
description = "Convert all texts to lowercase."
@classmethod
def _process(cls, text):
@ -917,11 +875,8 @@ class TextToLowercaseNode(TextProcessingNode):
class TextToUppercaseNode(TextProcessingNode):
node_id = "TextToUppercase"
search_aliases=["uppercase"]
display_name = "Convert Text to Uppercase (DEPRECATED)"
category = "text"
description = "Convert text to uppercase."
is_deprecated = True # This node is superseded by the Convert Text Case node
display_name = "Text to Uppercase"
description = "Convert all texts to uppercase."
@classmethod
def _process(cls, text):
@ -930,10 +885,8 @@ class TextToUppercaseNode(TextProcessingNode):
class TruncateTextNode(TextProcessingNode):
node_id = "TruncateText"
search_aliases=["truncate", "cut", "shorten"]
display_name = "Truncate Text"
category = "text"
description = "Truncate text to a maximum length."
description = "Truncate all texts to a maximum length."
extra_inputs = [
io.Int.Input(
"max_length", default=77, min=1, max=10000, tooltip="Maximum text length."
@ -947,10 +900,8 @@ class TruncateTextNode(TextProcessingNode):
class AddTextPrefixNode(TextProcessingNode):
node_id = "AddTextPrefix"
display_name = "Add Text Prefix (DEPRECATED)"
category = "text"
display_name = "Add Text Prefix"
description = "Add a prefix to all texts."
is_deprecated = True # This node is superseded by the Concatenate Text node
extra_inputs = [
io.String.Input("prefix", default="", tooltip="Prefix to add."),
]
@ -962,10 +913,8 @@ class AddTextPrefixNode(TextProcessingNode):
class AddTextSuffixNode(TextProcessingNode):
node_id = "AddTextSuffix"
display_name = "Add Text Suffix (DEPRECATED)"
category = "text"
display_name = "Add Text Suffix"
description = "Add a suffix to all texts."
is_deprecated = True # This node is superseded by the Concatenate Text node
extra_inputs = [
io.String.Input("suffix", default="", tooltip="Suffix to add."),
]
@ -977,10 +926,8 @@ class AddTextSuffixNode(TextProcessingNode):
class ReplaceTextNode(TextProcessingNode):
node_id = "ReplaceText"
display_name = "Replace Text (DEPRECATED)"
category = "text"
display_name = "Replace Text"
description = "Replace text in all texts."
is_deprecated = True # This node is superseded by the other Replace Text node
extra_inputs = [
io.String.Input("find", default="", tooltip="Text to find."),
io.String.Input("replace", default="", tooltip="Text to replace with."),
@ -993,10 +940,8 @@ class ReplaceTextNode(TextProcessingNode):
class StripWhitespaceNode(TextProcessingNode):
node_id = "StripWhitespace"
display_name = "Strip Whitespace (DEPRECATED)"
category = "text"
display_name = "Strip Whitespace"
description = "Strip leading and trailing whitespace from all texts."
is_deprecated = True # This node is superseded by the Trim Text node
@classmethod
def _process(cls, text):
@ -1007,13 +952,11 @@ class StripWhitespaceNode(TextProcessingNode):
class ImageDeduplicationNode(ImageProcessingNode):
"""Remove duplicate or very similar images from a list using perceptual hashing."""
"""Remove duplicate or very similar images from the dataset using perceptual hashing."""
node_id = "ImageDeduplication"
search_aliases=["deduplicate", "remove duplicates", "similarity filter"]
display_name = "Deduplicate Images"
category = "image/batch"
description = "Remove duplicate or very similar images from a list."
display_name = "Image Deduplication"
description = "Remove duplicate or very similar images from the dataset."
is_group_process = True # Requires full list to compare images
extra_inputs = [
io.Float.Input(
@ -1083,9 +1026,7 @@ class ImageGridNode(ImageProcessingNode):
"""Combine multiple images into a single grid/collage."""
node_id = "ImageGrid"
search_aliases=["grid", "collage", "combine"]
display_name = "Make Image Grid"
category="image/batch"
display_name = "Image Grid"
description = "Arrange multiple images into a grid layout."
is_group_process = True # Requires full list to create grid
is_output_list = False # Outputs single grid image
@ -1161,12 +1102,9 @@ class MergeImageListsNode(ImageProcessingNode):
"""Merge multiple image lists into a single list."""
node_id = "MergeImageLists"
search_aliases=["list", "merge list", "make list"]
display_name = "Merge Image Lists (DEPRECATED)"
category = "image/batch"
display_name = "Merge Image Lists"
description = "Concatenate multiple image lists into one."
is_group_process = True # Receives images as list
is_deprecated = True # This node is superseded by the Create List node
@classmethod
def _group_process(cls, images):
@ -1181,11 +1119,9 @@ class MergeTextListsNode(TextProcessingNode):
"""Merge multiple text lists into a single list."""
node_id = "MergeTextLists"
display_name = "Merge Text Lists (DEPRECATED)"
category = "text"
display_name = "Merge Text Lists"
description = "Concatenate multiple text lists into one."
is_group_process = True # Receives texts as list
is_deprecated = True # This node is superseded by the Create List node
@classmethod
def _group_process(cls, texts):
@ -1206,10 +1142,8 @@ class ResolutionBucket(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ResolutionBucket",
search_aliases=["bucket by resolution", "group by resolution", "batch by resolution"],
display_name="Resolution Bucket",
category="training",
description="Group latents and conditionings into buckets",
category="dataset",
is_experimental=True,
is_input_list=True,
inputs=[
@ -1302,8 +1236,7 @@ class MakeTrainingDataset(io.ComfyNode):
node_id="MakeTrainingDataset",
search_aliases=["encode dataset"],
display_name="Make Training Dataset",
category="training",
description="Encode images with VAE and texts with CLIP to create a training dataset of latents and conditionings.",
category="dataset",
is_experimental=True,
is_input_list=True, # images and texts as lists
inputs=[
@ -1318,7 +1251,6 @@ class MakeTrainingDataset(io.ComfyNode):
"texts",
optional=True,
tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).",
force_input=True
),
],
outputs=[
@ -1388,10 +1320,9 @@ class SaveTrainingDataset(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveTrainingDataset",
search_aliases=["export dataset", "save dataset"],
search_aliases=["export training data"],
display_name="Save Training Dataset",
category="training",
description="Save encoded training dataset (latents + conditioning) to disk for efficient loading during training.",
category="dataset",
is_experimental=True,
is_output_node=True,
is_input_list=True, # Receive lists
@ -1493,8 +1424,7 @@ class LoadTrainingDataset(io.ComfyNode):
node_id="LoadTrainingDataset",
search_aliases=["import dataset", "training data"],
display_name="Load Training Dataset",
category="training",
description="Load encoded training dataset (latents + conditioning) from disk for use in training.",
category="dataset",
is_experimental=True,
inputs=[
io.String.Input(

View File

@ -419,17 +419,15 @@ class VoxelToMeshBasic(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="VoxelToMeshBasic",
display_name="Voxel to Mesh (Basic) (DEPRECATED)",
display_name="Voxel to Mesh (Basic)",
category="3d",
description="Converts a voxel grid to a mesh.",
is_deprecated=True, # This node is superseded by the Voxel To Mesh node
inputs=[
IO.Voxel.Input("voxel"),
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
],
outputs=[
IO.Mesh.Output(),
],
]
)
@classmethod
@ -455,10 +453,9 @@ class VoxelToMesh(IO.ComfyNode):
node_id="VoxelToMesh",
display_name="Voxel to Mesh",
category="3d",
description="Converts a voxel grid to a mesh.",
inputs=[
IO.Voxel.Input("voxel"),
IO.Combo.Input("algorithm", options=["surface net", "basic"]),
IO.Combo.Input("algorithm", options=["surface net", "basic"], advanced=True),
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
],
outputs=[

View File

@ -55,10 +55,9 @@ class ImageCropV2(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageCropV2",
search_aliases=["crop", "cut", "trim"],
search_aliases=["trim"],
display_name="Crop Image",
category="image/transform",
description = "Crop an image to the specified dimensions.",
essentials_category="Image Tools",
has_intermediate_output=True,
inputs=[

View File

@ -8,82 +8,6 @@ from comfy_api.latest import _io
MISSING = object()
class NotNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ComfyNotNode",
display_name="Not",
category="utils/logic",
description="Logical NOT operation. Returns true if the value is falsy. Uses Python's rules for truthiness.",
search_aliases=["invert", "toggle", "negate", "flip boolean"],
inputs=[
io.AnyType.Input("value"),
],
outputs=[
io.Boolean.Output(),
],
)
@classmethod
def execute(cls, value) -> io.NodeOutput:
return io.NodeOutput(not value)
class AndNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.Autogrow.TemplatePrefix(
input=io.AnyType.Input("value"),
prefix="value",
min=1,
)
return io.Schema(
node_id="ComfyAndNode",
display_name="And",
category="utils/logic",
description="Logical AND operation. Returns true if all of the values are truthy. Uses Python's rules for truthiness.",
search_aliases=["all", "every"],
inputs=[
io.Autogrow.Input("values", template=template),
],
outputs=[
io.Boolean.Output(),
],
)
@classmethod
def execute(cls, values: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(all(values.values()))
class OrNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.Autogrow.TemplatePrefix(
input=io.AnyType.Input("value"),
prefix="value",
min=1,
)
return io.Schema(
node_id="ComfyOrNode",
display_name="Or",
category="utils/logic",
description="Logical OR operation. Returns true if any of the values are truthy. Uses Python's rules for truthiness.",
search_aliases=["any", "some"],
inputs=[
io.Autogrow.Input("values", template=template),
],
outputs=[
io.Boolean.Output(),
],
)
@classmethod
def execute(cls, values: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(any(values.values()))
class SwitchNode(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -91,7 +15,7 @@ class SwitchNode(io.ComfyNode):
return io.Schema(
node_id="ComfySwitchNode",
display_name="Switch",
category="utils/logic",
category="logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
@ -122,7 +46,7 @@ class SoftSwitchNode(io.ComfyNode):
return io.Schema(
node_id="ComfySoftSwitchNode",
display_name="Soft Switch",
category="utils/logic",
category="logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
@ -212,7 +136,7 @@ class DCTestNode(io.ComfyNode):
return io.Schema(
node_id="DCTestNode",
display_name="DCTest",
category="utils/logic",
category="logic",
is_output_node=True,
inputs=[io.DynamicCombo.Input("combo", options=[
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
@ -250,7 +174,7 @@ class AutogrowNamesTestNode(io.ComfyNode):
return io.Schema(
node_id="AutogrowNamesTestNode",
display_name="AutogrowNamesTest",
category="utils/logic",
category="logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
@ -270,7 +194,7 @@ class AutogrowPrefixTestNode(io.ComfyNode):
return io.Schema(
node_id="AutogrowPrefixTestNode",
display_name="AutogrowPrefixTest",
category="utils/logic",
category="logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
@ -289,7 +213,7 @@ class ComboOutputTestNode(io.ComfyNode):
return io.Schema(
node_id="ComboOptionTestNode",
display_name="ComboOptionTest",
category="utils/logic",
category="logic",
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
outputs=[io.Combo.Output(), io.Combo.Output()],
@ -306,7 +230,7 @@ class ConvertStringToComboNode(io.ComfyNode):
node_id="ConvertStringToComboNode",
search_aliases=["string to dropdown", "text to combo"],
display_name="Convert String to Combo",
category="utils/logic",
category="logic",
inputs=[io.String.Input("string")],
outputs=[io.Combo.Output()],
)
@ -322,7 +246,7 @@ class InvertBooleanNode(io.ComfyNode):
node_id="InvertBooleanNode",
search_aliases=["not", "toggle", "negate", "flip boolean"],
display_name="Invert Boolean",
category="utils/logic",
category="logic",
inputs=[io.Boolean.Input("boolean")],
outputs=[io.Boolean.Output()],
)
@ -337,9 +261,6 @@ class LogicExtension(ComfyExtension):
return [
SwitchNode,
CustomComboNode,
NotNode,
AndNode,
OrNode,
# SoftSwitchNode,
# ConvertStringToComboNode,
# DCTestNode,

View File

@ -11,8 +11,8 @@ class LTXVAudioVAELoader(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVAudioVAELoader",
display_name="Load LTXV Audio VAE",
category="loaders",
display_name="LTXV Audio VAE Loader",
category="audio",
inputs=[
io.Combo.Input(
"ckpt_name",
@ -40,7 +40,7 @@ class LTXVAudioVAEEncode(VAEEncodeAudio):
return io.Schema(
node_id="LTXVAudioVAEEncode",
display_name="LTXV Audio VAE Encode",
category="latent/audio",
category="audio",
inputs=[
io.Audio.Input("audio", tooltip="The audio to be encoded."),
io.Vae.Input(
@ -63,7 +63,7 @@ class LTXVAudioVAEDecode(io.ComfyNode):
return io.Schema(
node_id="LTXVAudioVAEDecode",
display_name="LTXV Audio VAE Decode",
category="latent/audio",
category="audio",
inputs=[
io.Latent.Input("samples", tooltip="The latent to be decoded."),
io.Vae.Input(

View File

@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode):
return io.Schema(
node_id="ComfyMathExpression",
display_name="Math Expression",
category="utils",
category="logic",
search_aliases=[
"expression", "formula", "calculate", "calculator",
"eval", "math",

View File

@ -28,7 +28,7 @@ from comfy_extras.mediapipe.face_landmarker import FaceLandmarker
from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection
FaceDetectionType = io.Custom("FACE_DETECTION_MODEL")
FaceLandmarkerType = io.Custom("FACE_LANDMARKER")
FaceLandmarksType = io.Custom("FACE_LANDMARKS")
_CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights")
@ -204,19 +204,18 @@ class LoadMediaPipeFaceLandmarker(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LoadMediaPipeFaceLandmarker",
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
display_name="Load Face Detection Model (MediaPipe)",
display_name="Load MediaPipe Face Landmarker",
category="loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("detection"),
tooltip="Face detection model from models/detection/."),
io.Combo.Input("model_name", options=folder_paths.get_filename_list("mediapipe"),
tooltip="Face Landmarker safetensors from models/mediapipe/."),
],
outputs=[FaceDetectionType.Output()],
outputs=[FaceLandmarkerType.Output()],
)
@classmethod
def execute(cls, model_name) -> io.NodeOutput:
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("detection", model_name), safe_load=True)
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("mediapipe", model_name), safe_load=True)
wrapper = FaceLandmarkerModel(sd)
return io.NodeOutput(wrapper)
@ -235,12 +234,10 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MediaPipeFaceLandmarker",
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
display_name="Detect Face Landmarks (MediaPipe)",
display_name="MediaPipe Face Landmarker",
category="image/detection",
description="Detects facial landmarks using MediaPipe model.",
inputs=[
FaceDetectionType.Input("face_detection_model"),
FaceLandmarkerType.Input("face_landmarker"),
io.Image.Input("image"),
io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short",
tooltip="Face detector range. 'short' is tuned for close-up faces "
@ -264,9 +261,9 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
)
@classmethod
def execute(cls, face_detection_model, image, detector_variant, num_faces, min_confidence,
def execute(cls, face_landmarker, image, detector_variant, num_faces, min_confidence,
missing_frame_fallback) -> io.NodeOutput:
canonical = face_detection_model.canonical_data
canonical = face_landmarker.canonical_data
img_np = _image_to_uint8(image)
B, H, W = img_np.shape[:3]
chunk = 16
@ -279,7 +276,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq:
for i in range(0, B, chunk):
end = min(i + chunk, B)
res.extend(face_detection_model.detect_batch(
res.extend(face_landmarker.detect_batch(
[img_np[bi] for bi in range(i, end)],
num_faces=int(num_faces),
score_thresh=float(min_confidence),
@ -309,7 +306,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])})
bboxes.append(per_bb)
return io.NodeOutput({"frames": frames, "image_size": (H, W),
"connection_sets": face_detection_model.connection_sets}, bboxes)
"connection_sets": face_landmarker.connection_sets}, bboxes)
# Topology keys unioned by the 'all' connections preset (contour parts + irises + nose).
@ -335,10 +332,8 @@ class MediaPipeFaceMeshVisualize(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MediaPipeFaceMeshVisualize",
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection", "visualize"],
display_name="Visualize Face Landmarks (MediaPipe)",
display_name="MediaPipe Face Mesh Visualize",
category="image/detection",
description="Draws face landmarks mesh on the input image.",
inputs=[
FaceLandmarksType.Input("face_landmarks"),
io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."),
@ -448,10 +443,8 @@ class MediaPipeFaceMask(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MediaPipeFaceMask",
search_aliases=["face", "facial", "mediapipe", "face mask", "blazeface", "face detection", "visualize"],
display_name="Draw Face Mask (MediaPipe)",
display_name="MediaPipe Face Mask",
category="image/detection",
description="Draws a mask from face landmarks.",
inputs=[
FaceLandmarksType.Input("face_landmarks"),
io.DynamicCombo.Input(

View File

@ -103,10 +103,8 @@ class MoGePanoramaInference(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MoGePanoramaInference",
search_aliases=["moge", "panorama", "depth", "geometry", "depth estimation", "geometry estimation"],
display_name="Run MoGe Panorama Inference",
display_name="MoGe Panorama Inference",
category="image/geometry_estimation",
description="Run MoGe on an equirectangular panorama by splitting it into 12 perspective views, running inference on each, and merging the results into a single depth map.",
inputs=[
MoGeModelType.Input("moge_model"),
io.Image.Input("image", tooltip="Equirectangular panorama (any aspect)."),
@ -224,9 +222,7 @@ class MoGeInference(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MoGeInference",
search_aliases=["moge", "depth", "geometry", "depth estimation", "geometry estimation"],
display_name="Run MoGe Inference",
description="Run MoGe on a single image to estimate depth and geometry.",
display_name="MoGe Inference",
category="image/geometry_estimation",
inputs=[
MoGeModelType.Input("moge_model"),
@ -281,9 +277,7 @@ class MoGeRender(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MoGeRender",
search_aliases=["moge", "render", "geometry", "depth", "normal"],
display_name="Render MoGe Geometry",
description="Render a depth map or normal map from geometry data",
display_name="MoGe Render",
category="image/geometry_estimation",
inputs=[
MoGeGeometry.Input("moge_geometry"),
@ -348,9 +342,7 @@ class MoGePointMapToMesh(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="MoGePointMapToMesh",
search_aliases=["moge", "mesh", "geometry", "point map"],
display_name="Convert MoGe Point Map to Mesh",
description="Convert a MoGe point map into a 3D mesh.",
display_name="MoGe Point Map to Mesh",
category="image/geometry_estimation",
inputs=[
MoGeGeometry.Input("moge_geometry"),

View File

@ -14,7 +14,7 @@ class CreateList(io.ComfyNode):
return io.Schema(
node_id="CreateList",
display_name="Create List",
category="utils",
category="logic",
is_input_list=True,
search_aliases=["Image Iterator", "Text Iterator", "Iterator"],
inputs=[io.Autogrow.Input("inputs", template=template_autogrow)],

View File

@ -60,7 +60,7 @@ folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geom
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
folder_names_and_paths["detection"] = ([os.path.join(models_dir, "detection")], supported_pt_extensions)
folder_names_and_paths["mediapipe"] = ([os.path.join(models_dir, "mediapipe")], supported_pt_extensions)
output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp")

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.43.18
comfyui-workflow-templates==0.9.82
comfyui-workflow-templates==0.9.79
comfyui-embedded-docs==0.5.0
torch
torchsde
@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
filelock
av>=14.2.0
comfy-kitchen>=0.2.8
comfy-aimdo==0.4.5
comfy-aimdo==0.4.3
requests
simpleeval>=1.0.0
blake3

View File

@ -21,6 +21,7 @@ from app.assets.database.queries import (
get_reference_ids_by_ids,
ensure_tags_exist,
add_tags_to_reference,
set_reference_tags,
)
from app.assets.helpers import get_utc_now
@ -159,6 +160,153 @@ class TestListReferencesPage:
assert refs[0].name == "large"
class TestTagRetrievalOrder:
"""End-to-end check: tags written through the public write paths come
back from the public read paths in insertion order rather than the
composite-PK alphabetical order SQLite would otherwise impose.
Each test deliberately picks tag names that would sort differently
under alphabetical vs insertion order, so an alphabetical regression
fails loudly.
"""
def _make_ref(self, session: Session) -> AssetReference:
asset = _make_asset(session, "h1")
return _make_reference(session, asset, name="x.bin")
def test_set_reference_tags_preserves_input_order_in_list(self, session: Session):
ref = self._make_ref(session)
# "checkpoints" < "models" alphabetically; if added_at stagger
# works, list_references_page returns insertion order.
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
session.commit()
_, tag_map, _ = list_references_page(session)
assert tag_map[ref.id] == ["models", "checkpoints"]
def test_set_reference_tags_preserves_input_order_in_fetch(self, session: Session):
ref = self._make_ref(session)
# Subpath tag sorts before "models" alphabetically.
set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "diffusers/kolors/text_encoder"],
)
session.commit()
result = fetch_reference_asset_and_tags(session, ref.id)
assert result is not None
_, _, tags = result
# Bucket-prefix expansion appends the standalone `diffusers` token
# at path-tier (microsecond stagger) so FE set-membership filters
# match nested category paths.
assert tags == ["models", "diffusers/kolors/text_encoder", "diffusers"]
def test_add_tags_to_reference_lands_after_path_tags(self, session: Session):
ref = self._make_ref(session)
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
session.commit()
# "aaa-..." sorts before both path tags alphabetically. If added_at
# stagger is missing, alphabetic tiebreak would hoist it to tags[0].
add_tags_to_reference(
session, reference_id=ref.id, tags=["aaa-user-tag"], origin="manual"
)
session.commit()
_, tag_map, _ = list_references_page(session)
assert tag_map[ref.id] == ["models", "checkpoints", "aaa-user-tag"]
def test_multi_tag_batch_lands_after_path_tags(self, session: Session):
ref = self._make_ref(session)
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
session.commit()
# Three user tags inserted in non-alphabetical input order. Per-tag
# microsecond stagger should preserve at least the "user batch is
# after path tags" property; within the user batch insertion order
# is also preserved.
add_tags_to_reference(
session,
reference_id=ref.id,
tags=["zzz-z", "favorite", "experiment-q4"],
origin="manual",
)
session.commit()
_, tag_map, _ = list_references_page(session)
tags = tag_map[ref.id]
assert tags[0:2] == ["models", "checkpoints"]
assert set(tags[2:]) == {"zzz-z", "favorite", "experiment-q4"}
def test_user_batch_lands_after_path_batch_under_clock_collision(
self, session: Session, monkeypatch: pytest.MonkeyPatch
):
"""Windows-specific race: when two back-to-back commits share the
same datetime.now() microsecond, the path-tier and user-tier
added_at values used to collide and alphabetic tiebreak would
hoist user tags ahead of path tags. The fix reads
max(existing_added_at) for the reference and seeds the next batch
past it, deterministically restoring insertion order.
This test simulates the collision by pinning get_utc_now() so the
platform-dependent race becomes a platform-independent failure.
"""
ref = self._make_ref(session)
from datetime import datetime
from app.assets.database import queries as queries_pkg
from app.assets.database.queries import tags as tags_module
frozen = datetime(2026, 1, 1, 0, 0, 0)
monkeypatch.setattr(tags_module, "get_utc_now", lambda: frozen)
monkeypatch.setattr(queries_pkg, "get_utc_now", lambda: frozen, raising=False)
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
session.commit()
# Same frozen timestamp — without the max(existing) seed, the
# user batch would share added_at with the path batch and
# `aaa-user-tag` would sort to position 0 via the alphabetic
# tiebreaker.
add_tags_to_reference(
session, reference_id=ref.id, tags=["aaa-user-tag"], origin="manual"
)
session.commit()
_, tag_map, _ = list_references_page(session)
assert tag_map[ref.id] == ["models", "checkpoints", "aaa-user-tag"]
def test_remove_then_add_does_not_disrupt_path_tag_positions(
self, session: Session
):
ref = self._make_ref(session)
set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "loras/my/custom/path"],
)
session.commit()
add_tags_to_reference(session, reference_id=ref.id, tags=["temp-tag"])
session.commit()
from app.assets.database.queries import remove_tags_from_reference
remove_tags_from_reference(session, reference_id=ref.id, tags=["temp-tag"])
session.commit()
add_tags_to_reference(session, reference_id=ref.id, tags=["second-tag"])
session.commit()
_, tag_map, _ = list_references_page(session)
# `loras` is expanded from the nested category path; user-added
# tags trail behind it via the microsecond stagger.
assert tag_map[ref.id] == [
"models",
"loras/my/custom/path",
"loras",
"second-tag",
]
class TestFetchReferenceAssetAndTags:
def test_returns_none_for_nonexistent(self, session: Session):
result = fetch_reference_asset_and_tags(session, "nonexistent")

View File

@ -160,6 +160,120 @@ class TestAddTagsToReference:
add_tags_to_reference(session, reference_id="nonexistent", tags=["x"])
class TestBucketPrefixExpansion:
"""The standalone bucket token must appear in the asset's tag set for
nested category paths so FE filters like
`include_tags=models,checkpoints` continue to match.
"""
def test_set_reference_tags_inserts_bucket_for_nested_path(
self, session: Session
):
asset = _make_asset(session, "hash-nested")
ref = _make_reference(session, asset)
result = set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "checkpoints/flux"],
)
session.commit()
assert set(result.total) == {"models", "checkpoints/flux", "checkpoints"}
stored = get_reference_tags(session, reference_id=ref.id)
# tag[1] keeps the slash-joined positional contract; the standalone
# bucket lands after it via path-tier microsecond stagger so user
# tags remain at the tail.
assert stored[:3] == ["models", "checkpoints/flux", "checkpoints"]
def test_set_reference_tags_idempotent_on_replay(self, session: Session):
asset = _make_asset(session, "hash-replay")
ref = _make_reference(session, asset)
set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "checkpoints/flux"],
)
# Replay with the same caller-supplied set; expansion is already
# baked in, so nothing should be added or removed.
result = set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "checkpoints/flux"],
)
session.commit()
assert result.added == []
assert result.removed == []
assert set(result.total) == {"models", "checkpoints/flux", "checkpoints"}
def test_add_tags_to_reference_expands_bucket(self, session: Session):
asset = _make_asset(session, "hash-add")
ref = _make_reference(session, asset)
result = add_tags_to_reference(
session,
reference_id=ref.id,
tags=["loras/style/v2"],
)
session.commit()
assert set(result.added) == {"loras/style/v2", "loras"}
stored = get_reference_tags(session, reference_id=ref.id)
assert "loras" in stored
assert "loras/style/v2" in stored
def test_add_tags_does_not_duplicate_existing_bucket(self, session: Session):
asset = _make_asset(session, "hash-dedupe")
ref = _make_reference(session, asset)
add_tags_to_reference(
session, reference_id=ref.id, tags=["models", "checkpoints"]
)
result = add_tags_to_reference(
session, reference_id=ref.id, tags=["checkpoints/flux"]
)
session.commit()
# `checkpoints` was already there from the first add; only the
# slash-joined token is genuinely new.
assert result.added == ["checkpoints/flux"]
assert "checkpoints" in result.already_present
def test_flat_category_is_unaffected(self, session: Session):
asset = _make_asset(session, "hash-flat")
ref = _make_reference(session, asset)
result = set_reference_tags(
session,
reference_id=ref.id,
tags=["models", "checkpoints"],
)
session.commit()
assert set(result.total) == {"models", "checkpoints"}
assert get_reference_tags(session, reference_id=ref.id) == [
"models",
"checkpoints",
]
def test_unknown_prefix_passes_through(self, session: Session):
asset = _make_asset(session, "hash-user")
ref = _make_reference(session, asset)
# `my-org` isn't a registered bucket — the slash-joined user tag
# should not trigger bucket expansion.
result = set_reference_tags(
session,
reference_id=ref.id,
tags=["my-org/team-a"],
)
session.commit()
assert result.total == ["my-org/team-a"]
class TestRemoveTagsFromReference:
def test_removes_tags(self, session: Session):
asset = _make_asset(session, "hash1")

View File

@ -4,7 +4,7 @@ from pathlib import Path
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag
from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets
@ -102,6 +102,82 @@ class TestBatchInsertSeedAssets:
assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}"
class TestBucketPrefixExpansionOnIngest:
"""Path-scanning ingest must persist the standalone bucket token for
nested category paths so the FE set-membership filter
(`include_tags=models,checkpoints`) matches assets organized into
subfolders (`models/checkpoints/flux/foo.safetensors`).
"""
def test_nested_path_inserts_standalone_bucket(
self, session: Session, temp_dir: Path
):
file_path = temp_dir / "flux.safetensors"
file_path.write_bytes(b"content")
specs: list[SeedAssetSpec] = [
{
"abs_path": str(file_path),
"size_bytes": 7,
"mtime_ns": 1234567890000000000,
"info_name": "flux",
# Shape emitted by get_name_and_tags_from_asset_path for a
# nested model path.
"tags": ["models", "checkpoints/flux"],
"fname": "flux.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
}
]
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
ref = session.query(AssetReference).filter_by(name="flux").one()
stored = [
row.tag_name
for row in session.query(AssetReferenceTag)
.filter_by(asset_reference_id=ref.id)
.order_by(AssetReferenceTag.added_at.asc())
.all()
]
assert stored == ["models", "checkpoints/flux", "checkpoints"]
def test_flat_path_remains_two_tags(
self, session: Session, temp_dir: Path
):
file_path = temp_dir / "vanilla.safetensors"
file_path.write_bytes(b"content")
specs: list[SeedAssetSpec] = [
{
"abs_path": str(file_path),
"size_bytes": 7,
"mtime_ns": 1234567890000000000,
"info_name": "vanilla",
"tags": ["models", "checkpoints"],
"fname": "vanilla.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
}
]
batch_insert_seed_assets(session, specs=specs, owner_id="")
ref = session.query(AssetReference).filter_by(name="vanilla").one()
stored = {
row.tag_name
for row in session.query(AssetReferenceTag)
.filter_by(asset_reference_id=ref.id)
.all()
}
# Dedupe means flat layouts don't pick up a redundant `checkpoints`
# row — tag[1] already serves both positional and set-membership.
assert stored == {"models", "checkpoints"}
class TestMetadataExtraction:
def test_extracts_mime_type_for_model_files(self, temp_dir: Path):
"""Verify metadata extraction returns correct mime_type for model files."""

View File

@ -6,7 +6,11 @@ from unittest.mock import patch
import pytest
from app.assets.services.path_utils import get_asset_category_and_relative_path
from app.assets.services.path_utils import (
get_asset_category_and_relative_path,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags,
)
@pytest.fixture
@ -38,6 +42,50 @@ def fake_dirs():
}
@pytest.fixture
def fake_dirs_multi_bucket():
"""Variant fixture with multiple model buckets (checkpoints + diffusers + loras)."""
with tempfile.TemporaryDirectory() as root:
root_path = Path(root)
input_dir = root_path / "input"
output_dir = root_path / "output"
temp_dir = root_path / "temp"
checkpoints_dir = root_path / "models" / "checkpoints"
diffusers_dir = root_path / "models" / "diffusers"
loras_dir = root_path / "models" / "loras"
for d in (
input_dir,
output_dir,
temp_dir,
checkpoints_dir,
diffusers_dir,
loras_dir,
):
d.mkdir(parents=True)
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.get_temp_directory.return_value = str(temp_dir)
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[
("checkpoints", [str(checkpoints_dir)]),
("diffusers", [str(diffusers_dir)]),
("loras", [str(loras_dir)]),
],
):
yield {
"input": input_dir,
"output": output_dir,
"temp": temp_dir,
"checkpoints": checkpoints_dir,
"diffusers": diffusers_dir,
"loras": loras_dir,
}
class TestGetAssetCategoryAndRelativePath:
def test_input_file(self, fake_dirs):
f = fake_dirs["input"] / "photo.png"
@ -79,3 +127,161 @@ class TestGetAssetCategoryAndRelativePath:
def test_unknown_path_raises(self, fake_dirs):
with pytest.raises(ValueError, match="not within"):
get_asset_category_and_relative_path("/some/random/path.png")
class TestGetNameAndTagsFromAssetPath:
"""tags collapse the parent subpath into a single slash-joined tag.
Consumers should be able to read ``tags[1]`` as a stable category
identifier regardless of how deep the file lives in the bucket.
"""
def test_flat_input(self, fake_dirs_multi_bucket):
f = fake_dirs_multi_bucket["input"] / "photo.png"
f.touch()
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "photo.png"
assert tags == ["input"]
def test_flat_output(self, fake_dirs_multi_bucket):
f = fake_dirs_multi_bucket["output"] / "result_00001.png"
f.touch()
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "result_00001.png"
assert tags == ["output"]
def test_flat_models_checkpoint(self, fake_dirs_multi_bucket):
f = fake_dirs_multi_bucket["checkpoints"] / "flux.safetensors"
f.touch()
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "flux.safetensors"
assert tags == ["models", "checkpoints"]
def test_diffusers_nested_subpath_slash_joined(self, fake_dirs_multi_bucket):
"""Diffusers components live in nested directories — the full subpath
must collapse into one tag so consumers can look up the model category
via tags[1] regardless of nesting depth.
The subpath is lowercased to match the canonicalization
:func:`ensure_tags_exist` applies on the write side; without that,
the asset_reference_tags.tag_name FK to tags.name would fail for
any path containing uppercase letters.
"""
nested = (
fake_dirs_multi_bucket["diffusers"]
/ "Kolors"
/ "text_encoder"
)
nested.mkdir(parents=True)
f = nested / "model.safetensors"
f.touch()
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "model.safetensors"
assert tags == ["models", "diffusers/kolors/text_encoder"]
def test_deep_lora_user_subpath_slash_joined(self, fake_dirs_multi_bucket):
"""User-created subdirectories under a model bucket also collapse to a
single tag rather than one tag per directory."""
nested = (
fake_dirs_multi_bucket["loras"]
/ "my"
/ "custom"
/ "path"
)
nested.mkdir(parents=True)
f = nested / "v0001.safetensors"
f.touch()
name, tags = get_name_and_tags_from_asset_path(str(f))
assert name == "v0001.safetensors"
assert tags == ["models", "loras/my/custom/path"]
class TestResolveDestinationFromTags:
"""resolve_destination_from_tags must accept both the legacy
one-tag-per-directory shape and the new slash-joined shape so that an
upload using the tags it just read back from /api/assets round-trips
to the right on-disk destination.
"""
@pytest.fixture
def resolve_dirs(self):
with tempfile.TemporaryDirectory() as root:
root_path = Path(root)
input_dir = root_path / "input"
output_dir = root_path / "output"
checkpoints_dir = root_path / "models" / "checkpoints"
diffusers_dir = root_path / "models" / "diffusers"
loras_dir = root_path / "models" / "loras"
for d in (input_dir, output_dir, checkpoints_dir, diffusers_dir, loras_dir):
d.mkdir(parents=True)
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.folder_names_and_paths = {
"checkpoints": ([str(checkpoints_dir)], None),
"diffusers": ([str(diffusers_dir)], None),
"loras": ([str(loras_dir)], None),
}
yield {
"input": input_dir,
"output": output_dir,
"checkpoints": checkpoints_dir,
"diffusers": diffusers_dir,
"loras": loras_dir,
}
def test_models_flat_category(self, resolve_dirs):
base, subdirs = resolve_destination_from_tags(["models", "checkpoints"])
assert base == str(resolve_dirs["checkpoints"])
assert subdirs == []
def test_models_slash_joined_new_shape(self, resolve_dirs):
# The shape get_name_and_tags_from_asset_path now emits.
base, subdirs = resolve_destination_from_tags(
["models", "diffusers/kolors/text_encoder"]
)
assert base == str(resolve_dirs["diffusers"])
assert subdirs == ["kolors", "text_encoder"]
def test_models_legacy_one_tag_per_dir(self, resolve_dirs):
# The legacy shape must still resolve identically.
base, subdirs = resolve_destination_from_tags(
["models", "diffusers", "kolors", "text_encoder"]
)
assert base == str(resolve_dirs["diffusers"])
assert subdirs == ["kolors", "text_encoder"]
def test_models_loras_slash_joined(self, resolve_dirs):
base, subdirs = resolve_destination_from_tags(
["models", "loras/my/custom/path"]
)
assert base == str(resolve_dirs["loras"])
assert subdirs == ["my", "custom", "path"]
def test_input_no_subdir(self, resolve_dirs):
base, subdirs = resolve_destination_from_tags(["input"])
assert base == str(resolve_dirs["input"])
assert subdirs == []
def test_input_slash_joined_subdir(self, resolve_dirs):
base, subdirs = resolve_destination_from_tags(["input", "portraits/2026"])
assert base == str(resolve_dirs["input"])
assert subdirs == ["portraits", "2026"]
def test_output_slash_joined_subdir(self, resolve_dirs):
base, subdirs = resolve_destination_from_tags(["output", "runs/abc"])
assert base == str(resolve_dirs["output"])
assert subdirs == ["runs", "abc"]
def test_unknown_category_rejected(self, resolve_dirs):
with pytest.raises(ValueError, match="unknown model category"):
resolve_destination_from_tags(["models", "not_a_real_category"])
def test_unknown_category_via_slash_joined(self, resolve_dirs):
# First segment of a slash-joined tag must still match a registered category.
with pytest.raises(ValueError, match="unknown model category 'bogus'"):
resolve_destination_from_tags(["models", "bogus/sub/path"])
def test_traversal_in_subdir_rejected(self, resolve_dirs):
with pytest.raises(ValueError, match="invalid path component"):
resolve_destination_from_tags(["models", "checkpoints/..", "evil"])

View File

@ -32,7 +32,7 @@ def test_seed_asset_removed_when_file_is_deleted(
# Verify it is visible via API and carries no hash (seed)
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
params={"include_tags": "unit-tests/syncseed", "name_contains": name},
timeout=120,
)
body1 = r1.json()
@ -52,7 +52,7 @@ def test_seed_asset_removed_when_file_is_deleted(
# It should disappear (AssetInfo and seed Asset gone)
r2 = http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
params={"include_tags": "unit-tests/syncseed", "name_contains": name},
timeout=120,
)
body2 = r2.json()
@ -332,7 +332,7 @@ def test_fastpass_removes_stale_state_row_no_missing(
rl = http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}"},
params={"include_tags": f"unit-tests/{scope}"},
timeout=120,
)
bl = rl.json()

View File

@ -280,9 +280,15 @@ def test_metadata_filename_is_set_for_seed_asset_without_hash(
trigger_sync_seed_assets(http, api_base)
# Scanner emits tags as ``[root, "<dir1>/<dir2>/..."]`` — the second tag
# is the slash-joined parent subpath. For ``<root>/unit-tests/<scope>/a/b/<name>``
# the second tag is ``"unit-tests/<scope>/a/b"``.
r1 = http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
params={
"include_tags": f"unit-tests/{scope}/a/b",
"name_contains": name,
},
timeout=120,
)
body = r1.json()

View File

@ -0,0 +1,69 @@
"""Unit tests for app.assets.helpers."""
from app.assets.helpers import expand_bucket_prefixes
class TestExpandBucketPrefixes:
def test_flat_category_unchanged(self):
# `checkpoints` is already a standalone token, no expansion needed.
assert expand_bucket_prefixes(["models", "checkpoints"]) == [
"models",
"checkpoints",
]
def test_nested_category_inserts_bucket(self):
# Path-derived shape for `models/checkpoints/flux/foo.safetensors` —
# the standalone bucket has to be present so the FE set-membership
# filter (`include_tags=models,checkpoints`) matches the asset.
assert expand_bucket_prefixes(["models", "checkpoints/flux"]) == [
"models",
"checkpoints/flux",
"checkpoints",
]
def test_deeply_nested_only_first_segment_expands(self):
# Only the FIRST slash segment ever gets emitted as a standalone —
# intermediate path segments don't have routing significance.
assert expand_bucket_prefixes(
["models", "diffusers/kolors/text_encoder"]
) == ["models", "diffusers/kolors/text_encoder", "diffusers"]
def test_unknown_prefix_does_not_expand(self):
# Free-form user labels with slashes whose first segment is not a
# registered bucket pass through opaquely.
assert expand_bucket_prefixes(["models", "my-org/team-a"]) == [
"models",
"my-org/team-a",
]
def test_idempotent(self):
# Re-applying the helper is a no-op once the bucket is in the set.
expanded = expand_bucket_prefixes(["models", "checkpoints/flux"])
assert expand_bucket_prefixes(expanded) == expanded
def test_does_not_duplicate_existing_bucket(self):
# If the caller already supplied the standalone bucket, don't add a
# second copy.
assert expand_bucket_prefixes(
["models", "checkpoints/flux", "checkpoints"]
) == ["models", "checkpoints/flux", "checkpoints"]
def test_preserves_caller_order(self):
# User tags after path tags must stay after; the inserted bucket
# token slots in immediately after its slash-joined parent so the
# microsecond stagger lands it at path-tier before user-tier.
assert expand_bucket_prefixes(
["models", "loras/style", "favorite", "v2"]
) == ["models", "loras/style", "loras", "favorite", "v2"]
def test_empty_input(self):
assert expand_bucket_prefixes([]) == []
def test_input_root_with_subpath_no_expansion(self):
# `portraits` isn't a registered model category, so the input
# subpath stays opaque (FE filter doesn't have a checkpoint-loader
# analogue for input subfolders).
assert expand_bucket_prefixes(["input", "portraits/2026"]) == [
"input",
"portraits/2026",
]

View File

@ -29,7 +29,10 @@ def create_seed_file(comfy_tmp_base_dir: Path):
def find_asset(http: requests.Session, api_base: str):
"""Query API for assets matching scope and optional name."""
def _find(scope: str, name: str | None = None) -> list[dict]:
params = {"include_tags": f"unit-tests,{scope}"}
# Scanner now emits tags as ``[root, "<dir1>/<dir2>/..."]`` rather than
# one tag per directory. For files at ``<root>/unit-tests/<scope>/...``
# the second tag is exactly ``"unit-tests/<scope>"``.
params = {"include_tags": f"unit-tests/{scope}"}
if name:
params["name_contains"] = name
r = http.get(f"{api_base}/api/assets", params=params, timeout=120)
@ -138,4 +141,7 @@ def test_special_chars_in_path_escaped_correctly(
trigger_sync_seed_assets(http, api_base)
trigger_sync_seed_assets(http, api_base)
assert find_asset(scope.split("/")[0], fp.name), "Asset with special chars should survive"
# Scanner emits the full parent subpath as a single slash-joined tag, so
# the lookup tag is ``unit-tests/<scope>`` even when <scope> itself
# contains a slash (parent + special-char dirname).
assert find_asset(scope, fp.name), "Asset with special chars should survive"

View File

@ -0,0 +1,135 @@
"""HTTP-layer smoke test: user-added tags via POST /api/assets/{id}/tags
land after path tags when read back via GET /api/assets.
Exercises the full route handler -> service -> query path that the unit
tests at tests-unit/assets_test/queries/test_asset_info.py only cover at
the service layer.
"""
import json
import pytest
import requests
@pytest.fixture
def smoke_asset(http: requests.Session, api_base: str):
"""Upload a single asset into models/checkpoints/unit-tests/smoke
and delete it on teardown."""
name = "smoke_user_tag.safetensors"
tags = ["models", "checkpoints", "unit-tests", "smoke"]
files = {"file": (name, b"S" * 4096, "application/octet-stream")}
form_data = {
"tags": json.dumps(tags),
"name": name,
"user_metadata": json.dumps({}),
}
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
assert r.status_code == 201, r.text
body = r.json()
yield body
http.delete(
f"{api_base}/api/assets/{body['id']}?delete_content=true", timeout=30
)
def _fetch_asset_tags(http, api_base, ref_id):
r = http.get(f"{api_base}/api/assets/{ref_id}", timeout=30)
assert r.status_code == 200, r.text
return r.json()["tags"]
def test_user_tag_lands_after_path_tags_via_http(
http: requests.Session, api_base: str, smoke_asset: dict
):
ref_id = smoke_asset["id"]
initial_tags = _fetch_asset_tags(http, api_base, ref_id)
# Path tags should already be at the front in upload order.
assert initial_tags[:2] == ["models", "checkpoints"]
# Add a user tag that would jump to position 0 under alphabetical sort.
r = http.post(
f"{api_base}/api/assets/{ref_id}/tags",
json={"tags": ["aaa-user-tag"]},
timeout=30,
)
assert r.status_code in (200, 201), r.text
tags_after = _fetch_asset_tags(http, api_base, ref_id)
# Path tags must still be at the front; user tag goes to the end.
assert tags_after[0] == "models"
assert tags_after[1] == "checkpoints"
assert "aaa-user-tag" in tags_after
assert tags_after[-1] == "aaa-user-tag"
def test_user_tag_batch_lands_after_path_tags_via_http(
http: requests.Session, api_base: str, smoke_asset: dict
):
ref_id = smoke_asset["id"]
# Add three user tags in a single request, in non-alphabetical input
# order. They should all land after the path tags (microsecond stagger
# in set_reference_tags / add_tags_to_reference is what makes this
# work — without it, "aaa" would jump to position 0).
r = http.post(
f"{api_base}/api/assets/{ref_id}/tags",
json={"tags": ["zzz-z", "favorite", "aaa-experiment"]},
timeout=30,
)
assert r.status_code in (200, 201), r.text
tags_after = _fetch_asset_tags(http, api_base, ref_id)
assert tags_after[0] == "models"
assert tags_after[1] == "checkpoints"
user_tail = tags_after[len({"models", "checkpoints", "unit-tests", "smoke"}):]
assert set(user_tail) >= {"zzz-z", "favorite", "aaa-experiment"}
# Critically: alphabetical sort would put 'aaa-experiment' at position 0.
assert tags_after.index("aaa-experiment") > tags_after.index("models")
assert tags_after.index("aaa-experiment") > tags_after.index("checkpoints")
@pytest.fixture
def nested_checkpoint_asset(http: requests.Session, api_base: str):
"""Upload a checkpoint at the slash-joined path shape cloud emits
(`models/checkpoints/flux/...`), then delete it on teardown.
"""
name = "nested_checkpoint.safetensors"
tags = ["models", "checkpoints/flux"]
files = {"file": (name, b"S" * 4096, "application/octet-stream")}
form_data = {
"tags": json.dumps(tags),
"name": name,
"user_metadata": json.dumps({}),
}
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
assert r.status_code == 201, r.text
body = r.json()
yield body
http.delete(
f"{api_base}/api/assets/{body['id']}?delete_content=true", timeout=30
)
def test_nested_checkpoint_satisfies_fe_set_filter(
http: requests.Session, api_base: str, nested_checkpoint_asset: dict
):
"""The case Simon flagged: a nested-path checkpoint must still match
`include_tags=models,checkpoints` — the FE combo-widget filter.
"""
ref_id = nested_checkpoint_asset["id"]
stored = _fetch_asset_tags(http, api_base, ref_id)
# tag[1] keeps cloud's slash-joined positional contract; tag[2] holds
# the standalone bucket the FE filter looks for.
assert stored[:3] == ["models", "checkpoints/flux", "checkpoints"]
# The actual FE query — exact set-membership across both tokens.
r = http.get(
f"{api_base}/api/assets",
params=[("include_tags", "models"), ("include_tags", "checkpoints")],
timeout=30,
)
assert r.status_code == 200, r.text
returned_ids = {a["id"] for a in r.json()["assets"]}
assert ref_id in returned_ids