mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-23 09:38:08 +08:00
Compare commits
14 Commits
matt/opena
...
matt/be-94
| Author | SHA1 | Date | |
|---|---|---|---|
| 2015bbb54a | |||
| 37764dc40c | |||
| cc62f2a9e8 | |||
| e4508af5e4 | |||
| 1c30b374de | |||
| ab47c85f95 | |||
| 9a7f580b37 | |||
| 33a57cc9e8 | |||
| 9b0042d78c | |||
| d0258ae53d | |||
| df1f6a7fcc | |||
| 39abd769b1 | |||
| 5a70aeebe8 | |||
| 28f60ccea5 |
519
.github/workflows/backport_release.yaml
vendored
519
.github/workflows/backport_release.yaml
vendored
@ -1,519 +0,0 @@
|
||||
name: Backport Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
commit:
|
||||
description: 'Full 40-char SHA of the tip commit of the backport source branch (the PR head commit that passed tests). The branch is resolved from this SHA and must be unique.'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
checks: read
|
||||
|
||||
jobs:
|
||||
backport-release:
|
||||
name: Create backport release
|
||||
runs-on: ubuntu-latest
|
||||
environment: backport release
|
||||
|
||||
steps:
|
||||
- name: Generate GitHub App token
|
||||
id: app-token
|
||||
uses: actions/create-github-app-token@bcd2ba49218906704ab6c1aa796996da409d3eb1
|
||||
with:
|
||||
app-id: ${{ secrets.FEN_RELEASE_APP_ID }}
|
||||
private-key: ${{ secrets.FEN_RELEASE_PRIVATE_KEY }}
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
|
||||
with:
|
||||
token: ${{ steps.app-token.outputs.token }}
|
||||
fetch-depth: 0
|
||||
fetch-tags: true
|
||||
|
||||
- name: Configure git
|
||||
run: |
|
||||
git config user.name "fen-release[bot]"
|
||||
git config user.email "fen-release[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Resolve source branch from commit SHA
|
||||
id: resolve
|
||||
env:
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Require a full 40-char lowercase-hex SHA. Short SHAs are ambiguous
|
||||
# and we will be comparing this value against API responses (PR head
|
||||
# SHA, ref tips) that always return the full form.
|
||||
if [[ ! "${SOURCE_COMMIT}" =~ ^[0-9a-f]{40}$ ]]; then
|
||||
echo "::error::Input commit '${SOURCE_COMMIT}' is not a full 40-char lowercase hex SHA."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Fetch all remote branches so we can search for which one(s) point
|
||||
# at this SHA. `actions/checkout` with fetch-depth: 0 fetches full
|
||||
# history of the checked-out ref but does not necessarily populate
|
||||
# every refs/remotes/origin/*, so do it explicitly.
|
||||
git fetch --prune origin '+refs/heads/*:refs/remotes/origin/*'
|
||||
|
||||
# Verify the commit actually exists in this repo's object DB.
|
||||
if ! git cat-file -e "${SOURCE_COMMIT}^{commit}" 2>/dev/null; then
|
||||
echo "::error::Commit ${SOURCE_COMMIT} was not found in the repository."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Find every remote branch whose tip == SOURCE_COMMIT. Exactly one
|
||||
# branch must point at it. If zero, the commit isn't anyone's tip
|
||||
# (likely stale, force-pushed past, or never the PR head). If more
|
||||
# than one, the (branch -> SHA) mapping is ambiguous and we refuse
|
||||
# to guess — the operator must give us a unique branch to release.
|
||||
mapfile -t matching_branches < <(
|
||||
git for-each-ref \
|
||||
--format='%(refname:strip=3)' \
|
||||
--points-at="${SOURCE_COMMIT}" \
|
||||
refs/remotes/origin/ \
|
||||
| grep -vx 'HEAD' || true
|
||||
)
|
||||
|
||||
if [[ "${#matching_branches[@]}" -eq 0 ]]; then
|
||||
echo "::error::No branch on origin has ${SOURCE_COMMIT} as its tip."
|
||||
echo "::error::Either the branch was updated after you copied this SHA, or this commit was never the head of a branch."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "${#matching_branches[@]}" -gt 1 ]]; then
|
||||
echo "::error::More than one branch on origin has ${SOURCE_COMMIT} as its tip; cannot pick one:"
|
||||
for b in "${matching_branches[@]}"; do
|
||||
echo "::error:: - ${b}"
|
||||
done
|
||||
echo "::error::Refusing to proceed with an ambiguous source branch."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
source_branch="${matching_branches[0]}"
|
||||
|
||||
if [[ "${source_branch}" == "${DEFAULT_BRANCH}" ]]; then
|
||||
echo "::error::Source branch must not be the default branch ('${DEFAULT_BRANCH}')."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Resolved commit ${SOURCE_COMMIT} to branch '${source_branch}'."
|
||||
echo "source_branch=${source_branch}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Determine latest stable release
|
||||
id: latest
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.app-token.outputs.token }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# List all tags matching vMAJOR.MINOR.PATCH and pick the highest by numeric
|
||||
# comparison of each component. We DO NOT use `sort -V` because it treats
|
||||
# v0.19.99 as higher than v0.20.1.
|
||||
latest_tag="$(
|
||||
git tag --list 'v[0-9]*.[0-9]*.[0-9]*' \
|
||||
| grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' \
|
||||
| awk -F'[v.]' '{ printf "%010d %010d %010d %s\n", $2, $3, $4, $0 }' \
|
||||
| sort -k1,1n -k2,2n -k3,3n \
|
||||
| tail -n1 \
|
||||
| awk '{print $4}'
|
||||
)"
|
||||
|
||||
if [[ -z "${latest_tag}" ]]; then
|
||||
echo "::error::No stable release tags (vMAJOR.MINOR.PATCH) were found."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Parse components
|
||||
ver="${latest_tag#v}"
|
||||
major="${ver%%.*}"
|
||||
rest="${ver#*.}"
|
||||
minor="${rest%%.*}"
|
||||
patch="${rest#*.}"
|
||||
|
||||
new_patch=$((patch + 1))
|
||||
new_version="v${major}.${minor}.${new_patch}"
|
||||
release_branch="release/v${major}.${minor}"
|
||||
|
||||
latest_sha="$(git rev-list -n 1 "refs/tags/${latest_tag}")"
|
||||
|
||||
echo "latest_tag=${latest_tag}" >> "$GITHUB_OUTPUT"
|
||||
echo "latest_sha=${latest_sha}" >> "$GITHUB_OUTPUT"
|
||||
echo "major=${major}" >> "$GITHUB_OUTPUT"
|
||||
echo "minor=${minor}" >> "$GITHUB_OUTPUT"
|
||||
echo "patch=${patch}" >> "$GITHUB_OUTPUT"
|
||||
echo "new_version=${new_version}" >> "$GITHUB_OUTPUT"
|
||||
echo "new_version_no_v=${major}.${minor}.${new_patch}" >> "$GITHUB_OUTPUT"
|
||||
echo "release_branch=${release_branch}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
echo "Latest stable release: ${latest_tag} (${latest_sha})"
|
||||
echo "New version will be: ${new_version}"
|
||||
echo "Release branch: ${release_branch}"
|
||||
|
||||
- name: Validate source branch is cut directly from the latest stable release
|
||||
env:
|
||||
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
|
||||
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Use the user-provided SHA directly rather than re-resolving the branch
|
||||
# tip — the resolve step already proved the branch tip equals SOURCE_COMMIT,
|
||||
# and pinning to the SHA here makes the rest of the job TOCTOU-safe against
|
||||
# someone pushing to the branch mid-run.
|
||||
source_sha="${SOURCE_COMMIT}"
|
||||
|
||||
# Walking first-parent from the source tip must reach LATEST_TAG_SHA.
|
||||
# We capture rev-list into a variable and grep against a here-string
|
||||
# rather than piping `rev-list | grep -q`: under `set -o pipefail`,
|
||||
# `grep -q` would exit on first match and SIGPIPE the still-streaming
|
||||
# `rev-list`, propagating exit 141 as a spurious "not found".
|
||||
first_parent_chain="$(git rev-list --first-parent "${source_sha}")"
|
||||
if ! grep -Fxq "${LATEST_TAG_SHA}" <<< "${first_parent_chain}"; then
|
||||
echo "::error::Source branch '${SOURCE_BRANCH}' is not cut from '${LATEST_TAG}'."
|
||||
echo "::error::Its first-parent history does not include ${LATEST_TAG_SHA}."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Additionally, every commit added on top of the tag (the set we are
|
||||
# about to publish) must itself be a descendant of the tag along
|
||||
# first-parent — i.e. no sibling commits from master sneak in via a
|
||||
# non-first-parent path. Enforce by requiring that the symmetric
|
||||
# difference is empty in one direction: commits in source that are
|
||||
# NOT first-parent-reachable from source starting at the tag.
|
||||
# We do this by intersecting:
|
||||
# A = commits reachable from source but not from tag (full DAG)
|
||||
# B = commits on the first-parent chain from source down to tag
|
||||
# and requiring A == B.
|
||||
all_added="$(git rev-list "${LATEST_TAG_SHA}..${source_sha}" | sort)"
|
||||
first_parent_added="$(
|
||||
git rev-list --first-parent "${LATEST_TAG_SHA}..${source_sha}" | sort
|
||||
)"
|
||||
|
||||
if [[ "${all_added}" != "${first_parent_added}" ]]; then
|
||||
echo "::error::Source branch '${SOURCE_BRANCH}' contains commits not on its first-parent chain from '${LATEST_TAG}'."
|
||||
echo "::error::This usually means the branch was cut from master (not from the tag) or contains a merge from master."
|
||||
echo "Commits reachable but not on first-parent chain:"
|
||||
comm -23 <(printf '%s\n' "${all_added}") <(printf '%s\n' "${first_parent_added}") \
|
||||
| while read -r sha; do
|
||||
echo " $(git log -1 --format='%h %s' "${sha}")"
|
||||
done
|
||||
exit 1
|
||||
fi
|
||||
|
||||
added_count="$(printf '%s\n' "${all_added}" | grep -c . || true)"
|
||||
echo "Source branch is cut directly from ${LATEST_TAG} with ${added_count} commit(s) on top."
|
||||
|
||||
- name: Validate PR exists, is open, named correctly, has latest commit, and checks pass
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.app-token.outputs.token }}
|
||||
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
|
||||
REPO: ${{ github.repository }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
expected_title="ComfyUI backport release ${NEW_VERSION}"
|
||||
|
||||
# Find open PRs from this branch into master. The --state open filter
|
||||
# is load-bearing: a closed/merged PR with passing checks must not be
|
||||
# accepted as authorization for a new release.
|
||||
pr_json="$(
|
||||
gh pr list \
|
||||
--repo "${REPO}" \
|
||||
--state open \
|
||||
--head "${SOURCE_BRANCH}" \
|
||||
--base master \
|
||||
--json number,title,headRefOid,state \
|
||||
--limit 10
|
||||
)"
|
||||
|
||||
pr_count="$(echo "${pr_json}" | jq 'length')"
|
||||
if [[ "${pr_count}" -eq 0 ]]; then
|
||||
echo "::error::No open PR found from '${SOURCE_BRANCH}' into 'master'. The PR must exist and be open."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Pick the PR matching the expected title
|
||||
pr_number="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
|
||||
map(select(.title == $t)) | .[0].number // empty
|
||||
')"
|
||||
pr_head_sha="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
|
||||
map(select(.title == $t)) | .[0].headRefOid // empty
|
||||
')"
|
||||
|
||||
if [[ -z "${pr_number}" ]]; then
|
||||
echo "::error::No open PR from '${SOURCE_BRANCH}' into 'master' is titled '${expected_title}'."
|
||||
echo "Found PRs:"
|
||||
echo "${pr_json}" | jq -r '.[] | " #\(.number): \(.title)"'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# The PR's current head commit must equal the SHA the operator gave us.
|
||||
# This is what closes the door on releasing stale code: if anyone has
|
||||
# pushed to the branch since the operator validated tests passed, the
|
||||
# PR head will have advanced past SOURCE_COMMIT and we abort. (The
|
||||
# resolve step already proved the branch tip == SOURCE_COMMIT; this
|
||||
# ties that same SHA to the PR that authorizes the release.)
|
||||
if [[ "${pr_head_sha}" != "${SOURCE_COMMIT}" ]]; then
|
||||
echo "::error::PR #${pr_number} head commit is ${pr_head_sha}, but the operator-provided commit is ${SOURCE_COMMIT}."
|
||||
echo "::error::The PR has new commits since this release was authorized. Re-run with the new head SHA after verifying its checks."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Found open PR #${pr_number} titled '${expected_title}' at head ${pr_head_sha} (matches operator-provided commit)."
|
||||
|
||||
# Verify all check runs on the head commit have completed successfully.
|
||||
# A check is considered passing if conclusion is success, neutral, or skipped.
|
||||
checks_json="$(
|
||||
gh api \
|
||||
--paginate \
|
||||
"repos/${REPO}/commits/${pr_head_sha}/check-runs" \
|
||||
--jq '.check_runs[] | {name: .name, status: .status, conclusion: .conclusion}'
|
||||
)"
|
||||
|
||||
if [[ -z "${checks_json}" ]]; then
|
||||
echo "::error::No check runs found on PR head commit ${pr_head_sha}."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Check runs on ${pr_head_sha}:"
|
||||
echo "${checks_json}" | jq -s '.'
|
||||
|
||||
failing="$(echo "${checks_json}" | jq -s '
|
||||
map(select(
|
||||
.status != "completed"
|
||||
or (.conclusion as $c
|
||||
| ["success","neutral","skipped"]
|
||||
| index($c) | not)
|
||||
))
|
||||
')"
|
||||
|
||||
failing_count="$(echo "${failing}" | jq 'length')"
|
||||
if [[ "${failing_count}" -gt 0 ]]; then
|
||||
echo "::error::One or more checks have not passed on PR head commit ${pr_head_sha}:"
|
||||
echo "${failing}" | jq -r '.[] | " - \(.name): status=\(.status) conclusion=\(.conclusion)"'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "All checks have passed on ${pr_head_sha}."
|
||||
|
||||
- name: Prepare release branch
|
||||
id: prepare
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.app-token.outputs.token }}
|
||||
REPO: ${{ github.repository }}
|
||||
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
|
||||
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
|
||||
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
|
||||
PATCH: ${{ steps.latest.outputs.patch }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Try to fetch the release branch. If patch == 0, it shouldn't exist yet
|
||||
# and we'll create it from the latest stable tag. If patch > 0, it must
|
||||
# already exist and its tip must equal the latest stable tag commit (i.e.
|
||||
# the previous patch release).
|
||||
if git ls-remote --exit-code --heads origin "${RELEASE_BRANCH}" >/dev/null 2>&1; then
|
||||
echo "Release branch '${RELEASE_BRANCH}' already exists on origin."
|
||||
git fetch origin "refs/heads/${RELEASE_BRANCH}:refs/remotes/origin/${RELEASE_BRANCH}"
|
||||
git checkout -B "${RELEASE_BRANCH}" "refs/remotes/origin/${RELEASE_BRANCH}"
|
||||
|
||||
current_tip="$(git rev-parse HEAD)"
|
||||
if [[ "${current_tip}" != "${LATEST_TAG_SHA}" ]]; then
|
||||
echo "::error::Release branch '${RELEASE_BRANCH}' tip (${current_tip}) is not at the latest stable release '${LATEST_TAG}' (${LATEST_TAG_SHA})."
|
||||
echo "::error::Refusing to release on top of a divergent branch."
|
||||
exit 1
|
||||
fi
|
||||
echo "branch_existed=true" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
if [[ "${PATCH}" != "0" ]]; then
|
||||
echo "::error::Release branch '${RELEASE_BRANCH}' does not exist on origin, but the latest stable release '${LATEST_TAG}' has patch=${PATCH} (>0). This is inconsistent."
|
||||
exit 1
|
||||
fi
|
||||
echo "Release branch '${RELEASE_BRANCH}' does not exist. Creating from ${LATEST_TAG}."
|
||||
git checkout -B "${RELEASE_BRANCH}" "refs/tags/${LATEST_TAG}"
|
||||
echo "branch_existed=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Fast-forward merge source branch into release branch
|
||||
env:
|
||||
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# --ff-only guarantees no merge commit is created. If a fast-forward is
|
||||
# not possible (i.e. the release branch has commits the source branch
|
||||
# doesn't), the merge will fail and we abort. Because we already validated
|
||||
# that the source branch is rooted on the latest stable tag, and the
|
||||
# release branch tip equals that same tag, this fast-forward should
|
||||
# always succeed for a well-formed backport branch.
|
||||
#
|
||||
# We merge the operator-provided SHA, not the branch ref, so a push to
|
||||
# the branch in the window between resolve and now cannot smuggle new
|
||||
# commits into the release.
|
||||
if ! git merge --ff-only "${SOURCE_COMMIT}"; then
|
||||
echo "::error::Cannot fast-forward '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}'). A merge commit would be required. Aborting."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Fast-forwarded '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}')."
|
||||
|
||||
- name: Bump version files
|
||||
env:
|
||||
NEW_VERSION_NO_V: ${{ steps.latest.outputs.new_version_no_v }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ ! -f comfyui_version.py ]]; then
|
||||
echo "::error::comfyui_version.py not found in repo root."
|
||||
exit 1
|
||||
fi
|
||||
if [[ ! -f pyproject.toml ]]; then
|
||||
echo "::error::pyproject.toml not found in repo root."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Replace the version string in comfyui_version.py.
|
||||
# Expected format: __version__ = "X.Y.Z"
|
||||
python3 - "$NEW_VERSION_NO_V" <<'PY'
|
||||
import re, sys, pathlib
|
||||
new = sys.argv[1]
|
||||
|
||||
p = pathlib.Path("comfyui_version.py")
|
||||
src = p.read_text()
|
||||
new_src, n = re.subn(
|
||||
r'(__version__\s*=\s*[\'"])[^\'"]+([\'"])',
|
||||
lambda m: f'{m.group(1)}{new}{m.group(2)}',
|
||||
src,
|
||||
count=1,
|
||||
)
|
||||
if n != 1:
|
||||
sys.exit("Could not find __version__ assignment in comfyui_version.py")
|
||||
p.write_text(new_src)
|
||||
|
||||
p = pathlib.Path("pyproject.toml")
|
||||
src = p.read_text()
|
||||
# Replace the first `version = "..."` inside [project] or [tool.poetry].
|
||||
new_src, n = re.subn(
|
||||
r'(?m)^(version\s*=\s*")[^"]+(")',
|
||||
lambda m: f'{m.group(1)}{new}{m.group(2)}',
|
||||
src,
|
||||
count=1,
|
||||
)
|
||||
if n != 1:
|
||||
sys.exit("Could not find version assignment in pyproject.toml")
|
||||
p.write_text(new_src)
|
||||
PY
|
||||
|
||||
echo "Updated version to ${NEW_VERSION_NO_V} in comfyui_version.py and pyproject.toml."
|
||||
git --no-pager diff -- comfyui_version.py pyproject.toml
|
||||
|
||||
- name: Commit version bump and tag release
|
||||
env:
|
||||
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
git add comfyui_version.py pyproject.toml
|
||||
git commit -m "ComfyUI ${NEW_VERSION}"
|
||||
|
||||
if git rev-parse -q --verify "refs/tags/${NEW_VERSION}" >/dev/null; then
|
||||
echo "::error::Tag ${NEW_VERSION} already exists locally."
|
||||
exit 1
|
||||
fi
|
||||
git tag "${NEW_VERSION}"
|
||||
|
||||
- name: Verify tag does not already exist on origin
|
||||
env:
|
||||
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if git ls-remote --exit-code --tags origin "refs/tags/${NEW_VERSION}" >/dev/null 2>&1; then
|
||||
echo "::error::Tag ${NEW_VERSION} already exists on origin. Aborting."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Push release branch and tag
|
||||
env:
|
||||
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
|
||||
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Push the branch first, then the tag. Atomic-ish: if the branch push
|
||||
# fails we never publish the tag.
|
||||
git push origin "refs/heads/${RELEASE_BRANCH}:refs/heads/${RELEASE_BRANCH}"
|
||||
git push origin "refs/tags/${NEW_VERSION}"
|
||||
|
||||
echo "Released ${NEW_VERSION} on ${RELEASE_BRANCH}."
|
||||
|
||||
- name: Delete remote source branch
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.app-token.outputs.token }}
|
||||
REPO: ${{ github.repository }}
|
||||
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Belt-and-braces: the resolve step already refuses the default branch,
|
||||
# but never delete the default or the release branch under any
|
||||
# circumstances.
|
||||
if [[ "${SOURCE_BRANCH}" == "${DEFAULT_BRANCH}" || "${SOURCE_BRANCH}" == "${RELEASE_BRANCH}" ]]; then
|
||||
echo "::error::Refusing to delete '${SOURCE_BRANCH}' (matches default or release branch)."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Delete the source branch on origin, but only if its tip is still the
|
||||
# SHA we released from. If someone pushed new commits to it after we
|
||||
# resolved it, leave it alone — those commits would be silently lost.
|
||||
current_tip="$(git ls-remote origin "refs/heads/${SOURCE_BRANCH}" | awk '{print $1}')"
|
||||
if [[ -z "${current_tip}" ]]; then
|
||||
echo "Source branch '${SOURCE_BRANCH}' no longer exists on origin; nothing to delete."
|
||||
exit 0
|
||||
fi
|
||||
if [[ "${current_tip}" != "${SOURCE_COMMIT}" ]]; then
|
||||
echo "::warning::Source branch '${SOURCE_BRANCH}' tip (${current_tip}) no longer matches released commit (${SOURCE_COMMIT}). Leaving it in place."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
git push origin --delete "refs/heads/${SOURCE_BRANCH}"
|
||||
echo "Deleted remote branch '${SOURCE_BRANCH}'."
|
||||
|
||||
- name: Summary
|
||||
if: always()
|
||||
env:
|
||||
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
|
||||
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
|
||||
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
|
||||
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
run: |
|
||||
# SOURCE_BRANCH is empty if the resolve step never produced an output
|
||||
# (e.g. the workflow failed in or before that step). Show a placeholder
|
||||
# in that case so the summary table still renders cleanly.
|
||||
source_branch_display="${SOURCE_BRANCH:-(unresolved)}"
|
||||
{
|
||||
echo "## Backport release"
|
||||
echo ""
|
||||
echo "| Field | Value |"
|
||||
echo "|---|---|"
|
||||
echo "| Source commit | \`${SOURCE_COMMIT}\` |"
|
||||
echo "| Source branch | \`${source_branch_display}\` |"
|
||||
echo "| Previous stable | \`${LATEST_TAG}\` |"
|
||||
echo "| New version | \`${NEW_VERSION}\` |"
|
||||
echo "| Release branch | \`${RELEASE_BRANCH}\` |"
|
||||
} >> "$GITHUB_STEP_SUMMARY"
|
||||
@ -20,7 +20,7 @@
|
||||
[website-url]: https://www.comfy.org/
|
||||
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
|
||||
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
|
||||
[discord-url]: https://discord.com/invite/comfyorg
|
||||
[discord-url]: https://www.comfy.org/discord
|
||||
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
|
||||
[twitter-url]: https://x.com/ComfyUI
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@ from app.assets.services import (
|
||||
update_asset_metadata,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.services.cursor import InvalidCursorError
|
||||
from app.assets.services.tagging import list_tag_histogram
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
@ -172,7 +173,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
metadata=result.ref.system_metadata,
|
||||
job_id=result.ref.job_id,
|
||||
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
|
||||
prompt_id=result.ref.job_id, # deprecated alias of job_id, kept for compatibility
|
||||
created_at=result.ref.created_at,
|
||||
updated_at=result.ref.updated_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
@ -209,24 +210,37 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
||||
order_candidate = (q.order or "desc").lower()
|
||||
order = order_candidate if order_candidate in {"asc", "desc"} else "desc"
|
||||
|
||||
result = list_assets_page(
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
try:
|
||||
result = list_assets_page(
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
after=q.after,
|
||||
)
|
||||
except InvalidCursorError as e:
|
||||
return _build_error_response(400, "INVALID_CURSOR", str(e))
|
||||
|
||||
summaries = [_build_asset_response(item) for item in result.items]
|
||||
|
||||
# has_more semantics differ by mode:
|
||||
# - cursor mode: a non-empty next_cursor means there are more results.
|
||||
# - offset mode: derived from total - (offset + page size).
|
||||
if q.after is not None:
|
||||
has_more = result.next_cursor is not None
|
||||
else:
|
||||
has_more = (q.offset + len(summaries)) < result.total
|
||||
|
||||
payload = schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=result.total,
|
||||
has_more=(q.offset + len(summaries)) < result.total,
|
||||
has_more=has_more,
|
||||
next_cursor=result.next_cursor,
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
|
||||
@ -59,6 +59,11 @@ class ListAssetsQuery(BaseModel):
|
||||
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
# Opaque keyset cursor. When supplied, `offset` is ignored. Cursor pagination
|
||||
# is supported for sort values `created_at`, `updated_at`, `name`, `size`.
|
||||
# Supplying `after` together with `sort=last_access_time` returns
|
||||
# 400 INVALID_CURSOR; that sort only supports offset/limit.
|
||||
after: str | None = None
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
|
||||
"created_at"
|
||||
|
||||
@ -40,6 +40,8 @@ class AssetsList(BaseModel):
|
||||
assets: list[Asset]
|
||||
total: int
|
||||
has_more: bool
|
||||
# Opaque cursor for the next page. Omitted when there are no more results.
|
||||
next_cursor: str | None = None
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
|
||||
@ -266,9 +266,18 @@ def list_references_page(
|
||||
metadata_filter: dict | None = None,
|
||||
sort: str | None = None,
|
||||
order: str | None = None,
|
||||
after_cursor_value: object | None = None,
|
||||
after_cursor_id: str | None = None,
|
||||
) -> tuple[list[AssetReference], dict[str, list[str]], int]:
|
||||
"""List references with pagination, filtering, and sorting.
|
||||
|
||||
When ``after_cursor_value``/``after_cursor_id`` are supplied the query uses
|
||||
keyset pagination — ``offset`` is ignored and a WHERE clause selects rows
|
||||
strictly after the given ``(sort_col, id)`` position in the active sort
|
||||
direction. The cursor value must already be typed for the column
|
||||
(datetime for time sorts, int for size, str for name); the caller decodes
|
||||
the opaque cursor string and resolves to the typed value.
|
||||
|
||||
Returns (references, tag_map, total_count).
|
||||
"""
|
||||
base = (
|
||||
@ -297,9 +306,31 @@ def list_references_page(
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetReference.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
descending = order == "desc"
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
# Keyset WHERE: (sort_col, id) strictly less-than / greater-than the cursor.
|
||||
# Equivalent to: sort_col <op> v OR (sort_col = v AND id <op> cursor_id).
|
||||
if after_cursor_value is not None and after_cursor_id is not None:
|
||||
if descending:
|
||||
keyset = sa.or_(
|
||||
sort_col < after_cursor_value,
|
||||
sa.and_(sort_col == after_cursor_value, AssetReference.id < after_cursor_id),
|
||||
)
|
||||
else:
|
||||
keyset = sa.or_(
|
||||
sort_col > after_cursor_value,
|
||||
sa.and_(sort_col == after_cursor_value, AssetReference.id > after_cursor_id),
|
||||
)
|
||||
base = base.where(keyset)
|
||||
|
||||
# Secondary ORDER BY id (matching the primary direction) gives the keyset
|
||||
# comparison a deterministic tiebreaker on duplicate sort_col values.
|
||||
id_exp = AssetReference.id.desc() if descending else AssetReference.id.asc()
|
||||
sort_exp = sort_col.desc() if descending else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp, id_exp).limit(limit)
|
||||
if after_cursor_id is None:
|
||||
base = base.offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(sa.func.count())
|
||||
|
||||
@ -1,8 +1,19 @@
|
||||
import contextlib
|
||||
import mimetypes
|
||||
import os
|
||||
from datetime import timezone
|
||||
from typing import Sequence
|
||||
|
||||
from app.assets.services.cursor import (
|
||||
CursorPayload,
|
||||
InvalidCursorError,
|
||||
decode_cursor,
|
||||
decode_cursor_int,
|
||||
decode_cursor_time,
|
||||
encode_cursor,
|
||||
encode_cursor_from_time,
|
||||
)
|
||||
|
||||
|
||||
from app.assets.database.models import Asset
|
||||
from app.assets.database.queries import (
|
||||
@ -242,6 +253,11 @@ def get_asset_by_hash(asset_hash: str) -> AssetData | None:
|
||||
return extract_asset_data(asset)
|
||||
|
||||
|
||||
# Sort fields that support cursor pagination. `last_access_time` is not
|
||||
# in this list — it falls back to offset/limit.
|
||||
_CURSOR_SORT_FIELDS = ("created_at", "updated_at", "name", "size")
|
||||
|
||||
|
||||
def list_assets_page(
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
@ -252,7 +268,39 @@ def list_assets_page(
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
after: str | None = None,
|
||||
) -> ListAssetsResult:
|
||||
"""List assets with optional cursor pagination.
|
||||
|
||||
When ``after`` is supplied it overrides ``offset``. The cursor's sort field
|
||||
must match ``sort`` and be in the cursor-supported allowlist; mismatches
|
||||
raise InvalidCursorError so the handler can map to 400 INVALID_CURSOR.
|
||||
"""
|
||||
cursor_value: object | None = None
|
||||
cursor_id: str | None = None
|
||||
# Mint next_cursor on every page where the sort is cursor-supported, not
|
||||
# only when the request itself arrived with a cursor. Otherwise a first
|
||||
# request (no `after`) returns next_cursor=None and the client can never
|
||||
# enter cursor mode.
|
||||
mint_cursor = sort in _CURSOR_SORT_FIELDS
|
||||
|
||||
if after is not None:
|
||||
if sort not in _CURSOR_SORT_FIELDS:
|
||||
raise InvalidCursorError(
|
||||
f"cursor pagination is not supported for sort={sort!r}"
|
||||
)
|
||||
payload = decode_cursor(after, _CURSOR_SORT_FIELDS, expected_order=order)
|
||||
if payload.sort_field != sort:
|
||||
raise InvalidCursorError(
|
||||
f"cursor sort field {payload.sort_field!r} does not match request sort {sort!r}"
|
||||
)
|
||||
cursor_value, cursor_id = _resolve_cursor_value(payload), payload.id
|
||||
|
||||
# Over-fetch by one row so we can distinguish "exactly `limit` rows total
|
||||
# remaining" from "more rows past this page" without a second query. Drop
|
||||
# the sentinel before returning.
|
||||
fetch_limit = limit + 1 if mint_cursor else limit
|
||||
|
||||
with create_session() as session:
|
||||
refs, tag_map, total = list_references_page(
|
||||
session,
|
||||
@ -261,12 +309,22 @@ def list_assets_page(
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
limit=fetch_limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
after_cursor_value=cursor_value,
|
||||
after_cursor_id=cursor_id,
|
||||
)
|
||||
|
||||
next_cursor: str | None = None
|
||||
if mint_cursor and len(refs) > limit:
|
||||
# There's at least one more row past this page — mint a cursor from
|
||||
# the last row of the page (i.e. index `limit - 1`, since we
|
||||
# over-fetched), and drop the sentinel.
|
||||
next_cursor = _encode_next_cursor(refs[limit - 1], sort, order)
|
||||
refs = refs[:limit]
|
||||
|
||||
items: list[AssetSummaryData] = []
|
||||
for ref in refs:
|
||||
items.append(
|
||||
@ -277,7 +335,39 @@ def list_assets_page(
|
||||
)
|
||||
)
|
||||
|
||||
return ListAssetsResult(items=items, total=total)
|
||||
return ListAssetsResult(items=items, total=total, next_cursor=next_cursor)
|
||||
|
||||
|
||||
def _resolve_cursor_value(payload: CursorPayload) -> object:
|
||||
"""Map a decoded cursor payload to a column-typed Python value."""
|
||||
if payload.sort_field in ("created_at", "updated_at"):
|
||||
# DB stores naive UTC; strip tzinfo so the comparison binds against a
|
||||
# `TIMESTAMP WITHOUT TIME ZONE` column without an offset shift.
|
||||
return decode_cursor_time(payload).replace(tzinfo=None)
|
||||
if payload.sort_field == "size":
|
||||
return decode_cursor_int(payload)
|
||||
return payload.value # name, str-typed
|
||||
|
||||
|
||||
def _encode_next_cursor(ref, sort: str, order: str) -> str | None:
|
||||
"""Mint a cursor pointing at *ref* for the given sort dimension.
|
||||
|
||||
Returns None when the boundary row carries a NULL sort value (e.g. an asset
|
||||
record whose size_bytes hasn't been backfilled). Continuing pagination
|
||||
across a NULL boundary is undefined under keyset ordering — better to
|
||||
truncate cleanly here than to mint a cursor that mis-positions.
|
||||
"""
|
||||
if sort == "name":
|
||||
return encode_cursor("name", ref.name, ref.id, order=order)
|
||||
if sort == "size":
|
||||
if ref.asset is None or ref.asset.size_bytes is None:
|
||||
return None
|
||||
return encode_cursor("size", str(ref.asset.size_bytes), ref.id, order=order)
|
||||
# created_at / updated_at — DB datetimes are naive UTC; attach tz before encoding.
|
||||
value = ref.created_at if sort == "created_at" else ref.updated_at
|
||||
if value is None:
|
||||
return None
|
||||
return encode_cursor_from_time(sort, value.replace(tzinfo=timezone.utc), ref.id, order=order)
|
||||
|
||||
|
||||
def resolve_hash_to_path(
|
||||
|
||||
225
app/assets/services/cursor.py
Normal file
225
app/assets/services/cursor.py
Normal file
@ -0,0 +1,225 @@
|
||||
"""Opaque keyset-pagination cursor for /api/assets.
|
||||
|
||||
Payload JSON uses short keys to keep the encoded length small:
|
||||
|
||||
{"s": <sort_field>, "v": <value>, "id": <id>, "o": <order>}
|
||||
|
||||
The `o` key binds the cursor to the sort direction it was minted under,
|
||||
so replaying a `desc` cursor against an `asc` request fails with
|
||||
``INVALID_CURSOR`` rather than silently walking the wrong direction.
|
||||
`o` is mandatory on every payload — a cursor without it is rejected as
|
||||
malformed.
|
||||
|
||||
Encoding is base64url with no padding. JSON serialization escapes `<`,
|
||||
`>`, `&`, U+2028, and U+2029 in encoded string values so asset names
|
||||
containing those characters produce a stable, byte-identical wire form
|
||||
across any compatible implementation of the same payload format.
|
||||
|
||||
Time values are serialized as Unix microseconds (UTC) — microsecond
|
||||
precision is sufficient to round-trip the timestamps stored by the
|
||||
database without rounding rows in the same millisecond bucket.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterable, Optional
|
||||
|
||||
|
||||
class InvalidCursorError(ValueError):
|
||||
"""Raised on a malformed, oversized, or unsupported-sort-field cursor.
|
||||
|
||||
Map to a 400 response with code ``INVALID_CURSOR`` at the handler.
|
||||
"""
|
||||
|
||||
|
||||
# Wire-format length caps. Cursors are user-controlled, so caps protect the
|
||||
# decode path from oversized allocations and downstream SQL predicates from
|
||||
# unbounded strings.
|
||||
#
|
||||
# MAX_CURSOR_VALUE_LENGTH is 512 to fit the `AssetReference.name` column max
|
||||
# (`String(512)`) — otherwise a long-named asset would mint a cursor the same
|
||||
# server then refuses on the next request.
|
||||
MAX_ENCODED_CURSOR_LENGTH = 1024
|
||||
MAX_CURSOR_VALUE_LENGTH = 512
|
||||
MAX_CURSOR_ID_LENGTH = 128
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CursorPayload:
|
||||
sort_field: str
|
||||
value: str
|
||||
id: str
|
||||
order: str
|
||||
|
||||
|
||||
_VALID_ORDERS = ("asc", "desc")
|
||||
|
||||
|
||||
def encode_cursor(sort_field: str, value: str, id: str, order: str = "desc") -> str:
|
||||
"""Encode a cursor payload as a base64url (no-padding) string.
|
||||
|
||||
`order` binds the cursor to the sort direction it was minted under so a
|
||||
later request with a flipped `order` query parameter is rejected with
|
||||
``INVALID_CURSOR`` rather than silently walking the wrong direction.
|
||||
"""
|
||||
if order not in _VALID_ORDERS:
|
||||
raise InvalidCursorError(f"order must be one of {_VALID_ORDERS}, got {order!r}")
|
||||
# Symmetric input validation: the encoder must reject anything the
|
||||
# decoder rejects, or the same server will mint cursors it then 400s on
|
||||
# the next request.
|
||||
if not id:
|
||||
raise InvalidCursorError("id must be non-empty")
|
||||
if len(id) > MAX_CURSOR_ID_LENGTH:
|
||||
raise InvalidCursorError("id exceeds maximum length")
|
||||
if len(value) > MAX_CURSOR_VALUE_LENGTH:
|
||||
raise InvalidCursorError("value exceeds maximum length")
|
||||
payload = {"s": sort_field, "v": value, "id": id, "o": order}
|
||||
raw = json.dumps(payload, separators=(",", ":"), ensure_ascii=False)
|
||||
# Match the default JSON escaping of HTML-significant characters and JS
|
||||
# line/paragraph separators (U+2028 / U+2029) so an asset name carrying
|
||||
# any of them encodes to identical bytes across runtimes. None of these
|
||||
# characters appear in JSON structural syntax, so a global replace on the
|
||||
# serialized output can only touch encoded values. Use explicit \uXXXX
|
||||
# escapes for U+2028 / U+2029 so the source survives any editor / git
|
||||
# tooling that normalizes invisible separators.
|
||||
raw = (
|
||||
raw.replace("<", "\\u003c")
|
||||
.replace(">", "\\u003e")
|
||||
.replace("&", "\\u0026")
|
||||
.replace("\u2028", "\\u2028")
|
||||
.replace("\u2029", "\\u2029")
|
||||
)
|
||||
encoded = base64.urlsafe_b64encode(raw.encode("utf-8")).rstrip(b"=").decode("ascii")
|
||||
# Final wire-size guard: the per-field caps above are char-counted, but the
|
||||
# wire cap applies to the base64url of the UTF-8-encoded, escape-expanded
|
||||
# payload. A value full of multibyte or HTML-significant characters (e.g.
|
||||
# 512 \u00d7 "\u00e9" or 512 \u00d7 "<") inflates well past MAX_ENCODED_CURSOR_LENGTH even
|
||||
# though it passes the char-count check. Refuse to mint a cursor the decoder
|
||||
# on the next request would reject.
|
||||
if len(encoded) > MAX_ENCODED_CURSOR_LENGTH:
|
||||
raise InvalidCursorError("encoded cursor exceeds maximum length")
|
||||
return encoded
|
||||
|
||||
|
||||
def encode_cursor_from_time(sort_field: str, t: datetime, id: str, order: str = "desc") -> str:
|
||||
"""Encode a time-typed cursor at Unix microsecond precision.
|
||||
|
||||
Accepts an aware datetime (any timezone) and normalizes to UTC. Naive
|
||||
datetimes are rejected so callers can't accidentally encode the local
|
||||
wall-clock value of a UTC-stored timestamp.
|
||||
"""
|
||||
if t.tzinfo is None:
|
||||
raise ValueError("encode_cursor_from_time requires an aware datetime")
|
||||
micros = _datetime_to_unix_micros(t.astimezone(timezone.utc))
|
||||
return encode_cursor(sort_field, str(micros), id, order=order)
|
||||
|
||||
|
||||
def decode_cursor(
|
||||
cursor: str,
|
||||
allowed_sort_fields: Iterable[str],
|
||||
expected_order: str | None = None,
|
||||
) -> CursorPayload:
|
||||
"""Parse an opaque cursor.
|
||||
|
||||
``allowed_sort_fields`` is the endpoint's accepted sort-field list — a
|
||||
cursor carrying a field outside this set is rejected so a cursor minted
|
||||
for one column can't be replayed against another (e.g. a ``created_at``
|
||||
timestamp string compared against a ``name`` column).
|
||||
|
||||
``expected_order`` (``"asc"``/``"desc"``), when supplied, must match the
|
||||
payload's ``o`` field. ``o`` is required on every payload; a cursor
|
||||
missing it is rejected as malformed.
|
||||
|
||||
Passing no allowed fields rejects every cursor.
|
||||
"""
|
||||
if len(cursor) > MAX_ENCODED_CURSOR_LENGTH:
|
||||
raise InvalidCursorError("cursor exceeds maximum length")
|
||||
|
||||
try:
|
||||
# urlsafe_b64decode requires correct padding; we strip on encode, so
|
||||
# restore the trailing '=' pad here.
|
||||
padding = "=" * (-len(cursor) % 4)
|
||||
raw = base64.urlsafe_b64decode(cursor + padding)
|
||||
except (ValueError, base64.binascii.Error) as e:
|
||||
raise InvalidCursorError(f"encoding: {e}") from e
|
||||
|
||||
try:
|
||||
decoded = json.loads(raw)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
raise InvalidCursorError(f"payload: {e}") from e
|
||||
|
||||
if not isinstance(decoded, dict):
|
||||
raise InvalidCursorError("payload: expected object")
|
||||
|
||||
sort_field = decoded.get("s")
|
||||
value = decoded.get("v")
|
||||
id = decoded.get("id")
|
||||
order = decoded.get("o")
|
||||
|
||||
if not isinstance(sort_field, str) or not isinstance(value, str) or not isinstance(id, str):
|
||||
raise InvalidCursorError("payload: missing or non-string s/v/id")
|
||||
|
||||
if id == "":
|
||||
raise InvalidCursorError("missing id")
|
||||
if len(id) > MAX_CURSOR_ID_LENGTH:
|
||||
raise InvalidCursorError("id exceeds maximum length")
|
||||
if len(value) > MAX_CURSOR_VALUE_LENGTH:
|
||||
raise InvalidCursorError("value exceeds maximum length")
|
||||
|
||||
if sort_field not in allowed_sort_fields:
|
||||
raise InvalidCursorError(f"unsupported sort field {sort_field!r}")
|
||||
|
||||
if not isinstance(order, str):
|
||||
raise InvalidCursorError("missing or non-string o")
|
||||
if order not in _VALID_ORDERS:
|
||||
raise InvalidCursorError(f"unsupported order {order!r}")
|
||||
if expected_order is not None and order != expected_order:
|
||||
raise InvalidCursorError(
|
||||
f"cursor order {order!r} does not match request order {expected_order!r}"
|
||||
)
|
||||
|
||||
return CursorPayload(sort_field=sort_field, value=value, id=id, order=order)
|
||||
|
||||
|
||||
def decode_cursor_time(payload: Optional[CursorPayload]) -> datetime:
|
||||
"""Parse a time-typed cursor value as Unix microseconds, returning UTC."""
|
||||
if payload is None:
|
||||
raise InvalidCursorError("nil cursor payload")
|
||||
try:
|
||||
micros = int(payload.value)
|
||||
except ValueError as e:
|
||||
raise InvalidCursorError(f"value is not a valid timestamp: {e}") from e
|
||||
try:
|
||||
return _unix_micros_to_datetime(micros)
|
||||
except (OverflowError, OSError, ValueError) as e:
|
||||
# Crafted out-of-range microseconds (e.g. > datetime.MAX_YEAR) blow up
|
||||
# in fromtimestamp / datetime construction. Map to 400, not 500.
|
||||
raise InvalidCursorError(f"value is out of representable range: {e}") from e
|
||||
|
||||
|
||||
def decode_cursor_int(payload: Optional[CursorPayload]) -> int:
|
||||
"""Parse a cursor value as a base-10 integer."""
|
||||
if payload is None:
|
||||
raise InvalidCursorError("nil cursor payload")
|
||||
try:
|
||||
return int(payload.value)
|
||||
except ValueError as e:
|
||||
raise InvalidCursorError(f"value is not a valid integer: {e}") from e
|
||||
|
||||
|
||||
_EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _datetime_to_unix_micros(t: datetime) -> int:
|
||||
"""Convert an aware UTC datetime to Unix microseconds (integer math)."""
|
||||
delta = t - _EPOCH
|
||||
return (delta.days * 86_400 + delta.seconds) * 1_000_000 + delta.microseconds
|
||||
|
||||
|
||||
def _unix_micros_to_datetime(micros: int) -> datetime:
|
||||
"""Convert Unix microseconds to a UTC datetime, preserving precision."""
|
||||
seconds, micro_remainder = divmod(micros, 1_000_000)
|
||||
return datetime.fromtimestamp(seconds, tz=timezone.utc).replace(microsecond=micro_remainder)
|
||||
@ -71,6 +71,7 @@ class AssetSummaryData:
|
||||
class ListAssetsResult:
|
||||
items: list[AssetSummaryData]
|
||||
total: int
|
||||
next_cursor: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@ -62,8 +62,6 @@ def get_comfy_package_versions():
|
||||
def check_comfy_packages_versions():
|
||||
"""Warn for every comfy* package whose installed version is below requirements.txt."""
|
||||
from packaging.version import InvalidVersion, parse as parse_pep440
|
||||
outdated_packages = []
|
||||
|
||||
for pkg in get_comfy_package_versions():
|
||||
installed_str = pkg["installed"]
|
||||
required_str = pkg["required"]
|
||||
@ -75,26 +73,19 @@ def check_comfy_packages_versions():
|
||||
logging.error(f"Failed to check {pkg['name']} version: {e}")
|
||||
continue
|
||||
if outdated:
|
||||
outdated_packages.append((pkg["name"], installed_str, required_str))
|
||||
else:
|
||||
logging.info("{} version: {}".format(pkg["name"], installed_str))
|
||||
|
||||
if outdated_packages:
|
||||
package_warnings = "\n".join(
|
||||
f"Installed {name} version {installed} is lower than the recommended version {required}."
|
||||
for name, installed, required in outdated_packages
|
||||
)
|
||||
app.logger.log_startup_warning(
|
||||
f"""
|
||||
app.logger.log_startup_warning(
|
||||
f"""
|
||||
________________________________________________________________________
|
||||
WARNING WARNING WARNING WARNING WARNING
|
||||
|
||||
{package_warnings}
|
||||
Installed {pkg["name"]} version {installed_str} is lower than the recommended version {required_str}.
|
||||
|
||||
{get_missing_requirements_message()}
|
||||
________________________________________________________________________
|
||||
""".strip()
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("{} version: {}".format(pkg["name"], installed_str))
|
||||
|
||||
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from enum import Enum
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@ -9,76 +11,44 @@ class Rodin3DGenerateRequest(BaseModel):
|
||||
material: str = Field(..., description="The material type.")
|
||||
quality_override: int = Field(..., description="The poly count of the mesh.")
|
||||
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
|
||||
TAPose: bool | None = Field(None, description="")
|
||||
|
||||
|
||||
class Rodin3DGen25Request(BaseModel):
|
||||
|
||||
tier: str = Field(..., description="Gen-2.5 tier (e.g. Gen-2.5-High).")
|
||||
prompt: str | None = Field(None, description="Required for Text-to-3D; ignored otherwise.")
|
||||
seed: int | None = Field(None, description="0-65535.")
|
||||
material: str | None = Field(None, description="PBR | Shaded | All | None.")
|
||||
geometry_file_format: str | None = Field(None, description="glb | usdz | fbx | obj | stl.")
|
||||
texture_mode: str | None = Field(None, description="legacy | extreme-low | low | medium | high.")
|
||||
mesh_mode: str | None = Field(None, description="Raw (triangular) | Quad.")
|
||||
quality_override: int | None = Field(None, description="Mesh face count override.")
|
||||
geometry_instruct_mode: str | None = Field(None, description="faithful | creative.")
|
||||
bbox_condition: list[int] | None = Field(None, description="Bounding box [Width(Y), Height(Z), Length(X)] in cm.")
|
||||
height: int | None = Field(None, description="Approximate model height in cm.")
|
||||
TAPose: bool | None = Field(None, description="T/A pose for human-like models.")
|
||||
hd_texture: bool | None = Field(None, description="Enhanced texture quality.")
|
||||
texture_delight: bool | None = Field(None, description="Remove baked lighting from textures.")
|
||||
is_micro: bool | None = Field(None, description="Micro detail (Extreme-High only).")
|
||||
use_original_alpha: bool | None = Field(None, description="Preserve image transparency.")
|
||||
preview_render: bool | None = Field(None, description="Generate high-quality preview render.")
|
||||
addons: list[str] | None = Field(None, description='Optional addons, e.g. ["HighPack"].')
|
||||
|
||||
TAPose: Optional[bool] = Field(None, description="")
|
||||
|
||||
class GenerateJobsData(BaseModel):
|
||||
uuids: list[str] = Field(..., description="str LIST")
|
||||
uuids: List[str] = Field(..., description="str LIST")
|
||||
subscription_key: str = Field(..., description="subscription key")
|
||||
|
||||
|
||||
class Rodin3DGenerateResponse(BaseModel):
|
||||
message: str | None = Field(None, description="Return message.")
|
||||
prompt: str | None = Field(None, description="Generated Prompt from image.")
|
||||
submit_time: str | None = Field(None, description="Submit Time")
|
||||
uuid: str | None = Field(None, description="Task str")
|
||||
jobs: GenerateJobsData | None = Field(None, description="Details of jobs")
|
||||
|
||||
message: Optional[str] = Field(None, description="Return message.")
|
||||
prompt: Optional[str] = Field(None, description="Generated Prompt from image.")
|
||||
submit_time: Optional[str] = Field(None, description="Submit Time")
|
||||
uuid: Optional[str] = Field(None, description="Task str")
|
||||
jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs")
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
"""
|
||||
Status for jobs
|
||||
"""
|
||||
|
||||
Done = "Done"
|
||||
Failed = "Failed"
|
||||
Generating = "Generating"
|
||||
Waiting = "Waiting"
|
||||
|
||||
|
||||
class Rodin3DCheckStatusRequest(BaseModel):
|
||||
subscription_key: str = Field(..., description="subscription from generate endpoint")
|
||||
|
||||
|
||||
class JobItem(BaseModel):
|
||||
uuid: str = Field(..., description="uuid")
|
||||
status: JobStatus = Field(..., description="Status Currently")
|
||||
|
||||
status: JobStatus = Field(...,description="Status Currently")
|
||||
|
||||
class Rodin3DCheckStatusResponse(BaseModel):
|
||||
jobs: list[JobItem] = Field(..., description="Job status List")
|
||||
|
||||
jobs: List[JobItem] = Field(..., description="Job status List")
|
||||
|
||||
class Rodin3DDownloadRequest(BaseModel):
|
||||
task_uuid: str = Field(..., description="Task str")
|
||||
|
||||
|
||||
class RodinResourceItem(BaseModel):
|
||||
url: str = Field(..., description="Download Url")
|
||||
name: str = Field(..., description="File name with ext")
|
||||
|
||||
|
||||
class Rodin3DDownloadResponse(BaseModel):
|
||||
items: list[RodinResourceItem] = Field(..., alias="list", description="Source List")
|
||||
list: List[RodinResourceItem] = Field(..., description="Source List")
|
||||
|
||||
@ -5,37 +5,32 @@ Rodin API docs: https://developer.hyper3d.ai/
|
||||
|
||||
"""
|
||||
|
||||
from inspect import cleandoc
|
||||
import folder_paths as comfy_paths
|
||||
import os
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from inspect import cleandoc
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from PIL import Image
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths as comfy_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Types
|
||||
from PIL import Image
|
||||
from comfy_api_nodes.apis.rodin import (
|
||||
JobStatus,
|
||||
Rodin3DGenerateRequest,
|
||||
Rodin3DGenerateResponse,
|
||||
Rodin3DCheckStatusRequest,
|
||||
Rodin3DCheckStatusResponse,
|
||||
Rodin3DDownloadRequest,
|
||||
Rodin3DDownloadResponse,
|
||||
Rodin3DGen25Request,
|
||||
Rodin3DGenerateRequest,
|
||||
Rodin3DGenerateResponse,
|
||||
JobStatus,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
sync_op,
|
||||
poll_op,
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
poll_op,
|
||||
sync_op,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
|
||||
|
||||
COMMON_PARAMETERS = [
|
||||
IO.Int.Input(
|
||||
@ -56,30 +51,40 @@ COMMON_PARAMETERS = [
|
||||
]
|
||||
|
||||
|
||||
_QUALITY_MESH_OPTIONS: dict[str, tuple[str, int]] = {
|
||||
"4K-Quad": ("Quad", 4000),
|
||||
"8K-Quad": ("Quad", 8000),
|
||||
"18K-Quad": ("Quad", 18000),
|
||||
"50K-Quad": ("Quad", 50000),
|
||||
"200K-Quad": ("Quad", 200000),
|
||||
"2K-Triangle": ("Raw", 2000),
|
||||
"20K-Triangle": ("Raw", 20000),
|
||||
"150K-Triangle": ("Raw", 150000),
|
||||
"200K-Triangle": ("Raw", 200000),
|
||||
"500K-Triangle": ("Raw", 500000),
|
||||
"1M-Triangle": ("Raw", 1000000),
|
||||
}
|
||||
def get_quality_mode(poly_count):
|
||||
polycount = poly_count.split("-")
|
||||
poly = polycount[1]
|
||||
count = polycount[0]
|
||||
if poly == "Triangle":
|
||||
mesh_mode = "Raw"
|
||||
elif poly == "Quad":
|
||||
mesh_mode = "Quad"
|
||||
else:
|
||||
mesh_mode = "Quad"
|
||||
|
||||
if count == "4K":
|
||||
quality_override = 4000
|
||||
elif count == "8K":
|
||||
quality_override = 8000
|
||||
elif count == "18K":
|
||||
quality_override = 18000
|
||||
elif count == "50K":
|
||||
quality_override = 50000
|
||||
elif count == "2K":
|
||||
quality_override = 2000
|
||||
elif count == "20K":
|
||||
quality_override = 20000
|
||||
elif count == "150K":
|
||||
quality_override = 150000
|
||||
elif count == "500K":
|
||||
quality_override = 500000
|
||||
else:
|
||||
quality_override = 18000
|
||||
|
||||
return mesh_mode, quality_override
|
||||
|
||||
|
||||
def get_quality_mode(poly_count: str) -> tuple[str, int]:
|
||||
"""Map a polygon-count preset like '18K-Quad' to (mesh_mode, quality_override).
|
||||
|
||||
Falls back to ('Quad', 18000) for unknown labels; legacy parity.
|
||||
"""
|
||||
return _QUALITY_MESH_OPTIONS.get(poly_count, ("Quad", 18000))
|
||||
|
||||
|
||||
def tensor_to_filelike(tensor, max_pixels: int = 2048 * 2048):
|
||||
def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
|
||||
"""
|
||||
Converts a PyTorch tensor to a file-like object.
|
||||
|
||||
@ -91,8 +96,8 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048 * 2048):
|
||||
- io.BytesIO: A file-like object containing the image data.
|
||||
"""
|
||||
array = tensor.cpu().numpy()
|
||||
array = (array * 255).astype("uint8")
|
||||
image = Image.fromarray(array, "RGB")
|
||||
array = (array * 255).astype('uint8')
|
||||
image = Image.fromarray(array, 'RGB')
|
||||
|
||||
original_width, original_height = image.size
|
||||
original_pixels = original_width * original_height
|
||||
@ -107,7 +112,7 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048 * 2048):
|
||||
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
img_byte_arr = BytesIO()
|
||||
image.save(img_byte_arr, format="PNG") # PNG is used for lossless compression
|
||||
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
|
||||
img_byte_arr.seek(0)
|
||||
return img_byte_arr
|
||||
|
||||
@ -140,9 +145,11 @@ async def create_generate_task(
|
||||
TAPose=ta_pose,
|
||||
),
|
||||
files=[
|
||||
("images", open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image))
|
||||
for image in images
|
||||
if image is not None
|
||||
(
|
||||
"images",
|
||||
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image)
|
||||
)
|
||||
for image in images if image is not None
|
||||
],
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
@ -170,7 +177,6 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
||||
return "DONE"
|
||||
return "Generating"
|
||||
|
||||
|
||||
def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None:
|
||||
if not response.jobs:
|
||||
return None
|
||||
@ -208,7 +214,7 @@ async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.Fi
|
||||
model_file_path = None
|
||||
file_3d = None
|
||||
|
||||
for i in url_list.items:
|
||||
for i in url_list.list:
|
||||
file_path = os.path.join(save_path, i.name)
|
||||
if i.name.lower().endswith(".glb"):
|
||||
model_file_path = os.path.join(result_folder_name, i.name)
|
||||
@ -483,16 +489,7 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
IO.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
|
||||
IO.Combo.Input(
|
||||
"Polygon_count",
|
||||
options=[
|
||||
"4K-Quad",
|
||||
"8K-Quad",
|
||||
"18K-Quad",
|
||||
"50K-Quad",
|
||||
"2K-Triangle",
|
||||
"20K-Triangle",
|
||||
"150K-Triangle",
|
||||
"500K-Triangle",
|
||||
],
|
||||
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
|
||||
default="500K-Triangle",
|
||||
optional=True,
|
||||
),
|
||||
@ -545,566 +542,6 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
def _rodin_multipart_parser(data: dict[str, Any]) -> aiohttp.FormData:
|
||||
"""Convert a Rodin request dict to an aiohttp form, fixing bool/list serialization.
|
||||
|
||||
Booleans --> "true"/"false". Lists --> one field per element.
|
||||
"""
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
for key, value in data.items():
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, bool):
|
||||
form.add_field(key, "true" if value else "false")
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
form.add_field(key, str(item))
|
||||
elif isinstance(value, (bytes, bytearray)):
|
||||
form.add_field(key, value)
|
||||
else:
|
||||
form.add_field(key, str(value))
|
||||
return form
|
||||
|
||||
|
||||
async def _create_gen25_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: Rodin3DGen25Request,
|
||||
images: list | None,
|
||||
) -> tuple[str, str]:
|
||||
"""Submit a Gen-2.5 generate job; returns (task_uuid, subscription_key)."""
|
||||
|
||||
if images is not None and len(images) > 5:
|
||||
raise ValueError("Rodin Gen-2.5 supports at most 5 input images.")
|
||||
|
||||
files = None
|
||||
if images:
|
||||
files = [
|
||||
(
|
||||
"images",
|
||||
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image),
|
||||
)
|
||||
for image in images
|
||||
if image is not None
|
||||
]
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"),
|
||||
response_model=Rodin3DGenerateResponse,
|
||||
data=request,
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
multipart_parser=_rodin_multipart_parser,
|
||||
)
|
||||
|
||||
if not response.uuid or not response.jobs or not response.jobs.subscription_key:
|
||||
raise RuntimeError(f"Rodin Gen-2.5 submit failed: message={response.message!r}")
|
||||
return response.uuid, response.jobs.subscription_key
|
||||
|
||||
|
||||
_PREVIEWABLE_3D_EXTS = {".glb", ".obj", ".fbx", ".stl", ".gltf"}
|
||||
|
||||
|
||||
async def _download_gen25_files(
|
||||
download_list: Rodin3DDownloadResponse,
|
||||
task_uuid: str,
|
||||
geometry_file_format: str,
|
||||
) -> Types.File3D | None:
|
||||
"""Download every file in the list; return the File3D matching the chosen format."""
|
||||
|
||||
folder_name = f"Rodin3D_Gen25_{task_uuid}"
|
||||
save_dir = os.path.join(comfy_paths.get_output_directory(), folder_name)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
target_ext = f".{geometry_file_format.lower().lstrip('.')}"
|
||||
file_3d: Types.File3D | None = None
|
||||
|
||||
for item in download_list.items:
|
||||
file_path = os.path.join(save_dir, item.name)
|
||||
ext = os.path.splitext(item.name.lower())[1]
|
||||
# Prefer the file matching the user's chosen format; fall back below.
|
||||
if file_3d is None and ext == target_ext and ext in _PREVIEWABLE_3D_EXTS:
|
||||
file_3d = await download_url_to_file_3d(item.url, target_ext.lstrip("."))
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_3d.get_bytes())
|
||||
continue
|
||||
await download_url_to_bytesio(item.url, file_path)
|
||||
|
||||
# If the chosen format wasn't found, surface any model file we did get.
|
||||
if file_3d is None:
|
||||
for item in download_list.items:
|
||||
ext = os.path.splitext(item.name.lower())[1]
|
||||
if ext in _PREVIEWABLE_3D_EXTS:
|
||||
file_3d = await download_url_to_file_3d(item.url, ext.lstrip("."))
|
||||
break
|
||||
return file_3d
|
||||
|
||||
|
||||
_MODE_REGULAR = "Regular"
|
||||
_MODE_FAST = "Fast"
|
||||
_MODE_EXTREME_HIGH = "Extreme-High"
|
||||
|
||||
_REGULAR_POLY_OPTIONS = [
|
||||
"Default",
|
||||
"4K-Quad",
|
||||
"8K-Quad",
|
||||
"18K-Quad",
|
||||
"50K-Quad",
|
||||
"2K-Triangle",
|
||||
"20K-Triangle",
|
||||
"150K-Triangle",
|
||||
"500K-Triangle",
|
||||
"1M-Triangle",
|
||||
]
|
||||
|
||||
_TEXTURE_MODE_OPTIONS = ["Default", "legacy", "extreme-low", "low", "medium", "high"]
|
||||
_GEOMETRY_FORMAT_OPTIONS = ["glb", "fbx", "obj", "stl"]
|
||||
_MATERIAL_OPTIONS = ["PBR", "Shaded", "All", "None"]
|
||||
|
||||
|
||||
def _build_mode_input(name: str = "mode") -> IO.DynamicCombo.Input:
|
||||
return IO.DynamicCombo.Input(
|
||||
name,
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
_MODE_REGULAR,
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"tier",
|
||||
options=["Gen-2.5-Low", "Gen-2.5-Medium", "Gen-2.5-High"],
|
||||
default="Gen-2.5-High",
|
||||
tooltip="Quality tier. Higher tiers produce higher-fidelity geometry.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"polygon_count",
|
||||
options=_REGULAR_POLY_OPTIONS,
|
||||
default="Default",
|
||||
tooltip="Preset face count. 'Default' uses the server's default for the selected tier.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"creative",
|
||||
default=False,
|
||||
tooltip="Creative mode (Medium/High only). Enhances generative robustness.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_MODE_FAST,
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"tier",
|
||||
options=[
|
||||
"Gen-2.5-Extreme-Low",
|
||||
"Gen-2.5-Low",
|
||||
"Gen-2.5-Medium",
|
||||
"Gen-2.5-High",
|
||||
],
|
||||
default="Gen-2.5-Low",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"mesh_faces",
|
||||
default=20000,
|
||||
min=1000,
|
||||
max=20000,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Mesh face count (1K-20K in Fast mode).",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_MODE_EXTREME_HIGH,
|
||||
[
|
||||
IO.Combo.Input("mesh_mode", options=["Raw", "Quad"], default="Raw"),
|
||||
IO.Int.Input(
|
||||
"mesh_faces",
|
||||
default=1000000,
|
||||
min=20000,
|
||||
max=2000000,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip=(
|
||||
"Mesh face count. Raw mode: 20K-2M. "
|
||||
"Quad mode: keep under 200K (upstream may reject higher values)."
|
||||
),
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"is_micro",
|
||||
default=False,
|
||||
tooltip="Enable micro detail (Extreme-High only).",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"creative",
|
||||
default=False,
|
||||
tooltip="Creative mode. Enhances generative robustness.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip=(
|
||||
"Generation mode. Regular = balanced. Fast = 1K-20K faces for rapid prototyping. "
|
||||
"Extreme-High = 20K-2M faces with optional micro details."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _build_common_inputs(*, include_image_only: bool) -> list:
|
||||
inputs: list = [
|
||||
IO.Combo.Input("material", options=_MATERIAL_OPTIONS, default="Shaded"),
|
||||
IO.Combo.Input("geometry_file_format", options=_GEOMETRY_FORMAT_OPTIONS, default="glb"),
|
||||
IO.Combo.Input(
|
||||
"texture_mode",
|
||||
options=_TEXTURE_MODE_OPTIONS,
|
||||
default="Default",
|
||||
optional=True,
|
||||
tooltip="Texture quality preset. 'Default' uses the server's default for the selected tier.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=65535,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"TAPose", default=False, optional=True, advanced=True, tooltip="T/A pose for human-like models."
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"hd_texture", default=False, optional=True, advanced=True, tooltip="High-quality texture enhancement."
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"texture_delight",
|
||||
default=False,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Remove baked lighting from textures.",
|
||||
),
|
||||
]
|
||||
if include_image_only:
|
||||
inputs.append(
|
||||
IO.Boolean.Input(
|
||||
"use_original_alpha",
|
||||
default=False,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Preserve image transparency.",
|
||||
)
|
||||
)
|
||||
inputs.extend(
|
||||
[
|
||||
IO.Boolean.Input(
|
||||
"addon_highpack",
|
||||
default=False,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="HighPack addon: 4K textures and ~16x faces in Quad mode.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"bbox_width",
|
||||
default=0,
|
||||
min=0,
|
||||
max=300,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Bounding-box width (Y axis). Set to 0 with the others to skip bbox.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"bbox_height",
|
||||
default=0,
|
||||
min=0,
|
||||
max=300,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Bounding-box height (Z axis).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"bbox_length",
|
||||
default=0,
|
||||
min=0,
|
||||
max=300,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Bounding-box length (X axis).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"height_cm",
|
||||
default=0,
|
||||
min=0,
|
||||
max=10000,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Approximate model height in centimeters (0 to skip).",
|
||||
),
|
||||
]
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
_PRICE_EXPR = """
|
||||
(
|
||||
$baseCredits := widgets.mode = "extreme-high" ? 1.0 : 0.5;
|
||||
$addonCredits := widgets.addon_highpack ? 1.0 : 0.0;
|
||||
$total := ($baseCredits * 1.5) + ($addonCredits * 0.8);
|
||||
{"type":"usd","usd": $total}
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
def _resolve_mode_params(mode_input: dict) -> dict:
|
||||
"""Translate the DynamicCombo `mode` payload into Gen-2.5 request fields.
|
||||
|
||||
Returns a dict with: tier, quality_override, mesh_mode, geometry_instruct_mode, is_micro.
|
||||
Missing keys mean "do not send" (so we don't override server defaults).
|
||||
"""
|
||||
selected = mode_input["mode"]
|
||||
out: dict = {}
|
||||
|
||||
if selected == _MODE_REGULAR:
|
||||
out["tier"] = mode_input["tier"]
|
||||
polygon = mode_input.get("polygon_count", "Default")
|
||||
if polygon != "Default":
|
||||
mesh_mode, faces = get_quality_mode(polygon)
|
||||
out["mesh_mode"] = mesh_mode
|
||||
out["quality_override"] = faces
|
||||
if mode_input.get("creative"):
|
||||
out["geometry_instruct_mode"] = "creative"
|
||||
|
||||
elif selected == _MODE_FAST:
|
||||
out["tier"] = mode_input["tier"]
|
||||
out["mesh_mode"] = "Raw"
|
||||
out["quality_override"] = int(mode_input["mesh_faces"])
|
||||
|
||||
elif selected == _MODE_EXTREME_HIGH:
|
||||
out["tier"] = "Gen-2.5-Extreme-High"
|
||||
out["mesh_mode"] = mode_input["mesh_mode"]
|
||||
out["quality_override"] = int(mode_input["mesh_faces"])
|
||||
if mode_input.get("is_micro"):
|
||||
out["is_micro"] = True
|
||||
if mode_input.get("creative"):
|
||||
out["geometry_instruct_mode"] = "creative"
|
||||
return out
|
||||
|
||||
|
||||
def _build_request(
|
||||
*,
|
||||
mode_input: dict,
|
||||
material: str,
|
||||
geometry_file_format: str,
|
||||
texture_mode: str,
|
||||
seed: int,
|
||||
TAPose: bool,
|
||||
hd_texture: bool,
|
||||
texture_delight: bool,
|
||||
addon_highpack: bool,
|
||||
bbox_width: int,
|
||||
bbox_height: int,
|
||||
bbox_length: int,
|
||||
height_cm: int,
|
||||
prompt: str | None = None,
|
||||
use_original_alpha: bool = False,
|
||||
) -> Rodin3DGen25Request:
|
||||
mode_params = _resolve_mode_params(mode_input)
|
||||
|
||||
bbox = None
|
||||
if bbox_width and bbox_height and bbox_length:
|
||||
bbox = [bbox_width, bbox_height, bbox_length]
|
||||
|
||||
return Rodin3DGen25Request(
|
||||
tier=mode_params["tier"],
|
||||
prompt=prompt or None,
|
||||
seed=seed,
|
||||
material=material,
|
||||
geometry_file_format=geometry_file_format,
|
||||
texture_mode=None if texture_mode == "Default" else texture_mode,
|
||||
mesh_mode=mode_params.get("mesh_mode"),
|
||||
quality_override=mode_params.get("quality_override"),
|
||||
geometry_instruct_mode=mode_params.get("geometry_instruct_mode"),
|
||||
bbox_condition=bbox,
|
||||
height=height_cm or None,
|
||||
TAPose=TAPose or None,
|
||||
hd_texture=hd_texture or None,
|
||||
texture_delight=texture_delight or None,
|
||||
is_micro=mode_params.get("is_micro"),
|
||||
use_original_alpha=use_original_alpha or None,
|
||||
addons=["HighPack"] if addon_highpack else None,
|
||||
)
|
||||
|
||||
|
||||
class Rodin3D_Gen25_Image(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen25_Image",
|
||||
display_name="Rodin 3D Gen-2.5 - Image to 3D",
|
||||
category="api node/3d/Rodin",
|
||||
description=(
|
||||
"Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. "
|
||||
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
|
||||
),
|
||||
inputs=[
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplatePrefix(IO.Image.Input("image"), prefix="image", min=1, max=5),
|
||||
tooltip="1-5 images. The first image is used for materials when multi-view.",
|
||||
),
|
||||
_build_mode_input(),
|
||||
*_build_common_inputs(include_image_only=True),
|
||||
],
|
||||
outputs=[IO.File3DAny.Output(display_name="model_file")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode", "addon_highpack"]),
|
||||
expr=_PRICE_EXPR,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
images: IO.Autogrow.Type,
|
||||
mode: dict,
|
||||
material: str,
|
||||
geometry_file_format: str,
|
||||
texture_mode: str,
|
||||
seed: int,
|
||||
TAPose: bool,
|
||||
hd_texture: bool,
|
||||
texture_delight: bool,
|
||||
use_original_alpha: bool,
|
||||
addon_highpack: bool,
|
||||
bbox_width: int,
|
||||
bbox_height: int,
|
||||
bbox_length: int,
|
||||
height_cm: int,
|
||||
) -> IO.NodeOutput:
|
||||
image_tensors = [img for img in images.values() if img is not None]
|
||||
if not image_tensors:
|
||||
raise ValueError("Rodin Gen-2.5 Image-to-3D requires at least one image.")
|
||||
|
||||
# Flatten multi-image tensors into individual frames; the API accepts each as a separate part.
|
||||
flat_images: list = []
|
||||
for tensor in image_tensors:
|
||||
if hasattr(tensor, "shape") and len(tensor.shape) == 4:
|
||||
for i in range(tensor.shape[0]):
|
||||
flat_images.append(tensor[i])
|
||||
else:
|
||||
flat_images.append(tensor)
|
||||
|
||||
if len(flat_images) > 5:
|
||||
raise ValueError(f"Rodin Gen-2.5 accepts at most 5 images; received {len(flat_images)}.")
|
||||
|
||||
request = _build_request(
|
||||
mode_input=mode,
|
||||
material=material,
|
||||
geometry_file_format=geometry_file_format,
|
||||
texture_mode=texture_mode,
|
||||
seed=seed,
|
||||
TAPose=TAPose,
|
||||
hd_texture=hd_texture,
|
||||
texture_delight=texture_delight,
|
||||
addon_highpack=addon_highpack,
|
||||
bbox_width=bbox_width,
|
||||
bbox_height=bbox_height,
|
||||
bbox_length=bbox_length,
|
||||
height_cm=height_cm,
|
||||
prompt=None,
|
||||
use_original_alpha=use_original_alpha,
|
||||
)
|
||||
|
||||
task_uuid, subscription_key = await _create_gen25_task(cls, request, flat_images)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
file_3d = await _download_gen25_files(download_list, task_uuid, geometry_file_format)
|
||||
return IO.NodeOutput(file_3d)
|
||||
|
||||
|
||||
class Rodin3D_Gen25_Text(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen25_Text",
|
||||
display_name="Rodin 3D Gen-2.5 - Text to 3D",
|
||||
category="api node/3d/Rodin",
|
||||
description=(
|
||||
"Generate a 3D model from a text prompt via Rodin Gen-2.5. "
|
||||
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
|
||||
),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the 3D model.",
|
||||
),
|
||||
_build_mode_input(),
|
||||
*_build_common_inputs(include_image_only=False),
|
||||
],
|
||||
outputs=[IO.File3DAny.Output(display_name="model_file")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode", "addon_highpack"]),
|
||||
expr=_PRICE_EXPR,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
mode: dict,
|
||||
material: str,
|
||||
geometry_file_format: str,
|
||||
texture_mode: str,
|
||||
seed: int,
|
||||
TAPose: bool,
|
||||
hd_texture: bool,
|
||||
texture_delight: bool,
|
||||
addon_highpack: bool,
|
||||
bbox_width: int,
|
||||
bbox_height: int,
|
||||
bbox_length: int,
|
||||
height_cm: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, field_name="prompt", min_length=1, max_length=2500)
|
||||
request = _build_request(
|
||||
mode_input=mode,
|
||||
material=material,
|
||||
geometry_file_format=geometry_file_format,
|
||||
texture_mode=texture_mode,
|
||||
seed=seed,
|
||||
TAPose=TAPose,
|
||||
hd_texture=hd_texture,
|
||||
texture_delight=texture_delight,
|
||||
addon_highpack=addon_highpack,
|
||||
bbox_width=bbox_width,
|
||||
bbox_height=bbox_height,
|
||||
bbox_length=bbox_length,
|
||||
height_cm=height_cm,
|
||||
prompt=prompt,
|
||||
)
|
||||
task_uuid, subscription_key = await _create_gen25_task(cls, request, images=None)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
file_3d = await _download_gen25_files(download_list, task_uuid, geometry_file_format)
|
||||
return IO.NodeOutput(file_3d)
|
||||
|
||||
|
||||
class Rodin3DExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -1114,8 +551,6 @@ class Rodin3DExtension(ComfyExtension):
|
||||
Rodin3D_Smooth,
|
||||
Rodin3D_Sketch,
|
||||
Rodin3D_Gen2,
|
||||
Rodin3D_Gen25_Image,
|
||||
Rodin3D_Gen25_Text,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -8,82 +8,6 @@ from comfy_api.latest import _io
|
||||
MISSING = object()
|
||||
|
||||
|
||||
class NotNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ComfyNotNode",
|
||||
display_name="Not",
|
||||
category="utils/logic",
|
||||
description="Logical NOT operation. Returns true if the value is falsy. Uses Python's rules for truthiness.",
|
||||
search_aliases=["invert", "toggle", "negate", "flip boolean"],
|
||||
inputs=[
|
||||
io.AnyType.Input("value"),
|
||||
],
|
||||
outputs=[
|
||||
io.Boolean.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, value) -> io.NodeOutput:
|
||||
return io.NodeOutput(not value)
|
||||
|
||||
|
||||
class AndNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = io.Autogrow.TemplatePrefix(
|
||||
input=io.AnyType.Input("value"),
|
||||
prefix="value",
|
||||
min=1,
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="ComfyAndNode",
|
||||
display_name="And",
|
||||
category="utils/logic",
|
||||
description="Logical AND operation. Returns true if all of the values are truthy. Uses Python's rules for truthiness.",
|
||||
search_aliases=["all", "every"],
|
||||
inputs=[
|
||||
io.Autogrow.Input("values", template=template),
|
||||
],
|
||||
outputs=[
|
||||
io.Boolean.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, values: io.Autogrow.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(all(values.values()))
|
||||
|
||||
|
||||
class OrNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = io.Autogrow.TemplatePrefix(
|
||||
input=io.AnyType.Input("value"),
|
||||
prefix="value",
|
||||
min=1,
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="ComfyOrNode",
|
||||
display_name="Or",
|
||||
category="utils/logic",
|
||||
description="Logical OR operation. Returns true if any of the values are truthy. Uses Python's rules for truthiness.",
|
||||
search_aliases=["any", "some"],
|
||||
inputs=[
|
||||
io.Autogrow.Input("values", template=template),
|
||||
],
|
||||
outputs=[
|
||||
io.Boolean.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, values: io.Autogrow.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(any(values.values()))
|
||||
|
||||
|
||||
class SwitchNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -91,7 +15,7 @@ class SwitchNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfySwitchNode",
|
||||
display_name="Switch",
|
||||
category="utils/logic",
|
||||
category="logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Boolean.Input("switch"),
|
||||
@ -122,7 +46,7 @@ class SoftSwitchNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfySoftSwitchNode",
|
||||
display_name="Soft Switch",
|
||||
category="utils/logic",
|
||||
category="logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Boolean.Input("switch"),
|
||||
@ -212,7 +136,7 @@ class DCTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="DCTestNode",
|
||||
display_name="DCTest",
|
||||
category="utils/logic",
|
||||
category="logic",
|
||||
is_output_node=True,
|
||||
inputs=[io.DynamicCombo.Input("combo", options=[
|
||||
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
||||
@ -250,7 +174,7 @@ class AutogrowNamesTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="AutogrowNamesTestNode",
|
||||
display_name="AutogrowNamesTest",
|
||||
category="utils/logic",
|
||||
category="logic",
|
||||
inputs=[
|
||||
_io.Autogrow.Input("autogrow", template=template)
|
||||
],
|
||||
@ -270,7 +194,7 @@ class AutogrowPrefixTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="AutogrowPrefixTestNode",
|
||||
display_name="AutogrowPrefixTest",
|
||||
category="utils/logic",
|
||||
category="logic",
|
||||
inputs=[
|
||||
_io.Autogrow.Input("autogrow", template=template)
|
||||
],
|
||||
@ -289,7 +213,7 @@ class ComboOutputTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComboOptionTestNode",
|
||||
display_name="ComboOptionTest",
|
||||
category="utils/logic",
|
||||
category="logic",
|
||||
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
|
||||
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
|
||||
outputs=[io.Combo.Output(), io.Combo.Output()],
|
||||
@ -306,7 +230,7 @@ class ConvertStringToComboNode(io.ComfyNode):
|
||||
node_id="ConvertStringToComboNode",
|
||||
search_aliases=["string to dropdown", "text to combo"],
|
||||
display_name="Convert String to Combo",
|
||||
category="utils/logic",
|
||||
category="logic",
|
||||
inputs=[io.String.Input("string")],
|
||||
outputs=[io.Combo.Output()],
|
||||
)
|
||||
@ -322,7 +246,7 @@ class InvertBooleanNode(io.ComfyNode):
|
||||
node_id="InvertBooleanNode",
|
||||
search_aliases=["not", "toggle", "negate", "flip boolean"],
|
||||
display_name="Invert Boolean",
|
||||
category="utils/logic",
|
||||
category="logic",
|
||||
inputs=[io.Boolean.Input("boolean")],
|
||||
outputs=[io.Boolean.Output()],
|
||||
)
|
||||
@ -337,9 +261,6 @@ class LogicExtension(ComfyExtension):
|
||||
return [
|
||||
SwitchNode,
|
||||
CustomComboNode,
|
||||
NotNode,
|
||||
AndNode,
|
||||
OrNode,
|
||||
# SoftSwitchNode,
|
||||
# ConvertStringToComboNode,
|
||||
# DCTestNode,
|
||||
|
||||
@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfyMathExpression",
|
||||
display_name="Math Expression",
|
||||
category="utils",
|
||||
category="logic",
|
||||
search_aliases=[
|
||||
"expression", "formula", "calculate", "calculator",
|
||||
"eval", "math",
|
||||
|
||||
@ -14,7 +14,7 @@ class CreateList(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="CreateList",
|
||||
display_name="Create List",
|
||||
category="utils",
|
||||
category="logic",
|
||||
is_input_list=True,
|
||||
search_aliases=["Image Iterator", "Text Iterator", "Iterator"],
|
||||
inputs=[io.Autogrow.Input("inputs", template=template_autogrow)],
|
||||
|
||||
3276
openapi.yaml
3276
openapi.yaml
File diff suppressed because it is too large
Load Diff
112
tests-unit/assets_test/queries/test_asset_reference_keyset.py
Normal file
112
tests-unit/assets_test/queries/test_asset_reference_keyset.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Keyset-pagination tiebreaker tests for list_references_page.
|
||||
|
||||
When multiple rows share the same primary sort value (e.g. four assets
|
||||
created in the same microsecond), the secondary `ORDER BY id` is what keeps
|
||||
keyset pagination from losing or repeating rows. This file exercises that
|
||||
branch directly against an in-memory SQLite session — engineering identical
|
||||
timestamps via HTTP is unreliable enough that we work at the query layer.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.database.queries.asset_reference import list_references_page
|
||||
|
||||
|
||||
def _make_ref(session: Session, created_at: datetime, name: str, owner: str = "") -> AssetReference:
|
||||
asset = Asset(hash=f"blake3:{uuid.uuid4().hex}", size_bytes=1024)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
ref = AssetReference(
|
||||
id=str(uuid.uuid4()),
|
||||
asset_id=asset.id,
|
||||
owner_id=owner,
|
||||
name=name,
|
||||
file_path=f"/tmp/{name}",
|
||||
created_at=created_at,
|
||||
updated_at=created_at,
|
||||
last_access_time=created_at,
|
||||
is_missing=False,
|
||||
)
|
||||
session.add(ref)
|
||||
return ref
|
||||
|
||||
|
||||
@pytest.mark.parametrize("order", ["desc", "asc"])
|
||||
def test_tiebreaker_walks_duplicate_sort_values(session: Session, order: str):
|
||||
"""Four rows with the SAME created_at must paginate cleanly under cursor
|
||||
mode — no row dropped, no row repeated, despite the primary sort column
|
||||
being non-discriminating.
|
||||
"""
|
||||
shared_ts = datetime(2024, 5, 20, 12, 0, 0) # naive UTC, like the DB stores
|
||||
refs = [_make_ref(session, shared_ts, f"tie_{i}.png") for i in range(4)]
|
||||
session.commit()
|
||||
|
||||
expected_ids = sorted([r.id for r in refs], reverse=(order == "desc"))
|
||||
|
||||
# Walk the cursor by hand: page size 2, take 3 pages (2 + 2 + 0).
|
||||
seen: list[str] = []
|
||||
after_value = None
|
||||
after_id = None
|
||||
for _ in range(4): # generous loop bound; ought to be 2 iterations
|
||||
page, _tag_map, _total = list_references_page(
|
||||
session,
|
||||
limit=2,
|
||||
sort="created_at",
|
||||
order=order,
|
||||
after_cursor_value=after_value,
|
||||
after_cursor_id=after_id,
|
||||
)
|
||||
if not page:
|
||||
break
|
||||
seen.extend(p.id for p in page)
|
||||
# Use the last row's (created_at, id) as the next cursor input.
|
||||
last = page[-1]
|
||||
after_value, after_id = last.created_at, last.id
|
||||
if len(page) < 2:
|
||||
break
|
||||
|
||||
assert seen == expected_ids, (
|
||||
f"keyset tiebreaker failed for order={order}: expected {expected_ids}, got {seen}"
|
||||
)
|
||||
|
||||
|
||||
def test_tiebreaker_no_duplicates_under_mixed_collisions(session: Session):
|
||||
"""Some rows share a timestamp, some don't. The cursor must still walk
|
||||
every row exactly once regardless of where ties sit relative to a
|
||||
page boundary."""
|
||||
t1 = datetime(2024, 5, 20, 12, 0, 0)
|
||||
t2 = datetime(2024, 5, 20, 12, 0, 1)
|
||||
layout = [t1, t1, t1, t2, t2] # three rows at t1, two at t2
|
||||
refs = [_make_ref(session, ts, f"mix_{i}.png") for i, ts in enumerate(layout)]
|
||||
session.commit()
|
||||
|
||||
all_ids = {r.id for r in refs}
|
||||
seen_set: set[str] = set()
|
||||
seen_list: list[str] = []
|
||||
after_value = None
|
||||
after_id = None
|
||||
for _ in range(6):
|
||||
page, _, _ = list_references_page(
|
||||
session,
|
||||
limit=2,
|
||||
sort="created_at",
|
||||
order="desc",
|
||||
after_cursor_value=after_value,
|
||||
after_cursor_id=after_id,
|
||||
)
|
||||
if not page:
|
||||
break
|
||||
for p in page:
|
||||
assert p.id not in seen_set, f"duplicate row {p.id} appeared in cursor walk"
|
||||
seen_set.add(p.id)
|
||||
seen_list.append(p.id)
|
||||
last = page[-1]
|
||||
after_value, after_id = last.created_at, last.id
|
||||
if len(page) < 2:
|
||||
break
|
||||
|
||||
assert seen_set == all_ids, f"missing rows: expected {all_ids}, got {seen_set}"
|
||||
354
tests-unit/assets_test/services/test_cursor.py
Normal file
354
tests-unit/assets_test/services/test_cursor.py
Normal file
@ -0,0 +1,354 @@
|
||||
"""Tests for app.assets.services.cursor.
|
||||
|
||||
The byte-identity fixtures below pin the wire format so a parallel
|
||||
implementation in another runtime can mint exchange-compatible cursors
|
||||
for the same payload. Drift here would break frontend pagination against
|
||||
any compatible backend.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.services.cursor import (
|
||||
MAX_CURSOR_ID_LENGTH,
|
||||
MAX_CURSOR_VALUE_LENGTH,
|
||||
MAX_ENCODED_CURSOR_LENGTH,
|
||||
CursorPayload,
|
||||
InvalidCursorError,
|
||||
decode_cursor,
|
||||
decode_cursor_int,
|
||||
decode_cursor_time,
|
||||
encode_cursor,
|
||||
encode_cursor_from_time,
|
||||
)
|
||||
|
||||
|
||||
ALLOWED = ("created_at", "updated_at", "name", "size")
|
||||
|
||||
|
||||
class TestRoundTrip:
|
||||
@pytest.mark.parametrize(
|
||||
"sort_field, value, id",
|
||||
[
|
||||
("created_at", "1716200000000000", "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7"),
|
||||
("size", "1024", "asset-123"),
|
||||
("name", "my-asset.png", "asset-abc"),
|
||||
("name", "résumé.txt", "asset-uni"),
|
||||
],
|
||||
)
|
||||
def test_encode_decode(self, sort_field, value, id):
|
||||
encoded = encode_cursor(sort_field, value, id)
|
||||
assert encoded != ""
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.sort_field == sort_field
|
||||
assert payload.value == value
|
||||
assert payload.id == id
|
||||
|
||||
|
||||
class TestTimeCursor:
|
||||
def test_microsecond_precision_preserved(self):
|
||||
# Pick a time with non-zero microseconds — encoding at ms would lose the µs.
|
||||
ts = datetime(2024, 5, 20, 12, 53, 20, 123456, tzinfo=timezone.utc)
|
||||
encoded = encode_cursor_from_time("created_at", ts, "id-1")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
# Value must be a microsecond integer string, not a millisecond one.
|
||||
assert payload.value == "1716209600123456"
|
||||
decoded = decode_cursor_time(payload)
|
||||
assert decoded == ts
|
||||
|
||||
def test_decode_returns_utc(self):
|
||||
payload = CursorPayload(sort_field="created_at", value="1716200000123456", id="id-1", order="desc")
|
||||
decoded = decode_cursor_time(payload)
|
||||
assert decoded.tzinfo == timezone.utc
|
||||
|
||||
def test_naive_datetime_rejected_on_encode(self):
|
||||
naive = datetime(2024, 5, 20, 12, 0, 0)
|
||||
with pytest.raises(ValueError):
|
||||
encode_cursor_from_time("created_at", naive, "id-1")
|
||||
|
||||
def test_non_integer_value_rejected_on_decode(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_time(CursorPayload("created_at", "not-a-number", "id-1", "desc"))
|
||||
|
||||
def test_none_payload_rejected(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_time(None)
|
||||
|
||||
def test_non_utc_aware_normalized(self):
|
||||
# Same instant, different timezone — must encode to the same micros.
|
||||
utc_ts = datetime(2024, 5, 20, 12, 0, 0, tzinfo=timezone.utc)
|
||||
offset_ts = utc_ts.astimezone(timezone(timedelta(hours=-5)))
|
||||
assert encode_cursor_from_time("created_at", utc_ts, "x") == encode_cursor_from_time(
|
||||
"created_at", offset_ts, "x"
|
||||
)
|
||||
|
||||
|
||||
class TestIntCursor:
|
||||
def test_decode_int(self):
|
||||
assert decode_cursor_int(CursorPayload("size", "1024", "id-1", "desc")) == 1024
|
||||
|
||||
def test_decode_int_rejects_non_int(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_int(CursorPayload("size", "abc", "id-1", "desc"))
|
||||
|
||||
def test_decode_int_rejects_none(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_int(None)
|
||||
|
||||
|
||||
class TestInvalidInputs:
|
||||
def test_oversized_cursor(self):
|
||||
oversized = "a" * (MAX_ENCODED_CURSOR_LENGTH + 1)
|
||||
with pytest.raises(InvalidCursorError, match="maximum length"):
|
||||
decode_cursor(oversized, ALLOWED)
|
||||
|
||||
def test_not_base64(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor("not base64!!!", ALLOWED)
|
||||
|
||||
def test_not_json(self):
|
||||
encoded = base64.urlsafe_b64encode(b"definitely not json").rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_empty_id(self):
|
||||
# Encoder rejects empty id symmetrically with the decoder, so build the
|
||||
# payload manually to exercise the decoder's missing-id branch.
|
||||
raw = b'{"s":"created_at","v":"1","id":"","o":"desc"}'
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="missing id"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_oversized_id(self):
|
||||
# Encoder enforces the cap symmetrically; hand-build to exercise decode.
|
||||
big_id = "a" * (MAX_CURSOR_ID_LENGTH + 1)
|
||||
raw = ('{"s":"created_at","v":"1","id":"' + big_id + '","o":"desc"}').encode("ascii")
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="id exceeds maximum length"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_oversized_value(self):
|
||||
# Encoder enforces the cap symmetrically; hand-build to exercise decode.
|
||||
big_v = "v" * (MAX_CURSOR_VALUE_LENGTH + 1)
|
||||
raw = ('{"s":"created_at","v":"' + big_v + '","id":"id-1","o":"desc"}').encode("ascii")
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="value exceeds maximum length"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_unsupported_sort_field(self):
|
||||
encoded = encode_cursor("execution_time", "1", "id-1")
|
||||
with pytest.raises(InvalidCursorError, match="unsupported sort field"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_no_allowed_fields_rejects_everything(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1")
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor(encoded, ())
|
||||
|
||||
def test_non_dict_payload_rejected(self):
|
||||
encoded = base64.urlsafe_b64encode(b'["array","not","dict"]').rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="expected object"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
|
||||
class TestEncodeAtCapsFits:
|
||||
def test_max_field_lengths_fit_wire_cap(self):
|
||||
# Worst-case payload: value and id at their per-field caps, with a long
|
||||
# sort field name. The encoded cursor must fit within MAX_ENCODED_CURSOR_LENGTH
|
||||
# so the wire cap cannot reject a cursor the encoder mints at the per-field caps.
|
||||
value = "v" * MAX_CURSOR_VALUE_LENGTH
|
||||
id = "i" * MAX_CURSOR_ID_LENGTH
|
||||
sort_field = "very_long_sort_field_name"
|
||||
|
||||
encoded = encode_cursor(sort_field, value, id)
|
||||
assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH
|
||||
payload = decode_cursor(encoded, (sort_field,))
|
||||
assert payload.value == value
|
||||
assert payload.id == id
|
||||
|
||||
|
||||
class TestDatetimeOverflow:
|
||||
"""Crafted cursors with extreme micros must map to InvalidCursorError,
|
||||
not OverflowError/OSError leaking as 500.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"micros_str",
|
||||
[
|
||||
"999999999999999999999", # 10^21 µs — past datetime.MAX_YEAR by ~14 orders
|
||||
"-999999999999999999999", # symmetric negative — pre-epoch overflow
|
||||
],
|
||||
)
|
||||
def test_out_of_range_micros_rejected(self, micros_str):
|
||||
encoded = encode_cursor("created_at", micros_str, "asset-x")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_time(payload)
|
||||
|
||||
|
||||
class TestEncoderDecoderSymmetry:
|
||||
"""The encoder must reject inputs the decoder rejects, or the same server
|
||||
will mint a cursor it then 400s on the next request.
|
||||
"""
|
||||
|
||||
def test_long_name_within_cap_round_trips(self):
|
||||
"""Assets allow names up to 512 chars (`String(512)`); the cursor
|
||||
encoder must round-trip a value at that cap so a freshly minted
|
||||
cursor never fails decode on the next request."""
|
||||
long_name = "n" * MAX_CURSOR_VALUE_LENGTH
|
||||
encoded = encode_cursor("name", long_name, "asset-x")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.value == long_name
|
||||
|
||||
def test_encoder_rejects_empty_id(self):
|
||||
with pytest.raises(InvalidCursorError, match="id must be non-empty"):
|
||||
encode_cursor("created_at", "1", "")
|
||||
|
||||
def test_encoder_rejects_oversized_id(self):
|
||||
with pytest.raises(InvalidCursorError, match="id exceeds maximum length"):
|
||||
encode_cursor("created_at", "1", "a" * (MAX_CURSOR_ID_LENGTH + 1))
|
||||
|
||||
def test_encoder_rejects_oversized_value(self):
|
||||
with pytest.raises(InvalidCursorError, match="value exceeds maximum length"):
|
||||
encode_cursor("name", "v" * (MAX_CURSOR_VALUE_LENGTH + 1), "id-1")
|
||||
|
||||
def test_encoder_rejects_multibyte_value_over_wire_cap(self):
|
||||
"""A value that passes the char-count cap can still inflate past the
|
||||
wire cap once UTF-8-encoded. Asset name made of 512 × multibyte
|
||||
characters (e.g. 'é' = 2 bytes) must be rejected at encode time, not
|
||||
minted into a cursor the next request will 400."""
|
||||
with pytest.raises(InvalidCursorError, match="encoded cursor exceeds maximum length"):
|
||||
encode_cursor("name", "é" * MAX_CURSOR_VALUE_LENGTH, "asset-multibyte")
|
||||
|
||||
def test_encoder_rejects_escape_heavy_value_over_wire_cap(self):
|
||||
"""Same wire-cap concern via escape expansion: each `<` serializes to
|
||||
the six-byte sequence `\\u003c`, so 512 of them blow past the encoded
|
||||
cap even though the raw char count is within the per-field limit."""
|
||||
with pytest.raises(InvalidCursorError, match="encoded cursor exceeds maximum length"):
|
||||
encode_cursor("name", "<" * MAX_CURSOR_VALUE_LENGTH, "asset-escape")
|
||||
|
||||
|
||||
class TestOrderBinding:
|
||||
def test_order_baked_into_payload(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1", order="asc")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.order == "asc"
|
||||
|
||||
def test_mismatched_order_rejected(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1", order="desc")
|
||||
with pytest.raises(InvalidCursorError, match="does not match request order"):
|
||||
decode_cursor(encoded, ALLOWED, expected_order="asc")
|
||||
|
||||
def test_matching_order_accepted(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1", order="desc")
|
||||
payload = decode_cursor(encoded, ALLOWED, expected_order="desc")
|
||||
assert payload.order == "desc"
|
||||
|
||||
def test_invalid_order_token_rejected_on_encode(self):
|
||||
with pytest.raises(ValueError):
|
||||
encode_cursor("created_at", "1", "id-1", order="sideways")
|
||||
|
||||
def test_invalid_order_token_rejected_on_decode(self):
|
||||
# Hand-craft a payload with an illegal `o` value.
|
||||
raw = b'{"s":"name","v":"x","id":"id-1","o":"sideways"}'
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="unsupported order"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_cursor_without_order_rejected(self):
|
||||
"""`o` is mandatory. A cursor minted without it is rejected as
|
||||
malformed rather than silently walking the keyset in whatever
|
||||
direction the request happens to ask for."""
|
||||
raw = b'{"s":"name","v":"x","id":"id-1"}'
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="missing or non-string o"):
|
||||
decode_cursor(encoded, ALLOWED, expected_order="desc")
|
||||
|
||||
|
||||
class TestHtmlSignificantCharEscaping:
|
||||
"""An asset name containing `<`, `>`, `&`, U+2028, or U+2029 must encode
|
||||
to the same escaped wire bytes as any compatible implementation of the
|
||||
same payload format. Drift here breaks cross-runtime byte-identity for
|
||||
those characters.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, escaped_substring",
|
||||
[
|
||||
("foo<bar>.png", "\\u003c"), # `<` escaped
|
||||
("foo<bar>.png", "\\u003e"), # `>` escaped
|
||||
("foo&bar.png", "\\u0026"),
|
||||
("foo
bar.png", "\\u2028"), # JS line separator
|
||||
("foo
bar.png", "\\u2029"), # JS paragraph separator
|
||||
],
|
||||
)
|
||||
def test_html_significant_chars_escaped(self, value, escaped_substring):
|
||||
encoded = encode_cursor("name", value, "id-1")
|
||||
decoded_bytes = base64.urlsafe_b64decode(encoded + "=" * (-len(encoded) % 4))
|
||||
assert escaped_substring in decoded_bytes.decode("ascii"), (
|
||||
f"Expected {escaped_substring!r} in serialized payload, got: {decoded_bytes!r}"
|
||||
)
|
||||
|
||||
def test_value_round_trips_through_escape(self):
|
||||
"""Encoding then decoding a value with `<>&` should yield the original
|
||||
string — the escape only affects the wire form, not the decoded value."""
|
||||
original = "foo<&>bar.png"
|
||||
encoded = encode_cursor("name", original, "id-1")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.value == original
|
||||
|
||||
|
||||
class TestByteIdentityFixtures:
|
||||
"""Pin the wire format so it doesn't drift silently.
|
||||
|
||||
These fixtures assert exact byte equality of the encoded JSON payload —
|
||||
a change in key order, escape choice, separator whitespace, or anything
|
||||
else that shifts a byte fails the test loudly rather than diverging
|
||||
silently from any external consumer of the same payload format.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sort_field, value, id, order, expected_payload",
|
||||
[
|
||||
(
|
||||
"created_at",
|
||||
"1716200000000000",
|
||||
"a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7",
|
||||
"desc",
|
||||
'{"s":"created_at","v":"1716200000000000","id":"a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7","o":"desc"}',
|
||||
),
|
||||
(
|
||||
"size",
|
||||
"1024",
|
||||
"asset-123",
|
||||
"asc",
|
||||
'{"s":"size","v":"1024","id":"asset-123","o":"asc"}',
|
||||
),
|
||||
(
|
||||
"name",
|
||||
"my-asset.png",
|
||||
"asset-abc",
|
||||
"desc",
|
||||
'{"s":"name","v":"my-asset.png","id":"asset-abc","o":"desc"}',
|
||||
),
|
||||
(
|
||||
"name",
|
||||
"foo<bar>&baz.png",
|
||||
"asset-html",
|
||||
"desc",
|
||||
# `<`, `>`, `&` escape to <, >, & in the value.
|
||||
'{"s":"name","v":"foo\\u003cbar\\u003e\\u0026baz.png","id":"asset-html","o":"desc"}',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_encoded_payload_shape_pinned(self, sort_field, value, id, order, expected_payload):
|
||||
encoded = encode_cursor(sort_field, value, id, order=order)
|
||||
decoded_bytes = base64.urlsafe_b64decode(encoded + "=" * (-len(encoded) % 4))
|
||||
assert decoded_bytes.decode("utf-8") == expected_payload, (
|
||||
f"wire format drifted for sort={sort_field!r}, value={value!r}:\n"
|
||||
f" expected: {expected_payload!r}\n"
|
||||
f" actual: {decoded_bytes.decode('utf-8')!r}"
|
||||
)
|
||||
349
tests-unit/assets_test/test_list_cursor.py
Normal file
349
tests-unit/assets_test/test_list_cursor.py
Normal file
@ -0,0 +1,349 @@
|
||||
"""Integration tests for cursor-based pagination on GET /api/assets.
|
||||
|
||||
These tests exercise the handler/service/query path end-to-end;
|
||||
cursor-encoding-level tests live in
|
||||
tests-unit/assets_test/services/test_cursor.py.
|
||||
"""
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def _seed(asset_factory, make_asset_bytes, count: int, tag: str) -> list[str]:
|
||||
names = [f"cursor_{i:02d}.safetensors" for i in range(count)]
|
||||
for n in names:
|
||||
asset_factory(
|
||||
n,
|
||||
["models", "checkpoints", "unit-tests", tag],
|
||||
{},
|
||||
make_asset_bytes(n, size=2048),
|
||||
)
|
||||
return sorted(names)
|
||||
|
||||
|
||||
def test_cursor_pages_all_items_in_order(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = _seed(asset_factory, make_asset_bytes, count=5, tag="cursor-walk")
|
||||
|
||||
params = {
|
||||
"include_tags": "unit-tests,cursor-walk",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "2",
|
||||
}
|
||||
|
||||
seen: list[str] = []
|
||||
after: str | None = None
|
||||
pages = 0
|
||||
while True:
|
||||
page_params = dict(params)
|
||||
if after is not None:
|
||||
page_params["after"] = after
|
||||
r = http.get(api_base + "/api/assets", params=page_params, timeout=120)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
seen.extend(a["name"] for a in body["assets"])
|
||||
pages += 1
|
||||
after = body.get("next_cursor")
|
||||
if after is None:
|
||||
break
|
||||
assert body["has_more"] is True
|
||||
assert pages < 10, "guard against runaway cursor loop"
|
||||
|
||||
assert seen == names, f"expected {names}, got {seen}"
|
||||
# Last page should have has_more False
|
||||
assert body["has_more"] is False
|
||||
assert "next_cursor" not in body
|
||||
|
||||
|
||||
def test_cursor_invalid_returns_400(http: requests.Session, api_base: str):
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"after": "not-a-real-cursor", "sort": "created_at"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 400, r.text
|
||||
body = r.json()
|
||||
assert body["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_sort_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
_seed(asset_factory, make_asset_bytes, count=2, tag="cursor-mismatch")
|
||||
|
||||
# Take a real cursor minted for sort=name.
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-mismatch",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
|
||||
# Replay against sort=created_at — should fail with INVALID_CURSOR.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"after": cursor, "sort": "created_at"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 400, r2.text
|
||||
assert r2.json()["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_wins_over_offset(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-vs-offset")
|
||||
|
||||
# Take a cursor that points past the first item.
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-vs-offset",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
|
||||
# Pass both 'after' and a large offset. Cursor must win; offset is ignored.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-vs-offset",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
"after": cursor,
|
||||
"offset": "999",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 200
|
||||
body = r2.json()
|
||||
# Should land on the second name in sorted order — not skip ahead by 999.
|
||||
assert [a["name"] for a in body["assets"]] == [names[1]]
|
||||
|
||||
|
||||
def test_next_cursor_absent_when_no_more_results(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
_seed(asset_factory, make_asset_bytes, count=2, tag="cursor-exhaust")
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-exhaust",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "50",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
assert body["has_more"] is False
|
||||
assert "next_cursor" not in body
|
||||
|
||||
|
||||
def test_cursor_pagination_first_page_mints_cursor(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""First-page request (no `after`) must still return `next_cursor` when
|
||||
more rows exist, or pagination is unreachable from a cold start.
|
||||
"""
|
||||
_seed(asset_factory, make_asset_bytes, count=3, tag="cursor-first-page")
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,cursor-first-page", "sort": "name", "order": "asc", "limit": "2"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
assert body["has_more"] is True
|
||||
assert body.get("next_cursor"), "first page must mint a cursor when more rows exist"
|
||||
|
||||
|
||||
def test_cursor_no_spurious_cursor_when_page_size_equals_remainder(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""When `total` is an exact multiple of `limit`, the final page must
|
||||
NOT carry a next_cursor — there is nothing past it.
|
||||
"""
|
||||
_seed(asset_factory, make_asset_bytes, count=4, tag="cursor-exact-multiple")
|
||||
# Page 1
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
# Page 2 — should exhaust the set with no cursor for a phantom page 3
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2", "after": cursor},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 200, r2.text
|
||||
body = r2.json()
|
||||
assert len(body["assets"]) == 2
|
||||
assert body["has_more"] is False
|
||||
assert "next_cursor" not in body
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sort_field", ["created_at", "updated_at", "size"])
|
||||
def test_cursor_walks_for_non_name_sorts(sort_field, http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""Cursor pagination must work for every sort field the contract claims.
|
||||
|
||||
Without this, the `created_at` / `updated_at` (time-encoded micros) and
|
||||
`size` (int-encoded) cursor paths go entirely unexercised end-to-end.
|
||||
"""
|
||||
# Sizes increase strictly by index, so `size desc` has a deterministic
|
||||
# expected order. Time-based sorts (created_at / updated_at) can tie when
|
||||
# rows are inserted faster than the DB's timestamp resolution; for those
|
||||
# we check coverage and no-duplicates and let the keyset tiebreaker do
|
||||
# the rest, instead of sleeping between inserts and asserting an order
|
||||
# that depends on clock granularity.
|
||||
names = []
|
||||
for i in range(4):
|
||||
n = f"cursor_{sort_field}_{i:02d}.safetensors"
|
||||
asset_factory(n, ["models", "checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i))
|
||||
names.append(n)
|
||||
|
||||
params = {
|
||||
"include_tags": f"unit-tests,cursor-{sort_field}",
|
||||
"sort": sort_field,
|
||||
"order": "desc",
|
||||
"limit": "2",
|
||||
}
|
||||
seen: list[str] = []
|
||||
after: str | None = None
|
||||
pages = 0
|
||||
while True:
|
||||
page_params = dict(params)
|
||||
if after is not None:
|
||||
page_params["after"] = after
|
||||
r = http.get(api_base + "/api/assets", params=page_params, timeout=120)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
seen.extend(a["name"] for a in body["assets"])
|
||||
after = body.get("next_cursor")
|
||||
pages += 1
|
||||
if after is None:
|
||||
break
|
||||
assert pages < 10, "guard against runaway cursor loop"
|
||||
|
||||
# No duplicates: a faulty keyset boundary that returns the same row across
|
||||
# two pages must fail this check.
|
||||
assert len(seen) == len(set(seen)), (
|
||||
f"cursor walk repeated rows for sort={sort_field}: {seen}"
|
||||
)
|
||||
# Full coverage: every seeded asset reached exactly once.
|
||||
assert set(seen) == set(names), (
|
||||
f"missing items for sort={sort_field}: expected {set(names)}, got {set(seen)}"
|
||||
)
|
||||
# Strict order check for the only field with a clock-independent ordering.
|
||||
if sort_field == "size":
|
||||
assert seen == list(reversed(names)), (
|
||||
f"size cursor walked out of order: got {seen}, expected {list(reversed(names))}"
|
||||
)
|
||||
|
||||
|
||||
def test_cursor_order_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""A cursor minted under desc order replayed against asc must 400, not
|
||||
silently walk the wrong direction."""
|
||||
_seed(asset_factory, make_asset_bytes, count=3, tag="cursor-order-flip")
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-order-flip",
|
||||
"sort": "name",
|
||||
"order": "desc",
|
||||
"limit": "1",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
|
||||
# Replay with order flipped to asc — server must reject the cursor.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-order-flip",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
"after": cursor,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 400, r2.text
|
||||
assert r2.json()["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_invalid_cursor_at_microsecond_boundary(http: requests.Session, api_base: str):
|
||||
"""A cursor carrying an out-of-range microsecond timestamp must map to
|
||||
400 INVALID_CURSOR, not 500."""
|
||||
import base64
|
||||
import json
|
||||
# 10^18 microseconds ≈ year 33658, well past datetime.MAX_YEAR.
|
||||
# `o` and `order=` must be set; otherwise decode fails earlier on the
|
||||
# missing-order branch and the µs-overflow path is never exercised.
|
||||
payload = {"s": "created_at", "o": "desc", "v": "999999999999999999999", "id": "asset-x"}
|
||||
raw = json.dumps(payload, separators=(",", ":")).encode("utf-8")
|
||||
cursor = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"after": cursor, "sort": "created_at", "order": "desc"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 400, r.text
|
||||
assert r.json()["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_pagination_stable_after_delete(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-delete")
|
||||
|
||||
# Page 1.
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-delete",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "2",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
page1_names = [a["name"] for a in body["assets"]]
|
||||
cursor = body["next_cursor"]
|
||||
assert cursor is not None
|
||||
assert page1_names == names[:2]
|
||||
|
||||
# Delete an item from page 1 (already returned) — cursor should still
|
||||
# locate the next page from where it was minted, not re-index.
|
||||
target_id = body["assets"][0]["id"]
|
||||
d = http.delete(api_base + f"/api/assets/{target_id}", timeout=120)
|
||||
assert d.status_code in (200, 204), d.text
|
||||
|
||||
# Page 2 via cursor.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-delete",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "2",
|
||||
"after": cursor,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 200, r2.text
|
||||
body2 = r2.json()
|
||||
assert [a["name"] for a in body2["assets"]] == names[2:]
|
||||
Reference in New Issue
Block a user