mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-23 09:38:08 +08:00
Compare commits
26 Commits
matt/asset
...
matt/opena
| Author | SHA1 | Date | |
|---|---|---|---|
| 343db1d315 | |||
| 187442cca4 | |||
| c3c881f37b | |||
| 7984a6a38e | |||
| e75b739c1d | |||
| 112fcd5f3b | |||
| 1579bbb52d | |||
| 93888ae8e3 | |||
| 38ebc19037 | |||
| 9650570378 | |||
| f48c32871b | |||
| 8edff549e3 | |||
| 8fecef0686 | |||
| 5d681a5420 | |||
| 32e58393b8 | |||
| b293f8cefd | |||
| 2ca1480f91 | |||
| 6ecf5eca7a | |||
| 03e511862e | |||
| aab41a9ddb | |||
| 4259a0c7c3 | |||
| af3d9b60af | |||
| 7b7c5fed7c | |||
| 1668aaf037 | |||
| ea174d3f12 | |||
| 9f9b32ed97 |
519
.github/workflows/backport_release.yaml
vendored
Normal file
519
.github/workflows/backport_release.yaml
vendored
Normal file
@ -0,0 +1,519 @@
|
||||
name: Backport Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
commit:
|
||||
description: 'Full 40-char SHA of the tip commit of the backport source branch (the PR head commit that passed tests). The branch is resolved from this SHA and must be unique.'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
checks: read
|
||||
|
||||
jobs:
|
||||
backport-release:
|
||||
name: Create backport release
|
||||
runs-on: ubuntu-latest
|
||||
environment: backport release
|
||||
|
||||
steps:
|
||||
- name: Generate GitHub App token
|
||||
id: app-token
|
||||
uses: actions/create-github-app-token@bcd2ba49218906704ab6c1aa796996da409d3eb1
|
||||
with:
|
||||
app-id: ${{ secrets.FEN_RELEASE_APP_ID }}
|
||||
private-key: ${{ secrets.FEN_RELEASE_PRIVATE_KEY }}
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
|
||||
with:
|
||||
token: ${{ steps.app-token.outputs.token }}
|
||||
fetch-depth: 0
|
||||
fetch-tags: true
|
||||
|
||||
- name: Configure git
|
||||
run: |
|
||||
git config user.name "fen-release[bot]"
|
||||
git config user.email "fen-release[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Resolve source branch from commit SHA
|
||||
id: resolve
|
||||
env:
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Require a full 40-char lowercase-hex SHA. Short SHAs are ambiguous
|
||||
# and we will be comparing this value against API responses (PR head
|
||||
# SHA, ref tips) that always return the full form.
|
||||
if [[ ! "${SOURCE_COMMIT}" =~ ^[0-9a-f]{40}$ ]]; then
|
||||
echo "::error::Input commit '${SOURCE_COMMIT}' is not a full 40-char lowercase hex SHA."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Fetch all remote branches so we can search for which one(s) point
|
||||
# at this SHA. `actions/checkout` with fetch-depth: 0 fetches full
|
||||
# history of the checked-out ref but does not necessarily populate
|
||||
# every refs/remotes/origin/*, so do it explicitly.
|
||||
git fetch --prune origin '+refs/heads/*:refs/remotes/origin/*'
|
||||
|
||||
# Verify the commit actually exists in this repo's object DB.
|
||||
if ! git cat-file -e "${SOURCE_COMMIT}^{commit}" 2>/dev/null; then
|
||||
echo "::error::Commit ${SOURCE_COMMIT} was not found in the repository."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Find every remote branch whose tip == SOURCE_COMMIT. Exactly one
|
||||
# branch must point at it. If zero, the commit isn't anyone's tip
|
||||
# (likely stale, force-pushed past, or never the PR head). If more
|
||||
# than one, the (branch -> SHA) mapping is ambiguous and we refuse
|
||||
# to guess — the operator must give us a unique branch to release.
|
||||
mapfile -t matching_branches < <(
|
||||
git for-each-ref \
|
||||
--format='%(refname:strip=3)' \
|
||||
--points-at="${SOURCE_COMMIT}" \
|
||||
refs/remotes/origin/ \
|
||||
| grep -vx 'HEAD' || true
|
||||
)
|
||||
|
||||
if [[ "${#matching_branches[@]}" -eq 0 ]]; then
|
||||
echo "::error::No branch on origin has ${SOURCE_COMMIT} as its tip."
|
||||
echo "::error::Either the branch was updated after you copied this SHA, or this commit was never the head of a branch."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "${#matching_branches[@]}" -gt 1 ]]; then
|
||||
echo "::error::More than one branch on origin has ${SOURCE_COMMIT} as its tip; cannot pick one:"
|
||||
for b in "${matching_branches[@]}"; do
|
||||
echo "::error:: - ${b}"
|
||||
done
|
||||
echo "::error::Refusing to proceed with an ambiguous source branch."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
source_branch="${matching_branches[0]}"
|
||||
|
||||
if [[ "${source_branch}" == "${DEFAULT_BRANCH}" ]]; then
|
||||
echo "::error::Source branch must not be the default branch ('${DEFAULT_BRANCH}')."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Resolved commit ${SOURCE_COMMIT} to branch '${source_branch}'."
|
||||
echo "source_branch=${source_branch}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Determine latest stable release
|
||||
id: latest
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.app-token.outputs.token }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# List all tags matching vMAJOR.MINOR.PATCH and pick the highest by numeric
|
||||
# comparison of each component. We DO NOT use `sort -V` because it treats
|
||||
# v0.19.99 as higher than v0.20.1.
|
||||
latest_tag="$(
|
||||
git tag --list 'v[0-9]*.[0-9]*.[0-9]*' \
|
||||
| grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' \
|
||||
| awk -F'[v.]' '{ printf "%010d %010d %010d %s\n", $2, $3, $4, $0 }' \
|
||||
| sort -k1,1n -k2,2n -k3,3n \
|
||||
| tail -n1 \
|
||||
| awk '{print $4}'
|
||||
)"
|
||||
|
||||
if [[ -z "${latest_tag}" ]]; then
|
||||
echo "::error::No stable release tags (vMAJOR.MINOR.PATCH) were found."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Parse components
|
||||
ver="${latest_tag#v}"
|
||||
major="${ver%%.*}"
|
||||
rest="${ver#*.}"
|
||||
minor="${rest%%.*}"
|
||||
patch="${rest#*.}"
|
||||
|
||||
new_patch=$((patch + 1))
|
||||
new_version="v${major}.${minor}.${new_patch}"
|
||||
release_branch="release/v${major}.${minor}"
|
||||
|
||||
latest_sha="$(git rev-list -n 1 "refs/tags/${latest_tag}")"
|
||||
|
||||
echo "latest_tag=${latest_tag}" >> "$GITHUB_OUTPUT"
|
||||
echo "latest_sha=${latest_sha}" >> "$GITHUB_OUTPUT"
|
||||
echo "major=${major}" >> "$GITHUB_OUTPUT"
|
||||
echo "minor=${minor}" >> "$GITHUB_OUTPUT"
|
||||
echo "patch=${patch}" >> "$GITHUB_OUTPUT"
|
||||
echo "new_version=${new_version}" >> "$GITHUB_OUTPUT"
|
||||
echo "new_version_no_v=${major}.${minor}.${new_patch}" >> "$GITHUB_OUTPUT"
|
||||
echo "release_branch=${release_branch}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
echo "Latest stable release: ${latest_tag} (${latest_sha})"
|
||||
echo "New version will be: ${new_version}"
|
||||
echo "Release branch: ${release_branch}"
|
||||
|
||||
- name: Validate source branch is cut directly from the latest stable release
|
||||
env:
|
||||
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
|
||||
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Use the user-provided SHA directly rather than re-resolving the branch
|
||||
# tip — the resolve step already proved the branch tip equals SOURCE_COMMIT,
|
||||
# and pinning to the SHA here makes the rest of the job TOCTOU-safe against
|
||||
# someone pushing to the branch mid-run.
|
||||
source_sha="${SOURCE_COMMIT}"
|
||||
|
||||
# Walking first-parent from the source tip must reach LATEST_TAG_SHA.
|
||||
# We capture rev-list into a variable and grep against a here-string
|
||||
# rather than piping `rev-list | grep -q`: under `set -o pipefail`,
|
||||
# `grep -q` would exit on first match and SIGPIPE the still-streaming
|
||||
# `rev-list`, propagating exit 141 as a spurious "not found".
|
||||
first_parent_chain="$(git rev-list --first-parent "${source_sha}")"
|
||||
if ! grep -Fxq "${LATEST_TAG_SHA}" <<< "${first_parent_chain}"; then
|
||||
echo "::error::Source branch '${SOURCE_BRANCH}' is not cut from '${LATEST_TAG}'."
|
||||
echo "::error::Its first-parent history does not include ${LATEST_TAG_SHA}."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Additionally, every commit added on top of the tag (the set we are
|
||||
# about to publish) must itself be a descendant of the tag along
|
||||
# first-parent — i.e. no sibling commits from master sneak in via a
|
||||
# non-first-parent path. Enforce by requiring that the symmetric
|
||||
# difference is empty in one direction: commits in source that are
|
||||
# NOT first-parent-reachable from source starting at the tag.
|
||||
# We do this by intersecting:
|
||||
# A = commits reachable from source but not from tag (full DAG)
|
||||
# B = commits on the first-parent chain from source down to tag
|
||||
# and requiring A == B.
|
||||
all_added="$(git rev-list "${LATEST_TAG_SHA}..${source_sha}" | sort)"
|
||||
first_parent_added="$(
|
||||
git rev-list --first-parent "${LATEST_TAG_SHA}..${source_sha}" | sort
|
||||
)"
|
||||
|
||||
if [[ "${all_added}" != "${first_parent_added}" ]]; then
|
||||
echo "::error::Source branch '${SOURCE_BRANCH}' contains commits not on its first-parent chain from '${LATEST_TAG}'."
|
||||
echo "::error::This usually means the branch was cut from master (not from the tag) or contains a merge from master."
|
||||
echo "Commits reachable but not on first-parent chain:"
|
||||
comm -23 <(printf '%s\n' "${all_added}") <(printf '%s\n' "${first_parent_added}") \
|
||||
| while read -r sha; do
|
||||
echo " $(git log -1 --format='%h %s' "${sha}")"
|
||||
done
|
||||
exit 1
|
||||
fi
|
||||
|
||||
added_count="$(printf '%s\n' "${all_added}" | grep -c . || true)"
|
||||
echo "Source branch is cut directly from ${LATEST_TAG} with ${added_count} commit(s) on top."
|
||||
|
||||
- name: Validate PR exists, is open, named correctly, has latest commit, and checks pass
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.app-token.outputs.token }}
|
||||
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
|
||||
REPO: ${{ github.repository }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
expected_title="ComfyUI backport release ${NEW_VERSION}"
|
||||
|
||||
# Find open PRs from this branch into master. The --state open filter
|
||||
# is load-bearing: a closed/merged PR with passing checks must not be
|
||||
# accepted as authorization for a new release.
|
||||
pr_json="$(
|
||||
gh pr list \
|
||||
--repo "${REPO}" \
|
||||
--state open \
|
||||
--head "${SOURCE_BRANCH}" \
|
||||
--base master \
|
||||
--json number,title,headRefOid,state \
|
||||
--limit 10
|
||||
)"
|
||||
|
||||
pr_count="$(echo "${pr_json}" | jq 'length')"
|
||||
if [[ "${pr_count}" -eq 0 ]]; then
|
||||
echo "::error::No open PR found from '${SOURCE_BRANCH}' into 'master'. The PR must exist and be open."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Pick the PR matching the expected title
|
||||
pr_number="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
|
||||
map(select(.title == $t)) | .[0].number // empty
|
||||
')"
|
||||
pr_head_sha="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
|
||||
map(select(.title == $t)) | .[0].headRefOid // empty
|
||||
')"
|
||||
|
||||
if [[ -z "${pr_number}" ]]; then
|
||||
echo "::error::No open PR from '${SOURCE_BRANCH}' into 'master' is titled '${expected_title}'."
|
||||
echo "Found PRs:"
|
||||
echo "${pr_json}" | jq -r '.[] | " #\(.number): \(.title)"'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# The PR's current head commit must equal the SHA the operator gave us.
|
||||
# This is what closes the door on releasing stale code: if anyone has
|
||||
# pushed to the branch since the operator validated tests passed, the
|
||||
# PR head will have advanced past SOURCE_COMMIT and we abort. (The
|
||||
# resolve step already proved the branch tip == SOURCE_COMMIT; this
|
||||
# ties that same SHA to the PR that authorizes the release.)
|
||||
if [[ "${pr_head_sha}" != "${SOURCE_COMMIT}" ]]; then
|
||||
echo "::error::PR #${pr_number} head commit is ${pr_head_sha}, but the operator-provided commit is ${SOURCE_COMMIT}."
|
||||
echo "::error::The PR has new commits since this release was authorized. Re-run with the new head SHA after verifying its checks."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Found open PR #${pr_number} titled '${expected_title}' at head ${pr_head_sha} (matches operator-provided commit)."
|
||||
|
||||
# Verify all check runs on the head commit have completed successfully.
|
||||
# A check is considered passing if conclusion is success, neutral, or skipped.
|
||||
checks_json="$(
|
||||
gh api \
|
||||
--paginate \
|
||||
"repos/${REPO}/commits/${pr_head_sha}/check-runs" \
|
||||
--jq '.check_runs[] | {name: .name, status: .status, conclusion: .conclusion}'
|
||||
)"
|
||||
|
||||
if [[ -z "${checks_json}" ]]; then
|
||||
echo "::error::No check runs found on PR head commit ${pr_head_sha}."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Check runs on ${pr_head_sha}:"
|
||||
echo "${checks_json}" | jq -s '.'
|
||||
|
||||
failing="$(echo "${checks_json}" | jq -s '
|
||||
map(select(
|
||||
.status != "completed"
|
||||
or (.conclusion as $c
|
||||
| ["success","neutral","skipped"]
|
||||
| index($c) | not)
|
||||
))
|
||||
')"
|
||||
|
||||
failing_count="$(echo "${failing}" | jq 'length')"
|
||||
if [[ "${failing_count}" -gt 0 ]]; then
|
||||
echo "::error::One or more checks have not passed on PR head commit ${pr_head_sha}:"
|
||||
echo "${failing}" | jq -r '.[] | " - \(.name): status=\(.status) conclusion=\(.conclusion)"'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "All checks have passed on ${pr_head_sha}."
|
||||
|
||||
- name: Prepare release branch
|
||||
id: prepare
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.app-token.outputs.token }}
|
||||
REPO: ${{ github.repository }}
|
||||
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
|
||||
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
|
||||
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
|
||||
PATCH: ${{ steps.latest.outputs.patch }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Try to fetch the release branch. If patch == 0, it shouldn't exist yet
|
||||
# and we'll create it from the latest stable tag. If patch > 0, it must
|
||||
# already exist and its tip must equal the latest stable tag commit (i.e.
|
||||
# the previous patch release).
|
||||
if git ls-remote --exit-code --heads origin "${RELEASE_BRANCH}" >/dev/null 2>&1; then
|
||||
echo "Release branch '${RELEASE_BRANCH}' already exists on origin."
|
||||
git fetch origin "refs/heads/${RELEASE_BRANCH}:refs/remotes/origin/${RELEASE_BRANCH}"
|
||||
git checkout -B "${RELEASE_BRANCH}" "refs/remotes/origin/${RELEASE_BRANCH}"
|
||||
|
||||
current_tip="$(git rev-parse HEAD)"
|
||||
if [[ "${current_tip}" != "${LATEST_TAG_SHA}" ]]; then
|
||||
echo "::error::Release branch '${RELEASE_BRANCH}' tip (${current_tip}) is not at the latest stable release '${LATEST_TAG}' (${LATEST_TAG_SHA})."
|
||||
echo "::error::Refusing to release on top of a divergent branch."
|
||||
exit 1
|
||||
fi
|
||||
echo "branch_existed=true" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
if [[ "${PATCH}" != "0" ]]; then
|
||||
echo "::error::Release branch '${RELEASE_BRANCH}' does not exist on origin, but the latest stable release '${LATEST_TAG}' has patch=${PATCH} (>0). This is inconsistent."
|
||||
exit 1
|
||||
fi
|
||||
echo "Release branch '${RELEASE_BRANCH}' does not exist. Creating from ${LATEST_TAG}."
|
||||
git checkout -B "${RELEASE_BRANCH}" "refs/tags/${LATEST_TAG}"
|
||||
echo "branch_existed=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Fast-forward merge source branch into release branch
|
||||
env:
|
||||
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# --ff-only guarantees no merge commit is created. If a fast-forward is
|
||||
# not possible (i.e. the release branch has commits the source branch
|
||||
# doesn't), the merge will fail and we abort. Because we already validated
|
||||
# that the source branch is rooted on the latest stable tag, and the
|
||||
# release branch tip equals that same tag, this fast-forward should
|
||||
# always succeed for a well-formed backport branch.
|
||||
#
|
||||
# We merge the operator-provided SHA, not the branch ref, so a push to
|
||||
# the branch in the window between resolve and now cannot smuggle new
|
||||
# commits into the release.
|
||||
if ! git merge --ff-only "${SOURCE_COMMIT}"; then
|
||||
echo "::error::Cannot fast-forward '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}'). A merge commit would be required. Aborting."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Fast-forwarded '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}')."
|
||||
|
||||
- name: Bump version files
|
||||
env:
|
||||
NEW_VERSION_NO_V: ${{ steps.latest.outputs.new_version_no_v }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ ! -f comfyui_version.py ]]; then
|
||||
echo "::error::comfyui_version.py not found in repo root."
|
||||
exit 1
|
||||
fi
|
||||
if [[ ! -f pyproject.toml ]]; then
|
||||
echo "::error::pyproject.toml not found in repo root."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Replace the version string in comfyui_version.py.
|
||||
# Expected format: __version__ = "X.Y.Z"
|
||||
python3 - "$NEW_VERSION_NO_V" <<'PY'
|
||||
import re, sys, pathlib
|
||||
new = sys.argv[1]
|
||||
|
||||
p = pathlib.Path("comfyui_version.py")
|
||||
src = p.read_text()
|
||||
new_src, n = re.subn(
|
||||
r'(__version__\s*=\s*[\'"])[^\'"]+([\'"])',
|
||||
lambda m: f'{m.group(1)}{new}{m.group(2)}',
|
||||
src,
|
||||
count=1,
|
||||
)
|
||||
if n != 1:
|
||||
sys.exit("Could not find __version__ assignment in comfyui_version.py")
|
||||
p.write_text(new_src)
|
||||
|
||||
p = pathlib.Path("pyproject.toml")
|
||||
src = p.read_text()
|
||||
# Replace the first `version = "..."` inside [project] or [tool.poetry].
|
||||
new_src, n = re.subn(
|
||||
r'(?m)^(version\s*=\s*")[^"]+(")',
|
||||
lambda m: f'{m.group(1)}{new}{m.group(2)}',
|
||||
src,
|
||||
count=1,
|
||||
)
|
||||
if n != 1:
|
||||
sys.exit("Could not find version assignment in pyproject.toml")
|
||||
p.write_text(new_src)
|
||||
PY
|
||||
|
||||
echo "Updated version to ${NEW_VERSION_NO_V} in comfyui_version.py and pyproject.toml."
|
||||
git --no-pager diff -- comfyui_version.py pyproject.toml
|
||||
|
||||
- name: Commit version bump and tag release
|
||||
env:
|
||||
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
git add comfyui_version.py pyproject.toml
|
||||
git commit -m "ComfyUI ${NEW_VERSION}"
|
||||
|
||||
if git rev-parse -q --verify "refs/tags/${NEW_VERSION}" >/dev/null; then
|
||||
echo "::error::Tag ${NEW_VERSION} already exists locally."
|
||||
exit 1
|
||||
fi
|
||||
git tag "${NEW_VERSION}"
|
||||
|
||||
- name: Verify tag does not already exist on origin
|
||||
env:
|
||||
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if git ls-remote --exit-code --tags origin "refs/tags/${NEW_VERSION}" >/dev/null 2>&1; then
|
||||
echo "::error::Tag ${NEW_VERSION} already exists on origin. Aborting."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Push release branch and tag
|
||||
env:
|
||||
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
|
||||
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Push the branch first, then the tag. Atomic-ish: if the branch push
|
||||
# fails we never publish the tag.
|
||||
git push origin "refs/heads/${RELEASE_BRANCH}:refs/heads/${RELEASE_BRANCH}"
|
||||
git push origin "refs/tags/${NEW_VERSION}"
|
||||
|
||||
echo "Released ${NEW_VERSION} on ${RELEASE_BRANCH}."
|
||||
|
||||
- name: Delete remote source branch
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.app-token.outputs.token }}
|
||||
REPO: ${{ github.repository }}
|
||||
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Belt-and-braces: the resolve step already refuses the default branch,
|
||||
# but never delete the default or the release branch under any
|
||||
# circumstances.
|
||||
if [[ "${SOURCE_BRANCH}" == "${DEFAULT_BRANCH}" || "${SOURCE_BRANCH}" == "${RELEASE_BRANCH}" ]]; then
|
||||
echo "::error::Refusing to delete '${SOURCE_BRANCH}' (matches default or release branch)."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Delete the source branch on origin, but only if its tip is still the
|
||||
# SHA we released from. If someone pushed new commits to it after we
|
||||
# resolved it, leave it alone — those commits would be silently lost.
|
||||
current_tip="$(git ls-remote origin "refs/heads/${SOURCE_BRANCH}" | awk '{print $1}')"
|
||||
if [[ -z "${current_tip}" ]]; then
|
||||
echo "Source branch '${SOURCE_BRANCH}' no longer exists on origin; nothing to delete."
|
||||
exit 0
|
||||
fi
|
||||
if [[ "${current_tip}" != "${SOURCE_COMMIT}" ]]; then
|
||||
echo "::warning::Source branch '${SOURCE_BRANCH}' tip (${current_tip}) no longer matches released commit (${SOURCE_COMMIT}). Leaving it in place."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
git push origin --delete "refs/heads/${SOURCE_BRANCH}"
|
||||
echo "Deleted remote branch '${SOURCE_BRANCH}'."
|
||||
|
||||
- name: Summary
|
||||
if: always()
|
||||
env:
|
||||
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
|
||||
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
|
||||
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
|
||||
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
run: |
|
||||
# SOURCE_BRANCH is empty if the resolve step never produced an output
|
||||
# (e.g. the workflow failed in or before that step). Show a placeholder
|
||||
# in that case so the summary table still renders cleanly.
|
||||
source_branch_display="${SOURCE_BRANCH:-(unresolved)}"
|
||||
{
|
||||
echo "## Backport release"
|
||||
echo ""
|
||||
echo "| Field | Value |"
|
||||
echo "|---|---|"
|
||||
echo "| Source commit | \`${SOURCE_COMMIT}\` |"
|
||||
echo "| Source branch | \`${source_branch_display}\` |"
|
||||
echo "| Previous stable | \`${LATEST_TAG}\` |"
|
||||
echo "| New version | \`${NEW_VERSION}\` |"
|
||||
echo "| Release branch | \`${RELEASE_BRANCH}\` |"
|
||||
} >> "$GITHUB_STEP_SUMMARY"
|
||||
@ -20,7 +20,7 @@
|
||||
[website-url]: https://www.comfy.org/
|
||||
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
|
||||
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
|
||||
[discord-url]: https://www.comfy.org/discord
|
||||
[discord-url]: https://discord.com/invite/comfyorg
|
||||
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
|
||||
[twitter-url]: https://x.com/ComfyUI
|
||||
|
||||
|
||||
@ -401,16 +401,12 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
)
|
||||
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
# tag[1] may be the standalone category ("checkpoints") or the
|
||||
# slash-joined shape ("checkpoints/flux/...") that
|
||||
# `get_name_and_tags_from_asset_path` and cloud both emit. Match
|
||||
# `resolve_destination_from_tags` by extracting the first segment.
|
||||
category = spec.tags[1].split("/", 1)[0] if len(spec.tags) >= 2 else ""
|
||||
if (
|
||||
len(spec.tags) < 2
|
||||
or category not in folder_paths.folder_names_and_paths
|
||||
or spec.tags[1] not in folder_paths.folder_names_and_paths
|
||||
):
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
category = spec.tags[1] if len(spec.tags) >= 2 else ""
|
||||
return _build_error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{category}'"
|
||||
)
|
||||
|
||||
@ -327,12 +327,7 @@ def list_references_page(
|
||||
select(AssetReferenceTag.asset_reference_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
|
||||
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
|
||||
# Preserve insertion order so the structural first tag (the root
|
||||
# category like "models") stays in position 0 and the path-derived
|
||||
# sub-path tag stays in position 1, matching cloud's behavior.
|
||||
# tag_name is a deterministic tiebreaker when multiple tags share
|
||||
# an added_at (same-batch insert via set_reference_tags).
|
||||
.order_by(AssetReferenceTag.added_at.asc(), AssetReferenceTag.tag_name.asc())
|
||||
.order_by(AssetReferenceTag.tag_name.asc())
|
||||
)
|
||||
for ref_id, tag_name in rows.all():
|
||||
tag_map[ref_id].append(tag_name)
|
||||
@ -360,8 +355,7 @@ def fetch_reference_asset_and_tags(
|
||||
build_visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetReference.tags))
|
||||
# See list_references_page for the rationale behind ordering by added_at.
|
||||
.order_by(AssetReferenceTag.added_at.asc(), Tag.name.asc())
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -21,12 +20,7 @@ from app.assets.database.queries.common import (
|
||||
build_visible_owner_clause,
|
||||
iter_row_chunks,
|
||||
)
|
||||
from app.assets.helpers import (
|
||||
escape_sql_like_string,
|
||||
expand_bucket_prefixes,
|
||||
get_utc_now,
|
||||
normalize_tags,
|
||||
)
|
||||
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -50,26 +44,6 @@ class SetTagsResult:
|
||||
total: list[str]
|
||||
|
||||
|
||||
def _next_added_at_base(session: Session, reference_id: str) -> datetime:
|
||||
"""Return a timestamp strictly greater than any existing
|
||||
`added_at` for this reference. On platforms where the wall clock
|
||||
has insufficient resolution between back-to-back commits (notably
|
||||
Windows), two write batches on the same reference can otherwise
|
||||
share a microsecond — the `ORDER BY added_at, tag_name` retrieval
|
||||
then falls back to the alphabetic tiebreaker and user-tier tags
|
||||
sort ahead of path-tier tags they were meant to follow.
|
||||
"""
|
||||
existing_max = session.execute(
|
||||
sa.select(sa.func.max(AssetReferenceTag.added_at)).where(
|
||||
AssetReferenceTag.asset_reference_id == reference_id
|
||||
)
|
||||
).scalar()
|
||||
now = get_utc_now()
|
||||
if existing_max is None:
|
||||
return now
|
||||
return max(existing_max + timedelta(microseconds=1), now)
|
||||
|
||||
|
||||
def validate_tags_exist(session: Session, tags: list[str]) -> None:
|
||||
"""Raise ValueError if any of the given tag names do not exist."""
|
||||
existing_tag_names = set(
|
||||
@ -103,13 +77,7 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
|
||||
session.execute(
|
||||
select(AssetReferenceTag.tag_name)
|
||||
.where(AssetReferenceTag.asset_reference_id == reference_id)
|
||||
# Match the response-path ordering used by
|
||||
# list_references_page / fetch_reference_asset_and_tags so
|
||||
# upload responses and subsequent GETs agree on tag order.
|
||||
.order_by(
|
||||
AssetReferenceTag.added_at.asc(),
|
||||
AssetReferenceTag.tag_name.asc(),
|
||||
)
|
||||
.order_by(AssetReferenceTag.tag_name.asc())
|
||||
)
|
||||
).all()
|
||||
]
|
||||
@ -121,7 +89,7 @@ def set_reference_tags(
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> SetTagsResult:
|
||||
desired = expand_bucket_prefixes(normalize_tags(tags))
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
@ -130,22 +98,15 @@ def set_reference_tags(
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
# Stagger added_at by microsecond per tag so the retrieval ORDER BY
|
||||
# added_at preserves input order. Per-tag get_utc_now() calls can
|
||||
# collide at microsecond resolution on fast machines, dropping the
|
||||
# query to the tag_name alphabetical tiebreaker — same fix as in
|
||||
# batch_insert_seed_assets. Read max(existing) so this batch sorts
|
||||
# strictly after any prior batch on the same reference.
|
||||
base_ts = _next_added_at_base(session, reference_id)
|
||||
session.add_all(
|
||||
[
|
||||
AssetReferenceTag(
|
||||
asset_reference_id=reference_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=base_ts + timedelta(microseconds=i),
|
||||
added_at=get_utc_now(),
|
||||
)
|
||||
for i, t in enumerate(to_add)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
@ -175,7 +136,7 @@ def add_tags_to_reference(
|
||||
if not ref:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
norm = expand_bucket_prefixes(normalize_tags(tags))
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_reference_tags(session, reference_id=reference_id)
|
||||
return AddTagsResult(added=[], already_present=[], total_tags=total)
|
||||
@ -185,17 +146,10 @@ def add_tags_to_reference(
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
# Preserve the caller's insertion order rather than alphabetizing —
|
||||
# the retrieval ORDER BY added_at + microsecond stagger only meaningfully
|
||||
# preserves insertion order if "the order we insert in" actually matches
|
||||
# the caller's intent.
|
||||
want = set(norm)
|
||||
to_add = [t for t in norm if t not in current]
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
# See set_reference_tags for the rationale behind the per-tag stagger
|
||||
# and the max(existing) seed.
|
||||
base_ts = _next_added_at_base(session, reference_id)
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
@ -204,9 +158,9 @@ def add_tags_to_reference(
|
||||
asset_reference_id=reference_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=base_ts + timedelta(microseconds=i),
|
||||
added_at=get_utc_now(),
|
||||
)
|
||||
for i, t in enumerate(to_add)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
@ -47,50 +47,6 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
|
||||
|
||||
|
||||
def _known_bucket_prefixes() -> set[str]:
|
||||
"""Lowercased model-category names eligible for standalone-prefix
|
||||
expansion. Tags whose first slash segment matches one of these get
|
||||
the bucket inserted as a separate token, so FE filters like
|
||||
``include_tags=models,checkpoints`` keep matching even when the
|
||||
asset lives in a nested subfolder (`models/checkpoints/flux/foo`).
|
||||
|
||||
Bare user labels with slashes whose first segment is not a registered
|
||||
bucket (e.g. ``my-org/team-a``) pass through unchanged.
|
||||
"""
|
||||
try:
|
||||
import folder_paths
|
||||
|
||||
return {
|
||||
name.lower()
|
||||
for name in folder_paths.folder_names_and_paths.keys()
|
||||
if name != "custom_nodes"
|
||||
}
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
|
||||
def expand_bucket_prefixes(tags: list[str]) -> list[str]:
|
||||
"""Insert standalone bucket tokens after any slash-joined tag whose
|
||||
first segment is a registered model category. Preserves caller order
|
||||
and is idempotent (existing bucket tokens are not duplicated).
|
||||
"""
|
||||
if not tags:
|
||||
return list(tags)
|
||||
buckets = _known_bucket_prefixes()
|
||||
if not buckets:
|
||||
return list(tags)
|
||||
seen = set(tags)
|
||||
result: list[str] = []
|
||||
for t in tags:
|
||||
result.append(t)
|
||||
if "/" in t:
|
||||
prefix = t.split("/", 1)[0]
|
||||
if prefix.lower() in buckets and prefix not in seen:
|
||||
result.append(prefix)
|
||||
seen.add(prefix)
|
||||
return result
|
||||
|
||||
|
||||
def validate_blake3_hash(s: str) -> str:
|
||||
"""Validate and normalize a blake3 hash string.
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@ -13,14 +13,13 @@ from app.assets.database.queries import (
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
bulk_insert_tags_and_meta,
|
||||
delete_assets_by_ids,
|
||||
ensure_tags_exist,
|
||||
get_existing_asset_ids,
|
||||
get_reference_ids_by_ids,
|
||||
get_references_by_paths_and_asset_ids,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
restore_references_by_paths,
|
||||
)
|
||||
from app.assets.helpers import expand_bucket_prefixes, get_utc_now
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.assets.services.metadata_extract import ExtractedMetadata
|
||||
@ -234,20 +233,13 @@ def batch_insert_seed_assets(
|
||||
if ref_id not in inserted_ref_ids:
|
||||
continue
|
||||
|
||||
# Stagger added_at by microsecond per tag within a reference so
|
||||
# the retrieval ORDER BY added_at preserves the input list order
|
||||
# (the path-derived root category stays at position 0). Without
|
||||
# this, every tag in a bulk-insert batch shares current_time and
|
||||
# the tag_name tiebreaker sorts them alphabetically — putting the
|
||||
# subpath tag ahead of "models" since "c"/"d"/"l" < "m".
|
||||
ref_tags = expand_bucket_prefixes(ref_data["tags"])
|
||||
for tag_idx, tag in enumerate(ref_tags):
|
||||
for tag in ref_data["tags"]:
|
||||
tag_rows.append(
|
||||
{
|
||||
"asset_reference_id": ref_id,
|
||||
"tag_name": tag,
|
||||
"origin": "automatic",
|
||||
"added_at": current_time + timedelta(microseconds=tag_idx),
|
||||
"added_at": current_time,
|
||||
}
|
||||
)
|
||||
|
||||
@ -269,16 +261,6 @@ def batch_insert_seed_assets(
|
||||
}
|
||||
)
|
||||
|
||||
if tag_rows:
|
||||
# Bucket-prefix expansion may have introduced tags the caller did
|
||||
# not register via the upstream tag_pool (e.g. `checkpoints` for a
|
||||
# nested `checkpoints/flux/foo` path). Pre-register the full set so
|
||||
# the AssetReferenceTag.tag_name FK is satisfied; the underlying
|
||||
# insert is ON CONFLICT DO NOTHING so re-registration is idempotent.
|
||||
ensure_tags_exist(
|
||||
session, {row["tag_name"] for row in tag_rows}, tag_type="user"
|
||||
)
|
||||
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows)
|
||||
|
||||
return BulkInsertResult(
|
||||
|
||||
@ -3,6 +3,7 @@ from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import folder_paths
|
||||
from app.assets.helpers import normalize_tags
|
||||
|
||||
|
||||
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
|
||||
@ -26,51 +27,27 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs).
|
||||
|
||||
Accepts both the legacy one-tag-per-directory shape
|
||||
(``["models", "diffusers", "Kolors", "text_encoder"]``) and the
|
||||
slash-joined shape emitted by :func:`get_name_and_tags_from_asset_path`
|
||||
(``["models", "diffusers/Kolors/text_encoder"]``). Hybrid shapes that
|
||||
mix the two within a single call (e.g.
|
||||
``["models", "diffusers", "Kolors/text_encoder"]``) are also
|
||||
accepted: each entry after ``tags[0]`` is split on ``/`` and
|
||||
concatenated, so the two shapes — and any mix of them — resolve to
|
||||
the same destination. The same safety checks are applied to each
|
||||
component after expansion.
|
||||
"""
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
if not tags:
|
||||
raise ValueError("tags must not be empty")
|
||||
root = tags[0].lower()
|
||||
|
||||
# Expand any slash-joined entries into individual path components so
|
||||
# the rest of the function can treat both tag shapes uniformly. Each
|
||||
# component is also stripped, so " a / b " behaves like ["a", "b"].
|
||||
expanded: list[str] = []
|
||||
for t in tags[1:]:
|
||||
for part in str(t).split("/"):
|
||||
part = part.strip()
|
||||
if part:
|
||||
expanded.append(part)
|
||||
|
||||
if root == "models":
|
||||
if not expanded:
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
category = expanded[0]
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[category][0]
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{category}'")
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{category}'")
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = expanded[1:]
|
||||
raw_subdirs = tags[2:]
|
||||
elif root == "input":
|
||||
base_dir = os.path.abspath(folder_paths.get_input_directory())
|
||||
raw_subdirs = expanded
|
||||
raw_subdirs = tags[1:]
|
||||
elif root == "output":
|
||||
base_dir = os.path.abspath(folder_paths.get_output_directory())
|
||||
raw_subdirs = expanded
|
||||
raw_subdirs = tags[1:]
|
||||
else:
|
||||
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
|
||||
_sep_chars = frozenset(("/", "\\", os.sep))
|
||||
@ -183,21 +160,7 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return (name, tags) derived from a filesystem path.
|
||||
|
||||
- name: base filename with extension
|
||||
- tags: [root_category] for paths with no parent subdirectories,
|
||||
[root_category, slash_joined_subpath] otherwise. The parent subpath
|
||||
(everything between the root category and the filename) is collapsed
|
||||
into a single tag rather than emitted as one tag per directory, so
|
||||
consumers can use ``tags[1]`` as a stable category identifier that
|
||||
survives nested directory layouts (e.g. diffusers components).
|
||||
|
||||
The subpath is lowercased to match the canonicalization applied by
|
||||
:func:`ensure_tags_exist`; without that, the
|
||||
``asset_reference_tags.tag_name`` FK to the lowercased ``tags.name``
|
||||
would fail for any path containing uppercase letters. The root
|
||||
category is lowercase by construction in
|
||||
:func:`get_asset_category_and_relative_path`, so no separate cast
|
||||
is applied here. Consumers that need to look up providers keyed on
|
||||
original-case paths should normalize their lookup key to lowercase.
|
||||
- tags: [root_category] + parent folder names in order
|
||||
|
||||
Raises:
|
||||
ValueError: path does not belong to any known root.
|
||||
@ -207,7 +170,4 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
parent_parts = [
|
||||
part for part in p.parent.parts if part not in (".", "..", p.anchor)
|
||||
]
|
||||
tags = [root_category]
|
||||
if parent_parts:
|
||||
tags.append("/".join(parent_parts).lower())
|
||||
return p.name, list(dict.fromkeys(t.strip() for t in tags if t.strip()))
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
|
||||
@ -62,6 +62,8 @@ def get_comfy_package_versions():
|
||||
def check_comfy_packages_versions():
|
||||
"""Warn for every comfy* package whose installed version is below requirements.txt."""
|
||||
from packaging.version import InvalidVersion, parse as parse_pep440
|
||||
outdated_packages = []
|
||||
|
||||
for pkg in get_comfy_package_versions():
|
||||
installed_str = pkg["installed"]
|
||||
required_str = pkg["required"]
|
||||
@ -73,19 +75,26 @@ def check_comfy_packages_versions():
|
||||
logging.error(f"Failed to check {pkg['name']} version: {e}")
|
||||
continue
|
||||
if outdated:
|
||||
app.logger.log_startup_warning(
|
||||
f"""
|
||||
outdated_packages.append((pkg["name"], installed_str, required_str))
|
||||
else:
|
||||
logging.info("{} version: {}".format(pkg["name"], installed_str))
|
||||
|
||||
if outdated_packages:
|
||||
package_warnings = "\n".join(
|
||||
f"Installed {name} version {installed} is lower than the recommended version {required}."
|
||||
for name, installed, required in outdated_packages
|
||||
)
|
||||
app.logger.log_startup_warning(
|
||||
f"""
|
||||
________________________________________________________________________
|
||||
WARNING WARNING WARNING WARNING WARNING
|
||||
|
||||
Installed {pkg["name"]} version {installed_str} is lower than the recommended version {required_str}.
|
||||
{package_warnings}
|
||||
|
||||
{get_missing_requirements_message()}
|
||||
________________________________________________________________________
|
||||
""".strip()
|
||||
)
|
||||
else:
|
||||
logging.info("{} version: {}".format(pkg["name"], installed_str))
|
||||
)
|
||||
|
||||
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
|
||||
@ -1613,6 +1613,16 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
#use all ModelPatcherDynamic this is ignored and its all done dynamically.
|
||||
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
|
||||
|
||||
def restore_loaded_backups(self):
|
||||
restored = self.model.model_loaded_weight_memory
|
||||
for key in list(self.backup.keys()):
|
||||
bk = self.backup.pop(key)
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
for key in list(self.backup_buffers.keys()):
|
||||
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
return restored
|
||||
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
|
||||
|
||||
@ -1629,7 +1639,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
num_patches = 0
|
||||
allocated_size = 0
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
self.restore_loaded_backups()
|
||||
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
@ -1716,6 +1726,9 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
force_load=True
|
||||
|
||||
if force_load:
|
||||
if hasattr(m, "_v"):
|
||||
comfy_aimdo.model_vbar.vbar_unpin(m._v)
|
||||
delattr(m, "_v")
|
||||
force_load_param(self, "weight", device_to)
|
||||
force_load_param(self, "bias", device_to)
|
||||
else:
|
||||
@ -1773,13 +1786,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||
|
||||
if freed < memory_to_free:
|
||||
for key in list(self.backup.keys()):
|
||||
bk = self.backup.pop(key)
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
for key in list(self.backup_buffers.keys()):
|
||||
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
|
||||
freed += self.model.model_loaded_weight_memory
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
freed += self.restore_loaded_backups()
|
||||
|
||||
return freed
|
||||
|
||||
|
||||
@ -1019,10 +1019,11 @@ def bislerp(samples, width, height):
|
||||
|
||||
def lanczos(samples, width, height):
|
||||
#the below API is strict and expects grayscale to be squeezed
|
||||
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
|
||||
if samples.ndim == 4:
|
||||
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
|
||||
images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
||||
images = [torch.from_numpy(t).movedim(-1, 0) if (t := np.array(image).astype(np.float32) / 255.0).ndim == 3 else torch.from_numpy(t) for image in images]
|
||||
result = torch.stack(images)
|
||||
return result.to(samples.device, samples.dtype)
|
||||
|
||||
|
||||
@ -35,6 +35,19 @@ class AnthropicMessage(BaseModel):
|
||||
content: list[AnthropicTextContent | AnthropicImageContent] = Field(...)
|
||||
|
||||
|
||||
class AnthropicThinkingConfig(BaseModel):
|
||||
type: Literal["enabled", "disabled", "adaptive"] = Field(...)
|
||||
budget_tokens: int | None = Field(
|
||||
None, ge=1024,
|
||||
description="Reasoning budget in tokens. Used when type is 'enabled'. Must be less than max_tokens.",
|
||||
)
|
||||
|
||||
|
||||
class AnthropicOutputConfig(BaseModel):
|
||||
"""Used with `thinking.type='adaptive'` on models like Opus 4.7."""
|
||||
effort: Literal["low", "medium", "high"] | None = Field(None)
|
||||
|
||||
|
||||
class AnthropicMessagesRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
messages: list[AnthropicMessage] = Field(...)
|
||||
@ -44,6 +57,8 @@ class AnthropicMessagesRequest(BaseModel):
|
||||
top_p: float | None = Field(None, ge=0.0, le=1.0)
|
||||
top_k: int | None = Field(None, ge=0)
|
||||
stop_sequences: list[str] | None = Field(None)
|
||||
thinking: AnthropicThinkingConfig | None = Field(None)
|
||||
output_config: AnthropicOutputConfig | None = Field(None)
|
||||
|
||||
|
||||
class AnthropicResponseTextBlock(BaseModel):
|
||||
@ -51,6 +66,14 @@ class AnthropicResponseTextBlock(BaseModel):
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class AnthropicResponseThinkingBlock(BaseModel):
|
||||
type: Literal["thinking"] = "thinking"
|
||||
thinking: str = Field(...)
|
||||
|
||||
|
||||
AnthropicResponseBlock = AnthropicResponseTextBlock | AnthropicResponseThinkingBlock
|
||||
|
||||
|
||||
class AnthropicCacheCreationUsage(BaseModel):
|
||||
ephemeral_5m_input_tokens: int | None = Field(None)
|
||||
ephemeral_1h_input_tokens: int | None = Field(None)
|
||||
@ -69,7 +92,7 @@ class AnthropicMessagesResponse(BaseModel):
|
||||
type: str | None = Field(None)
|
||||
role: str | None = Field(None)
|
||||
model: str | None = Field(None)
|
||||
content: list[AnthropicResponseTextBlock] | None = Field(None)
|
||||
content: list[AnthropicResponseBlock] | None = Field(None)
|
||||
stop_reason: str | None = Field(None)
|
||||
stop_sequence: str | None = Field(None)
|
||||
usage: AnthropicMessagesUsage | None = Field(None)
|
||||
|
||||
93
comfy_api_nodes/apis/openrouter.py
Normal file
93
comfy_api_nodes/apis/openrouter.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Pydantic models for the OpenRouter chat completions API.
|
||||
|
||||
See: https://openrouter.ai/docs/api/api-reference/chat/send-chat-completion-request
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OpenRouterTextContent(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class OpenRouterImageUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class OpenRouterImageContent(BaseModel):
|
||||
type: Literal["image_url"] = "image_url"
|
||||
image_url: OpenRouterImageUrl = Field(...)
|
||||
|
||||
|
||||
class OpenRouterVideoUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class OpenRouterVideoContent(BaseModel):
|
||||
type: Literal["video_url"] = "video_url"
|
||||
video_url: OpenRouterVideoUrl = Field(...)
|
||||
|
||||
|
||||
OpenRouterContentBlock = OpenRouterTextContent | OpenRouterImageContent | OpenRouterVideoContent
|
||||
|
||||
|
||||
class OpenRouterMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant"] = Field(...)
|
||||
content: str | list[OpenRouterContentBlock] = Field(...)
|
||||
|
||||
|
||||
class OpenRouterReasoningConfig(BaseModel):
|
||||
effort: str | None = Field(None)
|
||||
exclude: bool | None = Field(None, description="If true, model reasons but reasoning is excluded from response.")
|
||||
|
||||
|
||||
class OpenRouterWebSearchOptions(BaseModel):
|
||||
search_context_size: str | None = Field(None)
|
||||
|
||||
|
||||
class OpenRouterChatRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
messages: list[OpenRouterMessage] = Field(...)
|
||||
seed: int | None = Field(None)
|
||||
reasoning: OpenRouterReasoningConfig | None = Field(None)
|
||||
web_search_options: OpenRouterWebSearchOptions | None = Field(None)
|
||||
stream: bool = Field(False)
|
||||
|
||||
|
||||
class OpenRouterUsage(BaseModel):
|
||||
prompt_tokens: int | None = Field(None)
|
||||
completion_tokens: int | None = Field(None)
|
||||
total_tokens: int | None = Field(None)
|
||||
cost: float | None = Field(None, description="Server-side authoritative USD cost of the call.")
|
||||
|
||||
|
||||
class OpenRouterResponseMessage(BaseModel):
|
||||
role: str | None = Field(None)
|
||||
content: str | None = Field(None)
|
||||
reasoning: str | None = Field(None)
|
||||
refusal: str | None = Field(None)
|
||||
|
||||
|
||||
class OpenRouterChoice(BaseModel):
|
||||
index: int | None = Field(None)
|
||||
message: OpenRouterResponseMessage | None = Field(None)
|
||||
finish_reason: str | None = Field(None)
|
||||
|
||||
|
||||
class OpenRouterError(BaseModel):
|
||||
code: int | str | None = Field(None)
|
||||
message: str | None = Field(None)
|
||||
metadata: dict | None = Field(None)
|
||||
|
||||
|
||||
class OpenRouterChatResponse(BaseModel):
|
||||
id: str | None = Field(None)
|
||||
model: str | None = Field(None)
|
||||
object: str | None = Field(None)
|
||||
provider: str | None = Field(None)
|
||||
choices: list[OpenRouterChoice] | None = Field(None)
|
||||
usage: OpenRouterUsage | None = Field(None)
|
||||
error: OpenRouterError | None = Field(None)
|
||||
@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@ -11,44 +9,76 @@ class Rodin3DGenerateRequest(BaseModel):
|
||||
material: str = Field(..., description="The material type.")
|
||||
quality_override: int = Field(..., description="The poly count of the mesh.")
|
||||
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
|
||||
TAPose: Optional[bool] = Field(None, description="")
|
||||
TAPose: bool | None = Field(None, description="")
|
||||
|
||||
|
||||
class Rodin3DGen25Request(BaseModel):
|
||||
|
||||
tier: str = Field(..., description="Gen-2.5 tier (e.g. Gen-2.5-High).")
|
||||
prompt: str | None = Field(None, description="Required for Text-to-3D; ignored otherwise.")
|
||||
seed: int | None = Field(None, description="0-65535.")
|
||||
material: str | None = Field(None, description="PBR | Shaded | All | None.")
|
||||
geometry_file_format: str | None = Field(None, description="glb | usdz | fbx | obj | stl.")
|
||||
texture_mode: str | None = Field(None, description="legacy | extreme-low | low | medium | high.")
|
||||
mesh_mode: str | None = Field(None, description="Raw (triangular) | Quad.")
|
||||
quality_override: int | None = Field(None, description="Mesh face count override.")
|
||||
geometry_instruct_mode: str | None = Field(None, description="faithful | creative.")
|
||||
bbox_condition: list[int] | None = Field(None, description="Bounding box [Width(Y), Height(Z), Length(X)] in cm.")
|
||||
height: int | None = Field(None, description="Approximate model height in cm.")
|
||||
TAPose: bool | None = Field(None, description="T/A pose for human-like models.")
|
||||
hd_texture: bool | None = Field(None, description="Enhanced texture quality.")
|
||||
texture_delight: bool | None = Field(None, description="Remove baked lighting from textures.")
|
||||
is_micro: bool | None = Field(None, description="Micro detail (Extreme-High only).")
|
||||
use_original_alpha: bool | None = Field(None, description="Preserve image transparency.")
|
||||
preview_render: bool | None = Field(None, description="Generate high-quality preview render.")
|
||||
addons: list[str] | None = Field(None, description='Optional addons, e.g. ["HighPack"].')
|
||||
|
||||
|
||||
class GenerateJobsData(BaseModel):
|
||||
uuids: List[str] = Field(..., description="str LIST")
|
||||
uuids: list[str] = Field(..., description="str LIST")
|
||||
subscription_key: str = Field(..., description="subscription key")
|
||||
|
||||
|
||||
class Rodin3DGenerateResponse(BaseModel):
|
||||
message: Optional[str] = Field(None, description="Return message.")
|
||||
prompt: Optional[str] = Field(None, description="Generated Prompt from image.")
|
||||
submit_time: Optional[str] = Field(None, description="Submit Time")
|
||||
uuid: Optional[str] = Field(None, description="Task str")
|
||||
jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs")
|
||||
message: str | None = Field(None, description="Return message.")
|
||||
prompt: str | None = Field(None, description="Generated Prompt from image.")
|
||||
submit_time: str | None = Field(None, description="Submit Time")
|
||||
uuid: str | None = Field(None, description="Task str")
|
||||
jobs: GenerateJobsData | None = Field(None, description="Details of jobs")
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
"""
|
||||
Status for jobs
|
||||
"""
|
||||
|
||||
Done = "Done"
|
||||
Failed = "Failed"
|
||||
Generating = "Generating"
|
||||
Waiting = "Waiting"
|
||||
|
||||
|
||||
class Rodin3DCheckStatusRequest(BaseModel):
|
||||
subscription_key: str = Field(..., description="subscription from generate endpoint")
|
||||
|
||||
|
||||
class JobItem(BaseModel):
|
||||
uuid: str = Field(..., description="uuid")
|
||||
status: JobStatus = Field(...,description="Status Currently")
|
||||
status: JobStatus = Field(..., description="Status Currently")
|
||||
|
||||
|
||||
class Rodin3DCheckStatusResponse(BaseModel):
|
||||
jobs: List[JobItem] = Field(..., description="Job status List")
|
||||
jobs: list[JobItem] = Field(..., description="Job status List")
|
||||
|
||||
|
||||
class Rodin3DDownloadRequest(BaseModel):
|
||||
task_uuid: str = Field(..., description="Task str")
|
||||
|
||||
|
||||
class RodinResourceItem(BaseModel):
|
||||
url: str = Field(..., description="Download Url")
|
||||
name: str = Field(..., description="File name with ext")
|
||||
|
||||
|
||||
class Rodin3DDownloadResponse(BaseModel):
|
||||
list: List[RodinResourceItem] = Field(..., description="Source List")
|
||||
items: list[RodinResourceItem] = Field(..., alias="list", description="Source List")
|
||||
|
||||
@ -9,8 +9,11 @@ from comfy_api_nodes.apis.anthropic import (
|
||||
AnthropicMessage,
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicMessagesResponse,
|
||||
AnthropicOutputConfig,
|
||||
AnthropicResponseTextBlock,
|
||||
AnthropicRole,
|
||||
AnthropicTextContent,
|
||||
AnthropicThinkingConfig,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
@ -32,15 +35,29 @@ CLAUDE_MODELS: dict[str, str] = {
|
||||
"Haiku 4.5": "claude-haiku-4-5-20251001",
|
||||
}
|
||||
|
||||
_THINKING_UNSUPPORTED = {"Haiku 4.5"}
|
||||
# Models that use the newer "adaptive" thinking mode (Opus 4.7 requires it; older models keep the explicit budget API).
|
||||
# Anthropic decides the actual budget when adaptive is used, based on the `output_config.effort` hint.
|
||||
_ADAPTIVE_THINKING_MODELS = {"Opus 4.7", "Opus 4.6", "Sonnet 4.6"}
|
||||
|
||||
def _claude_model_inputs():
|
||||
return [
|
||||
# Budget mode (Sonnet 4.5): effort -> reasoning budget in tokens. Must be < max_tokens.
|
||||
# Sized so even the "high" budget fits comfortably under the default max_tokens=32768.
|
||||
_REASONING_BUDGET: dict[str, int] = {
|
||||
"low": 2048,
|
||||
"medium": 8192,
|
||||
"high": 16384,
|
||||
}
|
||||
_REASONING_EFFORTS = ["off", "low", "medium", "high"]
|
||||
|
||||
|
||||
def _claude_model_inputs(model_label: str):
|
||||
inputs: list = [
|
||||
IO.Int.Input(
|
||||
"max_tokens",
|
||||
default=16000,
|
||||
min=32,
|
||||
max=32000,
|
||||
tooltip="Maximum number of tokens to generate before stopping.",
|
||||
default=32768,
|
||||
min=4096,
|
||||
max=64000,
|
||||
tooltip="Maximum number of tokens to generate (includes reasoning tokens when enabled).",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
@ -49,10 +66,24 @@ def _claude_model_inputs():
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Controls randomness. 0.0 is deterministic, 1.0 is most random. Ignored for Opus 4.7.",
|
||||
tooltip=(
|
||||
"Controls randomness. 0.0 is deterministic, 1.0 is most random. "
|
||||
"Ignored for Opus 4.7 and any model when reasoning_effort is set."
|
||||
),
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
if model_label not in _THINKING_UNSUPPORTED:
|
||||
inputs.append(
|
||||
IO.Combo.Input(
|
||||
"reasoning_effort",
|
||||
options=_REASONING_EFFORTS,
|
||||
default="off",
|
||||
tooltip="Extended thinking effort. 'off' disables reasoning.",
|
||||
advanced=True,
|
||||
)
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
def _model_price_per_million(model: str) -> tuple[float, float] | None:
|
||||
@ -95,7 +126,11 @@ def calculate_tokens_price(response: AnthropicMessagesResponse) -> float | None:
|
||||
def _get_text_from_response(response: AnthropicMessagesResponse) -> str:
|
||||
if not response.content:
|
||||
return ""
|
||||
return "\n".join(block.text for block in response.content if block.text)
|
||||
# Thinking blocks are silently dropped — we never want reasoning in the output.
|
||||
return "\n".join(
|
||||
block.text for block in response.content
|
||||
if isinstance(block, AnthropicResponseTextBlock) and block.text
|
||||
)
|
||||
|
||||
|
||||
async def _build_image_content_blocks(
|
||||
@ -133,7 +168,10 @@ class ClaudeNode(IO.ComfyNode):
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[IO.DynamicCombo.Option(label, _claude_model_inputs()) for label in CLAUDE_MODELS],
|
||||
options=[
|
||||
IO.DynamicCombo.Option(label, _claude_model_inputs(label))
|
||||
for label in CLAUDE_MODELS
|
||||
],
|
||||
tooltip="The Claude model used to generate the response.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
@ -207,8 +245,29 @@ class ClaudeNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_label = model["model"]
|
||||
max_tokens = model["max_tokens"]
|
||||
temperature = None if model_label == "Opus 4.7" else model["temperature"]
|
||||
max_tokens = model.get("max_tokens", 32768)
|
||||
reasoning_effort = model.get("reasoning_effort", "off")
|
||||
thinking_enabled = reasoning_effort not in ("off", None) and model_label not in _THINKING_UNSUPPORTED
|
||||
|
||||
# Anthropic requires temperature to be unset (defaults to 1.0) when thinking is enabled.
|
||||
# Opus 4.7 also rejects user-supplied temperature.
|
||||
if thinking_enabled or model_label == "Opus 4.7":
|
||||
temperature = None
|
||||
else:
|
||||
temperature = model.get("temperature", 1.0)
|
||||
|
||||
thinking_cfg: AnthropicThinkingConfig | None = None
|
||||
output_cfg: AnthropicOutputConfig | None = None
|
||||
if thinking_enabled:
|
||||
if model_label in _ADAPTIVE_THINKING_MODELS:
|
||||
# Adaptive mode - Anthropic chooses the budget based on effort hint
|
||||
thinking_cfg = AnthropicThinkingConfig(type="adaptive")
|
||||
output_cfg = AnthropicOutputConfig(effort=reasoning_effort)
|
||||
else:
|
||||
# Budget mode (Sonnet 4.5). Leave at least 1024 tokens for the actual response
|
||||
budget = _REASONING_BUDGET[reasoning_effort]
|
||||
budget = min(budget, max(1024, max_tokens - 1024))
|
||||
thinking_cfg = AnthropicThinkingConfig(type="enabled", budget_tokens=budget)
|
||||
|
||||
image_tensors: list[Input.Image] = [t for t in (images or {}).values() if t is not None]
|
||||
if sum(get_number_of_images(t) for t in image_tensors) > CLAUDE_MAX_IMAGES:
|
||||
@ -229,6 +288,8 @@ class ClaudeNode(IO.ComfyNode):
|
||||
messages=[AnthropicMessage(role=AnthropicRole.user, content=content)],
|
||||
system=system_prompt or None,
|
||||
temperature=temperature,
|
||||
thinking=thinking_cfg,
|
||||
output_config=output_cfg,
|
||||
),
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
|
||||
@ -43,15 +43,16 @@ from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
downscale_video_to_max_pixels,
|
||||
get_number_of_images,
|
||||
image_tensor_pair_to_batch,
|
||||
poll_op,
|
||||
resize_video_to_pixel_budget,
|
||||
sync_op,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
upscale_video_to_min_pixels,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
@ -110,12 +111,13 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: st
|
||||
max_px = limits.get("max")
|
||||
if min_px and pixels < min_px:
|
||||
raise ValueError(
|
||||
f"Reference video {index} is too small: {w}x{h} = {pixels:,}px. " f"Minimum is {min_px:,}px for this model."
|
||||
f"Reference video {index} is too small: {w}x{h} = {pixels:,} total pixels. "
|
||||
f"Minimum for this model is {min_px:,} total pixels."
|
||||
)
|
||||
if max_px and pixels > max_px:
|
||||
raise ValueError(
|
||||
f"Reference video {index} is too large: {w}x{h} = {pixels:,}px. "
|
||||
f"Maximum is {max_px:,}px for this model. Try downscaling the video."
|
||||
f"Reference video {index} is too large: {w}x{h} = {pixels:,} total pixels. "
|
||||
f"Maximum for this model is {max_px:,} total pixels. Try downscaling the video."
|
||||
)
|
||||
|
||||
|
||||
@ -1676,14 +1678,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
"first_frame_asset_id",
|
||||
default="",
|
||||
tooltip="Seedance asset_id to use as the first frame. "
|
||||
"Mutually exclusive with the first_frame image input.",
|
||||
"Mutually exclusive with the first_frame image input.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"last_frame_asset_id",
|
||||
default="",
|
||||
tooltip="Seedance asset_id to use as the last frame. "
|
||||
"Mutually exclusive with the last_frame image input.",
|
||||
"Mutually exclusive with the last_frame image input.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
@ -1865,11 +1867,20 @@ def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16
|
||||
IO.Boolean.Input(
|
||||
"auto_downscale",
|
||||
default=False,
|
||||
advanced=True,
|
||||
optional=True,
|
||||
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
|
||||
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"auto_upscale",
|
||||
default=False,
|
||||
advanced=True,
|
||||
optional=True,
|
||||
tooltip="Automatically upscale reference videos that are below the model's minimum pixel count "
|
||||
"for the selected resolution. Aspect ratio is preserved; videos already meeting the minimum are "
|
||||
"untouched. Note: upscaling a low-resolution source does not add real detail and may produce "
|
||||
"lower-quality generations.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"reference_assets",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
@ -2030,7 +2041,13 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
max_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("max")
|
||||
if max_px:
|
||||
for key in reference_videos:
|
||||
reference_videos[key] = resize_video_to_pixel_budget(reference_videos[key], max_px)
|
||||
reference_videos[key] = downscale_video_to_max_pixels(reference_videos[key], max_px)
|
||||
|
||||
if model.get("auto_upscale") and reference_videos:
|
||||
min_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("min")
|
||||
if min_px:
|
||||
for key in reference_videos:
|
||||
reference_videos[key] = upscale_video_to_min_pixels(reference_videos[key], min_px)
|
||||
|
||||
total_video_duration = 0.0
|
||||
for i, key in enumerate(reference_videos, 1):
|
||||
|
||||
374
comfy_api_nodes/nodes_openrouter.py
Normal file
374
comfy_api_nodes/nodes_openrouter.py
Normal file
@ -0,0 +1,374 @@
|
||||
"""API Nodes for OpenRouter LLM chat completions."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.openrouter import (
|
||||
OpenRouterChatRequest,
|
||||
OpenRouterChatResponse,
|
||||
OpenRouterContentBlock,
|
||||
OpenRouterImageContent,
|
||||
OpenRouterImageUrl,
|
||||
OpenRouterMessage,
|
||||
OpenRouterReasoningConfig,
|
||||
OpenRouterTextContent,
|
||||
OpenRouterVideoContent,
|
||||
OpenRouterVideoUrl,
|
||||
OpenRouterWebSearchOptions,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
OPENROUTER_CHAT_ENDPOINT = "/proxy/openrouter/api/v1/chat/completions"
|
||||
|
||||
|
||||
Profile = Literal["standard", "reasoning", "frontier_reasoning", "perplexity", "perplexity_reasoning"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ModelSpec:
|
||||
slug: str # exact OpenRouter model id
|
||||
profile: Profile
|
||||
price_in: float # USD per token (prompt)
|
||||
price_out: float # USD per token (completion)
|
||||
max_images: int = 0 # 0 = no image input; otherwise max URL-passed images supported
|
||||
max_videos: int = 0 # 0 = no video input; otherwise max URL-passed videos supported
|
||||
|
||||
|
||||
MODELS: list[_ModelSpec] = [
|
||||
_ModelSpec("anthropic/claude-opus-4.7", "frontier_reasoning", 0.000005, 0.000025, max_images=20),
|
||||
_ModelSpec("openai/gpt-5.5-pro", "frontier_reasoning", 0.00003, 0.00018, max_images=20),
|
||||
_ModelSpec("openai/gpt-5.5", "frontier_reasoning", 0.000005, 0.00003, max_images=20),
|
||||
_ModelSpec("google/gemini-3.5-flash", "reasoning", 0.0000015, 0.000009, max_images=20, max_videos=4),
|
||||
_ModelSpec("x-ai/grok-4.20", "reasoning", 0.00000125, 0.0000025, max_images=20),
|
||||
_ModelSpec("x-ai/grok-4.3", "reasoning", 0.00000125, 0.0000025, max_images=20),
|
||||
_ModelSpec("deepseek/deepseek-v4-pro", "reasoning", 0.000000435, 0.00000087),
|
||||
_ModelSpec("deepseek/deepseek-v4-flash", "reasoning", 0.000000112, 0.000000224),
|
||||
_ModelSpec("deepseek/deepseek-v3.2", "reasoning", 0.000000252, 0.000000378),
|
||||
_ModelSpec("qwen/qwen3.6-max-preview", "reasoning", 0.00000104, 0.00000624),
|
||||
_ModelSpec("qwen/qwen3.6-plus", "reasoning", 0.000000325, 0.00000195, max_images=10, max_videos=4),
|
||||
_ModelSpec("qwen/qwen3.6-flash", "reasoning", 0.0000001875, 0.000001125, max_images=10, max_videos=4),
|
||||
_ModelSpec("mistralai/mistral-large-2512", "standard", 0.0000005, 0.0000015, max_images=8),
|
||||
_ModelSpec("mistralai/mistral-medium-3-5", "reasoning", 0.0000015, 0.0000075, max_images=8),
|
||||
_ModelSpec("z-ai/glm-4.6", "reasoning", 0.00000043, 0.00000174),
|
||||
_ModelSpec("z-ai/glm-5", "reasoning", 0.0000006, 0.00000192),
|
||||
_ModelSpec("moonshotai/kimi-k2.6", "reasoning", 0.00000073, 0.00000349, max_images=10),
|
||||
_ModelSpec("moonshotai/kimi-k2-thinking", "reasoning", 0.0000006, 0.0000025),
|
||||
_ModelSpec("perplexity/sonar-pro", "perplexity", 0.000003, 0.000015),
|
||||
_ModelSpec("perplexity/sonar-reasoning-pro", "perplexity_reasoning", 0.000002, 0.000008),
|
||||
_ModelSpec("perplexity/sonar-deep-research", "perplexity_reasoning", 0.000002, 0.000008),
|
||||
]
|
||||
|
||||
_MODELS_BY_SLUG: dict[str, _ModelSpec] = {m.slug: m for m in MODELS}
|
||||
_REASONING_EFFORTS = ["off", "low", "medium", "high"]
|
||||
_SEARCH_CONTEXT_SIZES = ["low", "medium", "high"]
|
||||
|
||||
|
||||
def _reasoning_extra_inputs() -> list:
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"reasoning_effort",
|
||||
options=_REASONING_EFFORTS,
|
||||
default="off",
|
||||
tooltip="Reasoning effort. 'off' disables reasoning entirely.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _perplexity_extra_inputs() -> list:
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"search_context_size",
|
||||
options=_SEARCH_CONTEXT_SIZES,
|
||||
default="medium",
|
||||
tooltip="How much web search context to retrieve. Larger = more grounded but slower/pricier.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _profile_inputs(profile: Profile) -> list:
|
||||
if profile == "standard":
|
||||
return []
|
||||
if profile in ("reasoning", "frontier_reasoning"):
|
||||
return _reasoning_extra_inputs()
|
||||
if profile == "perplexity":
|
||||
return _perplexity_extra_inputs()
|
||||
if profile == "perplexity_reasoning":
|
||||
return _perplexity_extra_inputs() + _reasoning_extra_inputs()
|
||||
raise ValueError(f"Unknown profile: {profile}")
|
||||
|
||||
|
||||
def _media_inputs(spec: _ModelSpec) -> list:
|
||||
extras: list = []
|
||||
if spec.max_images > 0:
|
||||
extras.append(
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, spec.max_images + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional reference image(s) — up to {spec.max_images}. Sent as URLs.",
|
||||
)
|
||||
)
|
||||
if spec.max_videos > 0:
|
||||
extras.append(
|
||||
IO.Autogrow.Input(
|
||||
"videos",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Video.Input("video"),
|
||||
names=[f"video_{i}" for i in range(1, spec.max_videos + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional reference video(s) — up to {spec.max_videos}. Sent as URLs.",
|
||||
)
|
||||
)
|
||||
return extras
|
||||
|
||||
|
||||
def _inputs_for_model(spec: _ModelSpec) -> list:
|
||||
return _profile_inputs(spec.profile) + _media_inputs(spec)
|
||||
|
||||
|
||||
def _build_model_options() -> list[IO.DynamicCombo.Option]:
|
||||
return [IO.DynamicCombo.Option(spec.slug, _inputs_for_model(spec)) for spec in MODELS]
|
||||
|
||||
|
||||
def _calculate_price(response: OpenRouterChatResponse) -> float | None:
|
||||
if response.usage and response.usage.cost is not None:
|
||||
return float(response.usage.cost)
|
||||
return None
|
||||
|
||||
|
||||
def _price_badge_jsonata() -> str:
|
||||
rates_pairs = []
|
||||
for spec in MODELS:
|
||||
prompt_per_1k = spec.price_in * 1000
|
||||
completion_per_1k = spec.price_out * 1000
|
||||
rates_pairs.append(f' "{spec.slug}": [{prompt_per_1k:.8g}, {completion_per_1k:.8g}]')
|
||||
rates_block = ",\n".join(rates_pairs)
|
||||
return (
|
||||
"(\n"
|
||||
" $rates := {\n"
|
||||
f"{rates_block}\n"
|
||||
" };\n"
|
||||
" $r := $lookup($rates, widgets.model);\n"
|
||||
" $r ? {\n"
|
||||
' "type": "list_usd",\n'
|
||||
' "usd": $r,\n'
|
||||
' "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }\n'
|
||||
' } : {"type": "text", "text": "Token-based"}\n'
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
async def _build_image_blocks(
|
||||
cls: type[IO.ComfyNode], spec: _ModelSpec, images: list[Input.Image]
|
||||
) -> list[OpenRouterImageContent]:
|
||||
urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
images,
|
||||
max_images=spec.max_images,
|
||||
total_pixels=2048 * 2048,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading reference images",
|
||||
)
|
||||
return [OpenRouterImageContent(image_url=OpenRouterImageUrl(url=url)) for url in urls]
|
||||
|
||||
|
||||
async def _build_video_blocks(cls: type[IO.ComfyNode], videos: list[Input.Video]) -> list[OpenRouterVideoContent]:
|
||||
blocks: list[OpenRouterVideoContent] = []
|
||||
total = len(videos)
|
||||
for idx, video in enumerate(videos):
|
||||
label = "Uploading reference video"
|
||||
if total > 1:
|
||||
label = f"{label} ({idx + 1}/{total})"
|
||||
url = await upload_video_to_comfyapi(cls, video, wait_label=label)
|
||||
blocks.append(OpenRouterVideoContent(video_url=OpenRouterVideoUrl(url=url)))
|
||||
return blocks
|
||||
|
||||
|
||||
def _user_message(prompt: str, media_blocks: list[OpenRouterContentBlock]) -> OpenRouterMessage:
|
||||
if not media_blocks:
|
||||
return OpenRouterMessage(role="user", content=prompt)
|
||||
blocks: list[OpenRouterContentBlock] = list(media_blocks)
|
||||
blocks.append(OpenRouterTextContent(text=prompt))
|
||||
return OpenRouterMessage(role="user", content=blocks)
|
||||
|
||||
|
||||
def _build_messages(
|
||||
system_prompt: str, prompt: str, media_blocks: list[OpenRouterContentBlock]
|
||||
) -> list[OpenRouterMessage]:
|
||||
messages: list[OpenRouterMessage] = []
|
||||
if system_prompt:
|
||||
messages.append(OpenRouterMessage(role="system", content=system_prompt))
|
||||
messages.append(_user_message(prompt, media_blocks))
|
||||
return messages
|
||||
|
||||
|
||||
def _build_request(
|
||||
slug: str,
|
||||
system_prompt: str,
|
||||
prompt: str,
|
||||
media_blocks: list[OpenRouterContentBlock],
|
||||
*,
|
||||
seed: int,
|
||||
reasoning_effort: str | None,
|
||||
search_context_size: str | None,
|
||||
) -> OpenRouterChatRequest:
|
||||
reasoning_cfg: OpenRouterReasoningConfig | None = None
|
||||
if reasoning_effort and reasoning_effort != "off":
|
||||
# exclude=True asks providers to reason internally but not return the trace
|
||||
reasoning_cfg = OpenRouterReasoningConfig(effort=reasoning_effort, exclude=True)
|
||||
web_search_cfg: OpenRouterWebSearchOptions | None = None
|
||||
if search_context_size:
|
||||
web_search_cfg = OpenRouterWebSearchOptions(search_context_size=search_context_size)
|
||||
return OpenRouterChatRequest(
|
||||
model=slug,
|
||||
messages=_build_messages(system_prompt, prompt, media_blocks),
|
||||
seed=seed if seed > 0 else None,
|
||||
reasoning=reasoning_cfg,
|
||||
web_search_options=web_search_cfg,
|
||||
)
|
||||
|
||||
|
||||
def _extract_text(response: OpenRouterChatResponse) -> str:
|
||||
if response.error:
|
||||
code = response.error.code if response.error.code is not None else "unknown"
|
||||
raise ValueError(f"OpenRouter error ({code}): {response.error.message or 'no message'}")
|
||||
if not response.choices:
|
||||
raise ValueError("Empty response from OpenRouter (no choices).")
|
||||
message = response.choices[0].message
|
||||
if not message:
|
||||
raise ValueError("Empty response from OpenRouter (no message).")
|
||||
if message.refusal:
|
||||
raise ValueError(f"Model refused to respond: {message.refusal}")
|
||||
return message.content or ""
|
||||
|
||||
|
||||
class OpenRouterLLMNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="OpenRouterLLMNode",
|
||||
display_name="OpenRouter LLM",
|
||||
category="api node/text/OpenRouter",
|
||||
essentials_category="Text Generation",
|
||||
description=(
|
||||
"Generate text responses through OpenRouter. Routes to a curated set of popular "
|
||||
"models from xAI, DeepSeek, Qwen, Mistral, Z.AI (GLM), Moonshot (Kimi), and "
|
||||
"Perplexity Sonar."
|
||||
),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text input to the model.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=_build_model_options(),
|
||||
tooltip="The OpenRouter model used to generate the response.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for sampling. Set to 0 to omit. Most models treat this as a hint only.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Foundational instructions that dictate the model's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.String.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr=_price_badge_jsonata(),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
slug: str = model["model"]
|
||||
spec = _MODELS_BY_SLUG.get(slug)
|
||||
if spec is None:
|
||||
raise ValueError(f"Unknown OpenRouter model: {slug}")
|
||||
|
||||
reasoning_effort: str | None = model.get("reasoning_effort")
|
||||
search_context_size: str | None = model.get("search_context_size")
|
||||
|
||||
image_tensors: list[Input.Image] = [t for t in (model.get("images") or {}).values() if t is not None]
|
||||
if image_tensors and sum(get_number_of_images(t) for t in image_tensors) > spec.max_images:
|
||||
raise ValueError(f"Up to {spec.max_images} images are supported for {slug}.")
|
||||
video_inputs: list[Input.Video] = [v for v in (model.get("videos") or {}).values() if v is not None]
|
||||
if video_inputs and len(video_inputs) > spec.max_videos:
|
||||
raise ValueError(f"Up to {spec.max_videos} videos are supported for {slug}.")
|
||||
|
||||
media_blocks: list[OpenRouterContentBlock] = []
|
||||
if image_tensors:
|
||||
media_blocks.extend(await _build_image_blocks(cls, spec, image_tensors))
|
||||
if video_inputs:
|
||||
media_blocks.extend(await _build_video_blocks(cls, video_inputs))
|
||||
|
||||
request = _build_request(
|
||||
slug,
|
||||
system_prompt,
|
||||
prompt,
|
||||
media_blocks,
|
||||
seed=seed,
|
||||
reasoning_effort=reasoning_effort,
|
||||
search_context_size=search_context_size,
|
||||
)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=OPENROUTER_CHAT_ENDPOINT, method="POST"),
|
||||
response_model=OpenRouterChatResponse,
|
||||
data=request,
|
||||
price_extractor=_calculate_price,
|
||||
)
|
||||
return IO.NodeOutput(_extract_text(response))
|
||||
|
||||
|
||||
class OpenRouterExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [OpenRouterLLMNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> OpenRouterExtension:
|
||||
return OpenRouterExtension()
|
||||
@ -5,32 +5,37 @@ Rodin API docs: https://developer.hyper3d.ai/
|
||||
|
||||
"""
|
||||
|
||||
from inspect import cleandoc
|
||||
import folder_paths as comfy_paths
|
||||
import os
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from inspect import cleandoc
|
||||
from io import BytesIO
|
||||
from typing_extensions import override
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from PIL import Image
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths as comfy_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Types
|
||||
from comfy_api_nodes.apis.rodin import (
|
||||
Rodin3DGenerateRequest,
|
||||
Rodin3DGenerateResponse,
|
||||
JobStatus,
|
||||
Rodin3DCheckStatusRequest,
|
||||
Rodin3DCheckStatusResponse,
|
||||
Rodin3DDownloadRequest,
|
||||
Rodin3DDownloadResponse,
|
||||
JobStatus,
|
||||
Rodin3DGen25Request,
|
||||
Rodin3DGenerateRequest,
|
||||
Rodin3DGenerateResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
sync_op,
|
||||
poll_op,
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
poll_op,
|
||||
sync_op,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
|
||||
|
||||
COMMON_PARAMETERS = [
|
||||
IO.Int.Input(
|
||||
@ -51,40 +56,30 @@ COMMON_PARAMETERS = [
|
||||
]
|
||||
|
||||
|
||||
def get_quality_mode(poly_count):
|
||||
polycount = poly_count.split("-")
|
||||
poly = polycount[1]
|
||||
count = polycount[0]
|
||||
if poly == "Triangle":
|
||||
mesh_mode = "Raw"
|
||||
elif poly == "Quad":
|
||||
mesh_mode = "Quad"
|
||||
else:
|
||||
mesh_mode = "Quad"
|
||||
|
||||
if count == "4K":
|
||||
quality_override = 4000
|
||||
elif count == "8K":
|
||||
quality_override = 8000
|
||||
elif count == "18K":
|
||||
quality_override = 18000
|
||||
elif count == "50K":
|
||||
quality_override = 50000
|
||||
elif count == "2K":
|
||||
quality_override = 2000
|
||||
elif count == "20K":
|
||||
quality_override = 20000
|
||||
elif count == "150K":
|
||||
quality_override = 150000
|
||||
elif count == "500K":
|
||||
quality_override = 500000
|
||||
else:
|
||||
quality_override = 18000
|
||||
|
||||
return mesh_mode, quality_override
|
||||
_QUALITY_MESH_OPTIONS: dict[str, tuple[str, int]] = {
|
||||
"4K-Quad": ("Quad", 4000),
|
||||
"8K-Quad": ("Quad", 8000),
|
||||
"18K-Quad": ("Quad", 18000),
|
||||
"50K-Quad": ("Quad", 50000),
|
||||
"200K-Quad": ("Quad", 200000),
|
||||
"2K-Triangle": ("Raw", 2000),
|
||||
"20K-Triangle": ("Raw", 20000),
|
||||
"150K-Triangle": ("Raw", 150000),
|
||||
"200K-Triangle": ("Raw", 200000),
|
||||
"500K-Triangle": ("Raw", 500000),
|
||||
"1M-Triangle": ("Raw", 1000000),
|
||||
}
|
||||
|
||||
|
||||
def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
|
||||
def get_quality_mode(poly_count: str) -> tuple[str, int]:
|
||||
"""Map a polygon-count preset like '18K-Quad' to (mesh_mode, quality_override).
|
||||
|
||||
Falls back to ('Quad', 18000) for unknown labels; legacy parity.
|
||||
"""
|
||||
return _QUALITY_MESH_OPTIONS.get(poly_count, ("Quad", 18000))
|
||||
|
||||
|
||||
def tensor_to_filelike(tensor, max_pixels: int = 2048 * 2048):
|
||||
"""
|
||||
Converts a PyTorch tensor to a file-like object.
|
||||
|
||||
@ -96,8 +91,8 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
|
||||
- io.BytesIO: A file-like object containing the image data.
|
||||
"""
|
||||
array = tensor.cpu().numpy()
|
||||
array = (array * 255).astype('uint8')
|
||||
image = Image.fromarray(array, 'RGB')
|
||||
array = (array * 255).astype("uint8")
|
||||
image = Image.fromarray(array, "RGB")
|
||||
|
||||
original_width, original_height = image.size
|
||||
original_pixels = original_width * original_height
|
||||
@ -112,7 +107,7 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
|
||||
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
img_byte_arr = BytesIO()
|
||||
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
|
||||
image.save(img_byte_arr, format="PNG") # PNG is used for lossless compression
|
||||
img_byte_arr.seek(0)
|
||||
return img_byte_arr
|
||||
|
||||
@ -145,11 +140,9 @@ async def create_generate_task(
|
||||
TAPose=ta_pose,
|
||||
),
|
||||
files=[
|
||||
(
|
||||
"images",
|
||||
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image)
|
||||
)
|
||||
for image in images if image is not None
|
||||
("images", open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image))
|
||||
for image in images
|
||||
if image is not None
|
||||
],
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
@ -177,6 +170,7 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
||||
return "DONE"
|
||||
return "Generating"
|
||||
|
||||
|
||||
def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None:
|
||||
if not response.jobs:
|
||||
return None
|
||||
@ -214,7 +208,7 @@ async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.Fi
|
||||
model_file_path = None
|
||||
file_3d = None
|
||||
|
||||
for i in url_list.list:
|
||||
for i in url_list.items:
|
||||
file_path = os.path.join(save_path, i.name)
|
||||
if i.name.lower().endswith(".glb"):
|
||||
model_file_path = os.path.join(result_folder_name, i.name)
|
||||
@ -489,7 +483,16 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
IO.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
|
||||
IO.Combo.Input(
|
||||
"Polygon_count",
|
||||
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
|
||||
options=[
|
||||
"4K-Quad",
|
||||
"8K-Quad",
|
||||
"18K-Quad",
|
||||
"50K-Quad",
|
||||
"2K-Triangle",
|
||||
"20K-Triangle",
|
||||
"150K-Triangle",
|
||||
"500K-Triangle",
|
||||
],
|
||||
default="500K-Triangle",
|
||||
optional=True,
|
||||
),
|
||||
@ -542,6 +545,566 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
def _rodin_multipart_parser(data: dict[str, Any]) -> aiohttp.FormData:
|
||||
"""Convert a Rodin request dict to an aiohttp form, fixing bool/list serialization.
|
||||
|
||||
Booleans --> "true"/"false". Lists --> one field per element.
|
||||
"""
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
for key, value in data.items():
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, bool):
|
||||
form.add_field(key, "true" if value else "false")
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
form.add_field(key, str(item))
|
||||
elif isinstance(value, (bytes, bytearray)):
|
||||
form.add_field(key, value)
|
||||
else:
|
||||
form.add_field(key, str(value))
|
||||
return form
|
||||
|
||||
|
||||
async def _create_gen25_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: Rodin3DGen25Request,
|
||||
images: list | None,
|
||||
) -> tuple[str, str]:
|
||||
"""Submit a Gen-2.5 generate job; returns (task_uuid, subscription_key)."""
|
||||
|
||||
if images is not None and len(images) > 5:
|
||||
raise ValueError("Rodin Gen-2.5 supports at most 5 input images.")
|
||||
|
||||
files = None
|
||||
if images:
|
||||
files = [
|
||||
(
|
||||
"images",
|
||||
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image),
|
||||
)
|
||||
for image in images
|
||||
if image is not None
|
||||
]
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"),
|
||||
response_model=Rodin3DGenerateResponse,
|
||||
data=request,
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
multipart_parser=_rodin_multipart_parser,
|
||||
)
|
||||
|
||||
if not response.uuid or not response.jobs or not response.jobs.subscription_key:
|
||||
raise RuntimeError(f"Rodin Gen-2.5 submit failed: message={response.message!r}")
|
||||
return response.uuid, response.jobs.subscription_key
|
||||
|
||||
|
||||
_PREVIEWABLE_3D_EXTS = {".glb", ".obj", ".fbx", ".stl", ".gltf"}
|
||||
|
||||
|
||||
async def _download_gen25_files(
|
||||
download_list: Rodin3DDownloadResponse,
|
||||
task_uuid: str,
|
||||
geometry_file_format: str,
|
||||
) -> Types.File3D | None:
|
||||
"""Download every file in the list; return the File3D matching the chosen format."""
|
||||
|
||||
folder_name = f"Rodin3D_Gen25_{task_uuid}"
|
||||
save_dir = os.path.join(comfy_paths.get_output_directory(), folder_name)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
target_ext = f".{geometry_file_format.lower().lstrip('.')}"
|
||||
file_3d: Types.File3D | None = None
|
||||
|
||||
for item in download_list.items:
|
||||
file_path = os.path.join(save_dir, item.name)
|
||||
ext = os.path.splitext(item.name.lower())[1]
|
||||
# Prefer the file matching the user's chosen format; fall back below.
|
||||
if file_3d is None and ext == target_ext and ext in _PREVIEWABLE_3D_EXTS:
|
||||
file_3d = await download_url_to_file_3d(item.url, target_ext.lstrip("."))
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_3d.get_bytes())
|
||||
continue
|
||||
await download_url_to_bytesio(item.url, file_path)
|
||||
|
||||
# If the chosen format wasn't found, surface any model file we did get.
|
||||
if file_3d is None:
|
||||
for item in download_list.items:
|
||||
ext = os.path.splitext(item.name.lower())[1]
|
||||
if ext in _PREVIEWABLE_3D_EXTS:
|
||||
file_3d = await download_url_to_file_3d(item.url, ext.lstrip("."))
|
||||
break
|
||||
return file_3d
|
||||
|
||||
|
||||
_MODE_REGULAR = "Regular"
|
||||
_MODE_FAST = "Fast"
|
||||
_MODE_EXTREME_HIGH = "Extreme-High"
|
||||
|
||||
_REGULAR_POLY_OPTIONS = [
|
||||
"Default",
|
||||
"4K-Quad",
|
||||
"8K-Quad",
|
||||
"18K-Quad",
|
||||
"50K-Quad",
|
||||
"2K-Triangle",
|
||||
"20K-Triangle",
|
||||
"150K-Triangle",
|
||||
"500K-Triangle",
|
||||
"1M-Triangle",
|
||||
]
|
||||
|
||||
_TEXTURE_MODE_OPTIONS = ["Default", "legacy", "extreme-low", "low", "medium", "high"]
|
||||
_GEOMETRY_FORMAT_OPTIONS = ["glb", "fbx", "obj", "stl"]
|
||||
_MATERIAL_OPTIONS = ["PBR", "Shaded", "All", "None"]
|
||||
|
||||
|
||||
def _build_mode_input(name: str = "mode") -> IO.DynamicCombo.Input:
|
||||
return IO.DynamicCombo.Input(
|
||||
name,
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
_MODE_REGULAR,
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"tier",
|
||||
options=["Gen-2.5-Low", "Gen-2.5-Medium", "Gen-2.5-High"],
|
||||
default="Gen-2.5-High",
|
||||
tooltip="Quality tier. Higher tiers produce higher-fidelity geometry.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"polygon_count",
|
||||
options=_REGULAR_POLY_OPTIONS,
|
||||
default="Default",
|
||||
tooltip="Preset face count. 'Default' uses the server's default for the selected tier.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"creative",
|
||||
default=False,
|
||||
tooltip="Creative mode (Medium/High only). Enhances generative robustness.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_MODE_FAST,
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"tier",
|
||||
options=[
|
||||
"Gen-2.5-Extreme-Low",
|
||||
"Gen-2.5-Low",
|
||||
"Gen-2.5-Medium",
|
||||
"Gen-2.5-High",
|
||||
],
|
||||
default="Gen-2.5-Low",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"mesh_faces",
|
||||
default=20000,
|
||||
min=1000,
|
||||
max=20000,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Mesh face count (1K-20K in Fast mode).",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_MODE_EXTREME_HIGH,
|
||||
[
|
||||
IO.Combo.Input("mesh_mode", options=["Raw", "Quad"], default="Raw"),
|
||||
IO.Int.Input(
|
||||
"mesh_faces",
|
||||
default=1000000,
|
||||
min=20000,
|
||||
max=2000000,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip=(
|
||||
"Mesh face count. Raw mode: 20K-2M. "
|
||||
"Quad mode: keep under 200K (upstream may reject higher values)."
|
||||
),
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"is_micro",
|
||||
default=False,
|
||||
tooltip="Enable micro detail (Extreme-High only).",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"creative",
|
||||
default=False,
|
||||
tooltip="Creative mode. Enhances generative robustness.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip=(
|
||||
"Generation mode. Regular = balanced. Fast = 1K-20K faces for rapid prototyping. "
|
||||
"Extreme-High = 20K-2M faces with optional micro details."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _build_common_inputs(*, include_image_only: bool) -> list:
|
||||
inputs: list = [
|
||||
IO.Combo.Input("material", options=_MATERIAL_OPTIONS, default="Shaded"),
|
||||
IO.Combo.Input("geometry_file_format", options=_GEOMETRY_FORMAT_OPTIONS, default="glb"),
|
||||
IO.Combo.Input(
|
||||
"texture_mode",
|
||||
options=_TEXTURE_MODE_OPTIONS,
|
||||
default="Default",
|
||||
optional=True,
|
||||
tooltip="Texture quality preset. 'Default' uses the server's default for the selected tier.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=65535,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"TAPose", default=False, optional=True, advanced=True, tooltip="T/A pose for human-like models."
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"hd_texture", default=False, optional=True, advanced=True, tooltip="High-quality texture enhancement."
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"texture_delight",
|
||||
default=False,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Remove baked lighting from textures.",
|
||||
),
|
||||
]
|
||||
if include_image_only:
|
||||
inputs.append(
|
||||
IO.Boolean.Input(
|
||||
"use_original_alpha",
|
||||
default=False,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Preserve image transparency.",
|
||||
)
|
||||
)
|
||||
inputs.extend(
|
||||
[
|
||||
IO.Boolean.Input(
|
||||
"addon_highpack",
|
||||
default=False,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="HighPack addon: 4K textures and ~16x faces in Quad mode.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"bbox_width",
|
||||
default=0,
|
||||
min=0,
|
||||
max=300,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Bounding-box width (Y axis). Set to 0 with the others to skip bbox.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"bbox_height",
|
||||
default=0,
|
||||
min=0,
|
||||
max=300,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Bounding-box height (Z axis).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"bbox_length",
|
||||
default=0,
|
||||
min=0,
|
||||
max=300,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Bounding-box length (X axis).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"height_cm",
|
||||
default=0,
|
||||
min=0,
|
||||
max=10000,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Approximate model height in centimeters (0 to skip).",
|
||||
),
|
||||
]
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
_PRICE_EXPR = """
|
||||
(
|
||||
$baseCredits := widgets.mode = "extreme-high" ? 1.0 : 0.5;
|
||||
$addonCredits := widgets.addon_highpack ? 1.0 : 0.0;
|
||||
$total := ($baseCredits * 1.5) + ($addonCredits * 0.8);
|
||||
{"type":"usd","usd": $total}
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
def _resolve_mode_params(mode_input: dict) -> dict:
|
||||
"""Translate the DynamicCombo `mode` payload into Gen-2.5 request fields.
|
||||
|
||||
Returns a dict with: tier, quality_override, mesh_mode, geometry_instruct_mode, is_micro.
|
||||
Missing keys mean "do not send" (so we don't override server defaults).
|
||||
"""
|
||||
selected = mode_input["mode"]
|
||||
out: dict = {}
|
||||
|
||||
if selected == _MODE_REGULAR:
|
||||
out["tier"] = mode_input["tier"]
|
||||
polygon = mode_input.get("polygon_count", "Default")
|
||||
if polygon != "Default":
|
||||
mesh_mode, faces = get_quality_mode(polygon)
|
||||
out["mesh_mode"] = mesh_mode
|
||||
out["quality_override"] = faces
|
||||
if mode_input.get("creative"):
|
||||
out["geometry_instruct_mode"] = "creative"
|
||||
|
||||
elif selected == _MODE_FAST:
|
||||
out["tier"] = mode_input["tier"]
|
||||
out["mesh_mode"] = "Raw"
|
||||
out["quality_override"] = int(mode_input["mesh_faces"])
|
||||
|
||||
elif selected == _MODE_EXTREME_HIGH:
|
||||
out["tier"] = "Gen-2.5-Extreme-High"
|
||||
out["mesh_mode"] = mode_input["mesh_mode"]
|
||||
out["quality_override"] = int(mode_input["mesh_faces"])
|
||||
if mode_input.get("is_micro"):
|
||||
out["is_micro"] = True
|
||||
if mode_input.get("creative"):
|
||||
out["geometry_instruct_mode"] = "creative"
|
||||
return out
|
||||
|
||||
|
||||
def _build_request(
|
||||
*,
|
||||
mode_input: dict,
|
||||
material: str,
|
||||
geometry_file_format: str,
|
||||
texture_mode: str,
|
||||
seed: int,
|
||||
TAPose: bool,
|
||||
hd_texture: bool,
|
||||
texture_delight: bool,
|
||||
addon_highpack: bool,
|
||||
bbox_width: int,
|
||||
bbox_height: int,
|
||||
bbox_length: int,
|
||||
height_cm: int,
|
||||
prompt: str | None = None,
|
||||
use_original_alpha: bool = False,
|
||||
) -> Rodin3DGen25Request:
|
||||
mode_params = _resolve_mode_params(mode_input)
|
||||
|
||||
bbox = None
|
||||
if bbox_width and bbox_height and bbox_length:
|
||||
bbox = [bbox_width, bbox_height, bbox_length]
|
||||
|
||||
return Rodin3DGen25Request(
|
||||
tier=mode_params["tier"],
|
||||
prompt=prompt or None,
|
||||
seed=seed,
|
||||
material=material,
|
||||
geometry_file_format=geometry_file_format,
|
||||
texture_mode=None if texture_mode == "Default" else texture_mode,
|
||||
mesh_mode=mode_params.get("mesh_mode"),
|
||||
quality_override=mode_params.get("quality_override"),
|
||||
geometry_instruct_mode=mode_params.get("geometry_instruct_mode"),
|
||||
bbox_condition=bbox,
|
||||
height=height_cm or None,
|
||||
TAPose=TAPose or None,
|
||||
hd_texture=hd_texture or None,
|
||||
texture_delight=texture_delight or None,
|
||||
is_micro=mode_params.get("is_micro"),
|
||||
use_original_alpha=use_original_alpha or None,
|
||||
addons=["HighPack"] if addon_highpack else None,
|
||||
)
|
||||
|
||||
|
||||
class Rodin3D_Gen25_Image(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen25_Image",
|
||||
display_name="Rodin 3D Gen-2.5 - Image to 3D",
|
||||
category="api node/3d/Rodin",
|
||||
description=(
|
||||
"Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. "
|
||||
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
|
||||
),
|
||||
inputs=[
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplatePrefix(IO.Image.Input("image"), prefix="image", min=1, max=5),
|
||||
tooltip="1-5 images. The first image is used for materials when multi-view.",
|
||||
),
|
||||
_build_mode_input(),
|
||||
*_build_common_inputs(include_image_only=True),
|
||||
],
|
||||
outputs=[IO.File3DAny.Output(display_name="model_file")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode", "addon_highpack"]),
|
||||
expr=_PRICE_EXPR,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
images: IO.Autogrow.Type,
|
||||
mode: dict,
|
||||
material: str,
|
||||
geometry_file_format: str,
|
||||
texture_mode: str,
|
||||
seed: int,
|
||||
TAPose: bool,
|
||||
hd_texture: bool,
|
||||
texture_delight: bool,
|
||||
use_original_alpha: bool,
|
||||
addon_highpack: bool,
|
||||
bbox_width: int,
|
||||
bbox_height: int,
|
||||
bbox_length: int,
|
||||
height_cm: int,
|
||||
) -> IO.NodeOutput:
|
||||
image_tensors = [img for img in images.values() if img is not None]
|
||||
if not image_tensors:
|
||||
raise ValueError("Rodin Gen-2.5 Image-to-3D requires at least one image.")
|
||||
|
||||
# Flatten multi-image tensors into individual frames; the API accepts each as a separate part.
|
||||
flat_images: list = []
|
||||
for tensor in image_tensors:
|
||||
if hasattr(tensor, "shape") and len(tensor.shape) == 4:
|
||||
for i in range(tensor.shape[0]):
|
||||
flat_images.append(tensor[i])
|
||||
else:
|
||||
flat_images.append(tensor)
|
||||
|
||||
if len(flat_images) > 5:
|
||||
raise ValueError(f"Rodin Gen-2.5 accepts at most 5 images; received {len(flat_images)}.")
|
||||
|
||||
request = _build_request(
|
||||
mode_input=mode,
|
||||
material=material,
|
||||
geometry_file_format=geometry_file_format,
|
||||
texture_mode=texture_mode,
|
||||
seed=seed,
|
||||
TAPose=TAPose,
|
||||
hd_texture=hd_texture,
|
||||
texture_delight=texture_delight,
|
||||
addon_highpack=addon_highpack,
|
||||
bbox_width=bbox_width,
|
||||
bbox_height=bbox_height,
|
||||
bbox_length=bbox_length,
|
||||
height_cm=height_cm,
|
||||
prompt=None,
|
||||
use_original_alpha=use_original_alpha,
|
||||
)
|
||||
|
||||
task_uuid, subscription_key = await _create_gen25_task(cls, request, flat_images)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
file_3d = await _download_gen25_files(download_list, task_uuid, geometry_file_format)
|
||||
return IO.NodeOutput(file_3d)
|
||||
|
||||
|
||||
class Rodin3D_Gen25_Text(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen25_Text",
|
||||
display_name="Rodin 3D Gen-2.5 - Text to 3D",
|
||||
category="api node/3d/Rodin",
|
||||
description=(
|
||||
"Generate a 3D model from a text prompt via Rodin Gen-2.5. "
|
||||
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
|
||||
),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the 3D model.",
|
||||
),
|
||||
_build_mode_input(),
|
||||
*_build_common_inputs(include_image_only=False),
|
||||
],
|
||||
outputs=[IO.File3DAny.Output(display_name="model_file")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode", "addon_highpack"]),
|
||||
expr=_PRICE_EXPR,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
mode: dict,
|
||||
material: str,
|
||||
geometry_file_format: str,
|
||||
texture_mode: str,
|
||||
seed: int,
|
||||
TAPose: bool,
|
||||
hd_texture: bool,
|
||||
texture_delight: bool,
|
||||
addon_highpack: bool,
|
||||
bbox_width: int,
|
||||
bbox_height: int,
|
||||
bbox_length: int,
|
||||
height_cm: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, field_name="prompt", min_length=1, max_length=2500)
|
||||
request = _build_request(
|
||||
mode_input=mode,
|
||||
material=material,
|
||||
geometry_file_format=geometry_file_format,
|
||||
texture_mode=texture_mode,
|
||||
seed=seed,
|
||||
TAPose=TAPose,
|
||||
hd_texture=hd_texture,
|
||||
texture_delight=texture_delight,
|
||||
addon_highpack=addon_highpack,
|
||||
bbox_width=bbox_width,
|
||||
bbox_height=bbox_height,
|
||||
bbox_length=bbox_length,
|
||||
height_cm=height_cm,
|
||||
prompt=prompt,
|
||||
)
|
||||
task_uuid, subscription_key = await _create_gen25_task(cls, request, images=None)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
file_3d = await _download_gen25_files(download_list, task_uuid, geometry_file_format)
|
||||
return IO.NodeOutput(file_3d)
|
||||
|
||||
|
||||
class Rodin3DExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -551,6 +1114,8 @@ class Rodin3DExtension(ComfyExtension):
|
||||
Rodin3D_Smooth,
|
||||
Rodin3D_Sketch,
|
||||
Rodin3D_Gen2,
|
||||
Rodin3D_Gen25_Image,
|
||||
Rodin3D_Gen25_Text,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -16,16 +16,17 @@ from .conversions import (
|
||||
convert_mask_to_image,
|
||||
downscale_image_tensor,
|
||||
downscale_image_tensor_by_max_side,
|
||||
downscale_video_to_max_pixels,
|
||||
image_tensor_pair_to_batch,
|
||||
pil_to_bytesio,
|
||||
resize_mask_to_image,
|
||||
resize_video_to_pixel_budget,
|
||||
tensor_to_base64_string,
|
||||
tensor_to_bytesio,
|
||||
tensor_to_pil,
|
||||
text_filepath_to_base64_string,
|
||||
text_filepath_to_data_uri,
|
||||
trim_video,
|
||||
upscale_video_to_min_pixels,
|
||||
video_to_base64_string,
|
||||
)
|
||||
from .download_helpers import (
|
||||
@ -88,16 +89,17 @@ __all__ = [
|
||||
"convert_mask_to_image",
|
||||
"downscale_image_tensor",
|
||||
"downscale_image_tensor_by_max_side",
|
||||
"downscale_video_to_max_pixels",
|
||||
"image_tensor_pair_to_batch",
|
||||
"pil_to_bytesio",
|
||||
"resize_mask_to_image",
|
||||
"resize_video_to_pixel_budget",
|
||||
"tensor_to_base64_string",
|
||||
"tensor_to_bytesio",
|
||||
"tensor_to_pil",
|
||||
"text_filepath_to_base64_string",
|
||||
"text_filepath_to_data_uri",
|
||||
"trim_video",
|
||||
"upscale_video_to_min_pixels",
|
||||
"video_to_base64_string",
|
||||
# Validation utilities
|
||||
"get_image_dimensions",
|
||||
|
||||
@ -415,14 +415,48 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
|
||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||
|
||||
|
||||
def resize_video_to_pixel_budget(video: Input.Video, total_pixels: int) -> Input.Video:
|
||||
"""Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio.
|
||||
def downscale_video_to_max_pixels(video: Input.Video, max_pixels: int) -> Input.Video:
|
||||
"""Downscale a video to fit within ``max_pixels`` (w * h), preserving aspect ratio.
|
||||
|
||||
Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio.
|
||||
Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
|
||||
"""
|
||||
src_w, src_h = video.get_dimensions()
|
||||
scale_dims = _compute_downscale_dims(src_w, src_h, total_pixels)
|
||||
scale_dims = _compute_downscale_dims(src_w, src_h, max_pixels)
|
||||
if scale_dims is None:
|
||||
return video
|
||||
return _apply_video_scale(video, scale_dims)
|
||||
|
||||
|
||||
def _compute_upscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[int, int] | None:
|
||||
"""Return upscaled (w, h) with even dims meeting at least ``total_pixels``, or None if already large enough.
|
||||
|
||||
Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions
|
||||
are rounded up to even values (many codecs require divisible-by-2). The result is guaranteed to be at
|
||||
least ``total_pixels``.
|
||||
"""
|
||||
pixels = src_w * src_h
|
||||
if pixels >= total_pixels:
|
||||
return None
|
||||
scale = math.sqrt(total_pixels / pixels)
|
||||
new_w = math.ceil(src_w * scale)
|
||||
new_h = math.ceil(src_h * scale)
|
||||
if new_w % 2:
|
||||
new_w += 1
|
||||
if new_h % 2:
|
||||
new_h += 1
|
||||
return new_w, new_h
|
||||
|
||||
|
||||
def upscale_video_to_min_pixels(video: Input.Video, min_pixels: int) -> Input.Video:
|
||||
"""Upscale a video to meet at least ``min_pixels`` (w * h), preserving aspect ratio.
|
||||
|
||||
Returns the original video object untouched when it already meets the minimum. Preserves frame rate,
|
||||
duration, and audio. Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
|
||||
Note: upscaling a low-resolution source does not add real detail; downstream model quality may suffer.
|
||||
"""
|
||||
src_w, src_h = video.get_dimensions()
|
||||
scale_dims = _compute_upscale_dims(src_w, src_h, min_pixels)
|
||||
if scale_dims is None:
|
||||
return video
|
||||
return _apply_video_scale(video, scale_dims)
|
||||
|
||||
@ -543,7 +543,7 @@ class AudioConcat(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="AudioConcat",
|
||||
search_aliases=["join audio", "combine audio", "append audio"],
|
||||
display_name="Audio Concat",
|
||||
display_name="Concatenate Audio",
|
||||
description="Concatenates the audio1 to audio2 in the specified direction.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -597,7 +597,7 @@ class AudioMerge(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="AudioMerge",
|
||||
search_aliases=["mix audio", "overlay audio", "layer audio"],
|
||||
display_name="Audio Merge",
|
||||
display_name="Merge Audio",
|
||||
description="Combine two audio tracks by overlaying their waveforms.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -667,8 +667,9 @@ class AudioAdjustVolume(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="AudioAdjustVolume",
|
||||
search_aliases=["audio gain", "loudness", "audio level"],
|
||||
display_name="Audio Adjust Volume",
|
||||
display_name="Adjust Audio Volume",
|
||||
category="audio",
|
||||
description="Adjust the volume of the audio by a specified amount in decibels (dB).",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.Int.Input(
|
||||
|
||||
@ -47,8 +47,10 @@ class LoadImageDataSetFromFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadImageDataSetFromFolder",
|
||||
display_name="Load Image Dataset from Folder",
|
||||
category="dataset",
|
||||
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
|
||||
display_name="Load Image (from Folder)",
|
||||
category="image",
|
||||
description="Load a dataset of images from a specified folder and return a list of images. Supported formats: PNG, JPG, JPEG, WEBP.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
@ -84,14 +86,16 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadImageTextDataSetFromFolder",
|
||||
display_name="Load Image and Text Dataset from Folder",
|
||||
category="dataset",
|
||||
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
|
||||
display_name="Load Image-Text (from Folder)",
|
||||
category="image",
|
||||
description="Load a dataset of pairs of images and text captions from a specified folder and return them as a list. Supported formats: PNG, JPG, JPEG, WEBP.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"folder",
|
||||
options=folder_paths.get_input_subfolders(),
|
||||
tooltip="The folder to load images from.",
|
||||
tooltip="The folder to load images and text captions from.",
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
@ -206,8 +210,10 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveImageDataSetToFolder",
|
||||
display_name="Save Image Dataset to Folder",
|
||||
category="dataset",
|
||||
search_aliases=["save folder", "save to folder", "save dataset", "save images", "export dataset"],
|
||||
display_name="Save Image (to Folder) (DEPRECATED)",
|
||||
category="image",
|
||||
description="Save a dataset of images to a specified folder. Supported formats: PNG.",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
is_input_list=True, # Receive images as list
|
||||
@ -226,6 +232,7 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[],
|
||||
is_deprecated=True, # This node is redundant and superseded by existing Save Image nodes where the target folder can be specified in the filename_prefix
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -246,14 +253,20 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveImageTextDataSetToFolder",
|
||||
display_name="Save Image and Text Dataset to Folder",
|
||||
category="dataset",
|
||||
search_aliases=["save folder", "save to folder", "save dataset", "save images", "save text", "export dataset"],
|
||||
display_name="Save Image-Text (to Folder)",
|
||||
category="image",
|
||||
description="Save a dataset of pairs of images and text captions to a specified folder. Images are saved as PNG files and captions are saved as TXT files with the same filename_prefix.",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
is_input_list=True, # Receive both images and texts as lists
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="List of images to save."),
|
||||
io.String.Input("texts", tooltip="List of text captions to save."),
|
||||
io.String.Input("texts",
|
||||
optional=True,
|
||||
force_input=True,
|
||||
tooltip="List of text captions to save."
|
||||
),
|
||||
io.String.Input(
|
||||
"folder_name",
|
||||
default="dataset",
|
||||
@ -270,7 +283,7 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, texts, folder_name, filename_prefix):
|
||||
def execute(cls, images, folder_name, filename_prefix, texts=None):
|
||||
# Extract scalar values
|
||||
folder_name = folder_name[0]
|
||||
filename_prefix = filename_prefix[0]
|
||||
@ -279,11 +292,12 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
saved_files = save_images_to_folder(images, output_dir, filename_prefix)
|
||||
|
||||
# Save captions
|
||||
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
|
||||
caption_filename = filename.replace(".png", ".txt")
|
||||
caption_path = os.path.join(output_dir, caption_filename)
|
||||
with open(caption_path, "w", encoding="utf-8") as f:
|
||||
f.write(caption)
|
||||
if texts:
|
||||
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
|
||||
caption_filename = filename.replace(".png", ".txt")
|
||||
caption_path = os.path.join(output_dir, caption_filename)
|
||||
with open(caption_path, "w", encoding="utf-8") as f:
|
||||
f.write(caption)
|
||||
|
||||
logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.")
|
||||
return io.NodeOutput()
|
||||
@ -314,11 +328,13 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
|
||||
Child classes should set:
|
||||
node_id: Unique node identifier (required)
|
||||
search_aliases: List of search aliases (optional)
|
||||
display_name: Display name (optional, defaults to node_id)
|
||||
description: Node description (optional)
|
||||
extra_inputs: List of additional io.Input objects beyond "images" (optional)
|
||||
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
||||
is_output_list: True (list output) or False (single output) (optional, default True)
|
||||
is_deprecated: True if the node is deprecated (optional, default False)
|
||||
|
||||
Child classes must implement ONE of:
|
||||
_process(cls, image, **kwargs) -> tensor (for single-item processing)
|
||||
@ -326,12 +342,13 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
"""
|
||||
|
||||
node_id = None
|
||||
search_aliases = []
|
||||
display_name = None
|
||||
description = None
|
||||
extra_inputs = []
|
||||
is_group_process = None # None = auto-detect, True/False = explicit
|
||||
is_output_list = None # None = auto-detect based on processing mode
|
||||
|
||||
is_deprecated = False
|
||||
@classmethod
|
||||
def _detect_processing_mode(cls):
|
||||
"""Detect whether this node uses group or individual processing.
|
||||
@ -402,8 +419,10 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
|
||||
return io.Schema(
|
||||
node_id=cls.node_id,
|
||||
search_aliases=cls.search_aliases,
|
||||
display_name=cls.display_name or cls.node_id,
|
||||
category="dataset/image",
|
||||
category=cls.category,
|
||||
description=cls.description,
|
||||
is_experimental=True,
|
||||
is_input_list=is_group, # True for group, False for individual
|
||||
inputs=inputs,
|
||||
@ -472,11 +491,13 @@ class TextProcessingNode(io.ComfyNode):
|
||||
|
||||
Child classes should set:
|
||||
node_id: Unique node identifier (required)
|
||||
search_aliases: List of search aliases (optional)
|
||||
display_name: Display name (optional, defaults to node_id)
|
||||
description: Node description (optional)
|
||||
extra_inputs: List of additional io.Input objects beyond "texts" (optional)
|
||||
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
||||
is_output_list: True (list output) or False (single output) (optional, default True)
|
||||
is_deprecated: True if the node is deprecated (optional, default False)
|
||||
|
||||
Child classes must implement ONE of:
|
||||
_process(cls, text, **kwargs) -> str (for single-item processing)
|
||||
@ -484,12 +505,13 @@ class TextProcessingNode(io.ComfyNode):
|
||||
"""
|
||||
|
||||
node_id = None
|
||||
search_aliases = []
|
||||
display_name = None
|
||||
description = None
|
||||
extra_inputs = []
|
||||
is_group_process = None # None = auto-detect, True/False = explicit
|
||||
is_output_list = None # None = auto-detect based on processing mode
|
||||
|
||||
is_deprecated = False
|
||||
@classmethod
|
||||
def _detect_processing_mode(cls):
|
||||
"""Detect whether this node uses group or individual processing.
|
||||
@ -627,15 +649,17 @@ class TextProcessingNode(io.ComfyNode):
|
||||
|
||||
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesByShorterEdge"
|
||||
display_name = "Resize Images by Shorter Edge"
|
||||
description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio."
|
||||
display_name = "Resize Images by Shorter Edge (DEPRECATED)"
|
||||
category = "image/transform"
|
||||
description = "Resize images so that the shorter edge matches the specified dimension while preserving aspect ratio."
|
||||
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale shorter dimension
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"shorter_edge",
|
||||
default=512,
|
||||
min=1,
|
||||
max=8192,
|
||||
tooltip="Target length for the shorter edge.",
|
||||
tooltip="Target dimension for the shorter edge.",
|
||||
),
|
||||
]
|
||||
|
||||
@ -655,15 +679,17 @@ class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
|
||||
class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesByLongerEdge"
|
||||
display_name = "Resize Images by Longer Edge"
|
||||
description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio."
|
||||
display_name = "Resize Images by Longer Edge (DEPRECATED)"
|
||||
category = "image/transform"
|
||||
description = "Resize images so that the longer edge matches the specified dimension while preserving aspect ratio."
|
||||
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale longer dimension
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"longer_edge",
|
||||
default=1024,
|
||||
min=1,
|
||||
max=8192,
|
||||
tooltip="Target length for the longer edge.",
|
||||
tooltip="Target dimension for the longer edge.",
|
||||
),
|
||||
]
|
||||
|
||||
@ -686,8 +712,10 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
||||
|
||||
class CenterCropImagesNode(ImageProcessingNode):
|
||||
node_id = "CenterCropImages"
|
||||
display_name = "Center Crop Images"
|
||||
description = "Center crop all images to the specified dimensions."
|
||||
search_aliases=["crop", "cut", "trim"]
|
||||
display_name="Crop Image (Center)"
|
||||
category="image/transform"
|
||||
description = "Center crop an image to the specified dimensions."
|
||||
extra_inputs = [
|
||||
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
|
||||
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
|
||||
@ -706,10 +734,11 @@ class CenterCropImagesNode(ImageProcessingNode):
|
||||
|
||||
class RandomCropImagesNode(ImageProcessingNode):
|
||||
node_id = "RandomCropImages"
|
||||
display_name = "Random Crop Images"
|
||||
description = (
|
||||
"Randomly crop all images to the specified dimensions (for data augmentation)."
|
||||
)
|
||||
search_aliases=["crop", "cut", "trim"]
|
||||
display_name = "Crop Image (Random)"
|
||||
category="image/transform"
|
||||
description = "Randomly crop an image to the specified dimensions."
|
||||
|
||||
extra_inputs = [
|
||||
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
|
||||
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
|
||||
@ -734,7 +763,9 @@ class RandomCropImagesNode(ImageProcessingNode):
|
||||
|
||||
class NormalizeImagesNode(ImageProcessingNode):
|
||||
node_id = "NormalizeImages"
|
||||
display_name = "Normalize Images"
|
||||
search_aliases=["normalize", "normalize colors"]
|
||||
display_name = "Normalize Image Colors"
|
||||
category = "image/color"
|
||||
description = "Normalize images using mean and standard deviation."
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
@ -762,8 +793,10 @@ class NormalizeImagesNode(ImageProcessingNode):
|
||||
|
||||
class AdjustBrightnessNode(ImageProcessingNode):
|
||||
node_id = "AdjustBrightness"
|
||||
search_aliases=["brightness"]
|
||||
display_name = "Adjust Brightness"
|
||||
description = "Adjust brightness of all images."
|
||||
category="image/adjustments"
|
||||
description = "Adjust the brightness of an image."
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
"factor",
|
||||
@ -781,8 +814,10 @@ class AdjustBrightnessNode(ImageProcessingNode):
|
||||
|
||||
class AdjustContrastNode(ImageProcessingNode):
|
||||
node_id = "AdjustContrast"
|
||||
search_aliases=["contrast"]
|
||||
display_name = "Adjust Contrast"
|
||||
description = "Adjust contrast of all images."
|
||||
category="image/adjustments"
|
||||
description = "Adjust the contrast of an image."
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
"factor",
|
||||
@ -800,8 +835,10 @@ class AdjustContrastNode(ImageProcessingNode):
|
||||
|
||||
class ShuffleDatasetNode(ImageProcessingNode):
|
||||
node_id = "ShuffleDataset"
|
||||
display_name = "Shuffle Image Dataset"
|
||||
description = "Randomly shuffle the order of images in the dataset."
|
||||
search_aliases=["shuffle", "randomize", "mix"]
|
||||
display_name = "Shuffle Images List"
|
||||
category = "image/batch"
|
||||
description = "Randomly shuffle the order of images in a list."
|
||||
is_group_process = True # Requires full list to shuffle
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
@ -823,13 +860,15 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ShuffleImageTextDataset",
|
||||
display_name="Shuffle Image-Text Dataset",
|
||||
category="dataset/image",
|
||||
search_aliases=["shuffle", "randomize", "mix"],
|
||||
display_name = "Shuffle Pairs of Image-Text",
|
||||
category = "image/batch",
|
||||
description = "Randomly shuffle the order of pairs of image-text in a list.",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="List of images to shuffle."),
|
||||
io.String.Input("texts", tooltip="List of texts to shuffle."),
|
||||
io.String.Input("texts", tooltip="List of texts to shuffle.", force_input=True),
|
||||
io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
@ -865,8 +904,11 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
|
||||
|
||||
class TextToLowercaseNode(TextProcessingNode):
|
||||
node_id = "TextToLowercase"
|
||||
display_name = "Text to Lowercase"
|
||||
description = "Convert all texts to lowercase."
|
||||
search_aliases=["lowercase"]
|
||||
display_name = "Convert Text to Lowercase (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Convert text to lowercase."
|
||||
is_deprecated = True # This node is superseded by the Convert Text Case node
|
||||
|
||||
@classmethod
|
||||
def _process(cls, text):
|
||||
@ -875,8 +917,11 @@ class TextToLowercaseNode(TextProcessingNode):
|
||||
|
||||
class TextToUppercaseNode(TextProcessingNode):
|
||||
node_id = "TextToUppercase"
|
||||
display_name = "Text to Uppercase"
|
||||
description = "Convert all texts to uppercase."
|
||||
search_aliases=["uppercase"]
|
||||
display_name = "Convert Text to Uppercase (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Convert text to uppercase."
|
||||
is_deprecated = True # This node is superseded by the Convert Text Case node
|
||||
|
||||
@classmethod
|
||||
def _process(cls, text):
|
||||
@ -885,8 +930,10 @@ class TextToUppercaseNode(TextProcessingNode):
|
||||
|
||||
class TruncateTextNode(TextProcessingNode):
|
||||
node_id = "TruncateText"
|
||||
search_aliases=["truncate", "cut", "shorten"]
|
||||
display_name = "Truncate Text"
|
||||
description = "Truncate all texts to a maximum length."
|
||||
category = "text"
|
||||
description = "Truncate text to a maximum length."
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"max_length", default=77, min=1, max=10000, tooltip="Maximum text length."
|
||||
@ -900,8 +947,10 @@ class TruncateTextNode(TextProcessingNode):
|
||||
|
||||
class AddTextPrefixNode(TextProcessingNode):
|
||||
node_id = "AddTextPrefix"
|
||||
display_name = "Add Text Prefix"
|
||||
display_name = "Add Text Prefix (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Add a prefix to all texts."
|
||||
is_deprecated = True # This node is superseded by the Concatenate Text node
|
||||
extra_inputs = [
|
||||
io.String.Input("prefix", default="", tooltip="Prefix to add."),
|
||||
]
|
||||
@ -913,8 +962,10 @@ class AddTextPrefixNode(TextProcessingNode):
|
||||
|
||||
class AddTextSuffixNode(TextProcessingNode):
|
||||
node_id = "AddTextSuffix"
|
||||
display_name = "Add Text Suffix"
|
||||
display_name = "Add Text Suffix (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Add a suffix to all texts."
|
||||
is_deprecated = True # This node is superseded by the Concatenate Text node
|
||||
extra_inputs = [
|
||||
io.String.Input("suffix", default="", tooltip="Suffix to add."),
|
||||
]
|
||||
@ -926,8 +977,10 @@ class AddTextSuffixNode(TextProcessingNode):
|
||||
|
||||
class ReplaceTextNode(TextProcessingNode):
|
||||
node_id = "ReplaceText"
|
||||
display_name = "Replace Text"
|
||||
display_name = "Replace Text (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Replace text in all texts."
|
||||
is_deprecated = True # This node is superseded by the other Replace Text node
|
||||
extra_inputs = [
|
||||
io.String.Input("find", default="", tooltip="Text to find."),
|
||||
io.String.Input("replace", default="", tooltip="Text to replace with."),
|
||||
@ -940,8 +993,10 @@ class ReplaceTextNode(TextProcessingNode):
|
||||
|
||||
class StripWhitespaceNode(TextProcessingNode):
|
||||
node_id = "StripWhitespace"
|
||||
display_name = "Strip Whitespace"
|
||||
display_name = "Strip Whitespace (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Strip leading and trailing whitespace from all texts."
|
||||
is_deprecated = True # This node is superseded by the Trim Text node
|
||||
|
||||
@classmethod
|
||||
def _process(cls, text):
|
||||
@ -952,11 +1007,13 @@ class StripWhitespaceNode(TextProcessingNode):
|
||||
|
||||
|
||||
class ImageDeduplicationNode(ImageProcessingNode):
|
||||
"""Remove duplicate or very similar images from the dataset using perceptual hashing."""
|
||||
"""Remove duplicate or very similar images from a list using perceptual hashing."""
|
||||
|
||||
node_id = "ImageDeduplication"
|
||||
display_name = "Image Deduplication"
|
||||
description = "Remove duplicate or very similar images from the dataset."
|
||||
search_aliases=["deduplicate", "remove duplicates", "similarity filter"]
|
||||
display_name = "Deduplicate Images"
|
||||
category = "image/batch"
|
||||
description = "Remove duplicate or very similar images from a list."
|
||||
is_group_process = True # Requires full list to compare images
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
@ -1026,7 +1083,9 @@ class ImageGridNode(ImageProcessingNode):
|
||||
"""Combine multiple images into a single grid/collage."""
|
||||
|
||||
node_id = "ImageGrid"
|
||||
display_name = "Image Grid"
|
||||
search_aliases=["grid", "collage", "combine"]
|
||||
display_name = "Make Image Grid"
|
||||
category="image/batch"
|
||||
description = "Arrange multiple images into a grid layout."
|
||||
is_group_process = True # Requires full list to create grid
|
||||
is_output_list = False # Outputs single grid image
|
||||
@ -1102,9 +1161,12 @@ class MergeImageListsNode(ImageProcessingNode):
|
||||
"""Merge multiple image lists into a single list."""
|
||||
|
||||
node_id = "MergeImageLists"
|
||||
display_name = "Merge Image Lists"
|
||||
search_aliases=["list", "merge list", "make list"]
|
||||
display_name = "Merge Image Lists (DEPRECATED)"
|
||||
category = "image/batch"
|
||||
description = "Concatenate multiple image lists into one."
|
||||
is_group_process = True # Receives images as list
|
||||
is_deprecated = True # This node is superseded by the Create List node
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, images):
|
||||
@ -1119,9 +1181,11 @@ class MergeTextListsNode(TextProcessingNode):
|
||||
"""Merge multiple text lists into a single list."""
|
||||
|
||||
node_id = "MergeTextLists"
|
||||
display_name = "Merge Text Lists"
|
||||
display_name = "Merge Text Lists (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Concatenate multiple text lists into one."
|
||||
is_group_process = True # Receives texts as list
|
||||
is_deprecated = True # This node is superseded by the Create List node
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, texts):
|
||||
@ -1142,8 +1206,10 @@ class ResolutionBucket(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ResolutionBucket",
|
||||
search_aliases=["bucket by resolution", "group by resolution", "batch by resolution"],
|
||||
display_name="Resolution Bucket",
|
||||
category="dataset",
|
||||
category="training",
|
||||
description="Group latents and conditionings into buckets",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
@ -1236,7 +1302,8 @@ class MakeTrainingDataset(io.ComfyNode):
|
||||
node_id="MakeTrainingDataset",
|
||||
search_aliases=["encode dataset"],
|
||||
display_name="Make Training Dataset",
|
||||
category="dataset",
|
||||
category="training",
|
||||
description="Encode images with VAE and texts with CLIP to create a training dataset of latents and conditionings.",
|
||||
is_experimental=True,
|
||||
is_input_list=True, # images and texts as lists
|
||||
inputs=[
|
||||
@ -1251,6 +1318,7 @@ class MakeTrainingDataset(io.ComfyNode):
|
||||
"texts",
|
||||
optional=True,
|
||||
tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).",
|
||||
force_input=True
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
@ -1320,9 +1388,10 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveTrainingDataset",
|
||||
search_aliases=["export training data"],
|
||||
search_aliases=["export dataset", "save dataset"],
|
||||
display_name="Save Training Dataset",
|
||||
category="dataset",
|
||||
category="training",
|
||||
description="Save encoded training dataset (latents + conditioning) to disk for efficient loading during training.",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
is_input_list=True, # Receive lists
|
||||
@ -1424,7 +1493,8 @@ class LoadTrainingDataset(io.ComfyNode):
|
||||
node_id="LoadTrainingDataset",
|
||||
search_aliases=["import dataset", "training data"],
|
||||
display_name="Load Training Dataset",
|
||||
category="dataset",
|
||||
category="training",
|
||||
description="Load encoded training dataset (latents + conditioning) from disk for use in training.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
|
||||
@ -419,15 +419,17 @@ class VoxelToMeshBasic(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VoxelToMeshBasic",
|
||||
display_name="Voxel to Mesh (Basic)",
|
||||
display_name="Voxel to Mesh (Basic) (DEPRECATED)",
|
||||
category="3d",
|
||||
description="Converts a voxel grid to a mesh.",
|
||||
is_deprecated=True, # This node is superseded by the Voxel To Mesh node
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output(),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -453,9 +455,10 @@ class VoxelToMesh(IO.ComfyNode):
|
||||
node_id="VoxelToMesh",
|
||||
display_name="Voxel to Mesh",
|
||||
category="3d",
|
||||
description="Converts a voxel grid to a mesh.",
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
IO.Combo.Input("algorithm", options=["surface net", "basic"], advanced=True),
|
||||
IO.Combo.Input("algorithm", options=["surface net", "basic"]),
|
||||
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
|
||||
@ -55,9 +55,10 @@ class ImageCropV2(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCropV2",
|
||||
search_aliases=["trim"],
|
||||
search_aliases=["crop", "cut", "trim"],
|
||||
display_name="Crop Image",
|
||||
category="image/transform",
|
||||
description = "Crop an image to the specified dimensions.",
|
||||
essentials_category="Image Tools",
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
|
||||
@ -8,6 +8,82 @@ from comfy_api.latest import _io
|
||||
MISSING = object()
|
||||
|
||||
|
||||
class NotNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ComfyNotNode",
|
||||
display_name="Not",
|
||||
category="utils/logic",
|
||||
description="Logical NOT operation. Returns true if the value is falsy. Uses Python's rules for truthiness.",
|
||||
search_aliases=["invert", "toggle", "negate", "flip boolean"],
|
||||
inputs=[
|
||||
io.AnyType.Input("value"),
|
||||
],
|
||||
outputs=[
|
||||
io.Boolean.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, value) -> io.NodeOutput:
|
||||
return io.NodeOutput(not value)
|
||||
|
||||
|
||||
class AndNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = io.Autogrow.TemplatePrefix(
|
||||
input=io.AnyType.Input("value"),
|
||||
prefix="value",
|
||||
min=1,
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="ComfyAndNode",
|
||||
display_name="And",
|
||||
category="utils/logic",
|
||||
description="Logical AND operation. Returns true if all of the values are truthy. Uses Python's rules for truthiness.",
|
||||
search_aliases=["all", "every"],
|
||||
inputs=[
|
||||
io.Autogrow.Input("values", template=template),
|
||||
],
|
||||
outputs=[
|
||||
io.Boolean.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, values: io.Autogrow.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(all(values.values()))
|
||||
|
||||
|
||||
class OrNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = io.Autogrow.TemplatePrefix(
|
||||
input=io.AnyType.Input("value"),
|
||||
prefix="value",
|
||||
min=1,
|
||||
)
|
||||
return io.Schema(
|
||||
node_id="ComfyOrNode",
|
||||
display_name="Or",
|
||||
category="utils/logic",
|
||||
description="Logical OR operation. Returns true if any of the values are truthy. Uses Python's rules for truthiness.",
|
||||
search_aliases=["any", "some"],
|
||||
inputs=[
|
||||
io.Autogrow.Input("values", template=template),
|
||||
],
|
||||
outputs=[
|
||||
io.Boolean.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, values: io.Autogrow.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(any(values.values()))
|
||||
|
||||
|
||||
class SwitchNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -15,7 +91,7 @@ class SwitchNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfySwitchNode",
|
||||
display_name="Switch",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Boolean.Input("switch"),
|
||||
@ -46,7 +122,7 @@ class SoftSwitchNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfySoftSwitchNode",
|
||||
display_name="Soft Switch",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Boolean.Input("switch"),
|
||||
@ -136,7 +212,7 @@ class DCTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="DCTestNode",
|
||||
display_name="DCTest",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
is_output_node=True,
|
||||
inputs=[io.DynamicCombo.Input("combo", options=[
|
||||
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
||||
@ -174,7 +250,7 @@ class AutogrowNamesTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="AutogrowNamesTestNode",
|
||||
display_name="AutogrowNamesTest",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
inputs=[
|
||||
_io.Autogrow.Input("autogrow", template=template)
|
||||
],
|
||||
@ -194,7 +270,7 @@ class AutogrowPrefixTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="AutogrowPrefixTestNode",
|
||||
display_name="AutogrowPrefixTest",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
inputs=[
|
||||
_io.Autogrow.Input("autogrow", template=template)
|
||||
],
|
||||
@ -213,7 +289,7 @@ class ComboOutputTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComboOptionTestNode",
|
||||
display_name="ComboOptionTest",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
|
||||
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
|
||||
outputs=[io.Combo.Output(), io.Combo.Output()],
|
||||
@ -230,7 +306,7 @@ class ConvertStringToComboNode(io.ComfyNode):
|
||||
node_id="ConvertStringToComboNode",
|
||||
search_aliases=["string to dropdown", "text to combo"],
|
||||
display_name="Convert String to Combo",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
inputs=[io.String.Input("string")],
|
||||
outputs=[io.Combo.Output()],
|
||||
)
|
||||
@ -246,7 +322,7 @@ class InvertBooleanNode(io.ComfyNode):
|
||||
node_id="InvertBooleanNode",
|
||||
search_aliases=["not", "toggle", "negate", "flip boolean"],
|
||||
display_name="Invert Boolean",
|
||||
category="logic",
|
||||
category="utils/logic",
|
||||
inputs=[io.Boolean.Input("boolean")],
|
||||
outputs=[io.Boolean.Output()],
|
||||
)
|
||||
@ -261,6 +337,9 @@ class LogicExtension(ComfyExtension):
|
||||
return [
|
||||
SwitchNode,
|
||||
CustomComboNode,
|
||||
NotNode,
|
||||
AndNode,
|
||||
OrNode,
|
||||
# SoftSwitchNode,
|
||||
# ConvertStringToComboNode,
|
||||
# DCTestNode,
|
||||
|
||||
@ -11,8 +11,8 @@ class LTXVAudioVAELoader(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAELoader",
|
||||
display_name="LTXV Audio VAE Loader",
|
||||
category="audio",
|
||||
display_name="Load LTXV Audio VAE",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"ckpt_name",
|
||||
@ -40,7 +40,7 @@ class LTXVAudioVAEEncode(VAEEncodeAudio):
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEEncode",
|
||||
display_name="LTXV Audio VAE Encode",
|
||||
category="audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Audio.Input("audio", tooltip="The audio to be encoded."),
|
||||
io.Vae.Input(
|
||||
@ -63,7 +63,7 @@ class LTXVAudioVAEDecode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEDecode",
|
||||
display_name="LTXV Audio VAE Decode",
|
||||
category="audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Latent.Input("samples", tooltip="The latent to be decoded."),
|
||||
io.Vae.Input(
|
||||
|
||||
@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfyMathExpression",
|
||||
display_name="Math Expression",
|
||||
category="logic",
|
||||
category="utils",
|
||||
search_aliases=[
|
||||
"expression", "formula", "calculate", "calculator",
|
||||
"eval", "math",
|
||||
|
||||
@ -28,7 +28,7 @@ from comfy_extras.mediapipe.face_landmarker import FaceLandmarker
|
||||
from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection
|
||||
|
||||
|
||||
FaceLandmarkerType = io.Custom("FACE_LANDMARKER")
|
||||
FaceDetectionType = io.Custom("FACE_DETECTION_MODEL")
|
||||
FaceLandmarksType = io.Custom("FACE_LANDMARKS")
|
||||
|
||||
_CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights")
|
||||
@ -204,18 +204,19 @@ class LoadMediaPipeFaceLandmarker(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadMediaPipeFaceLandmarker",
|
||||
display_name="Load MediaPipe Face Landmarker",
|
||||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
|
||||
display_name="Load Face Detection Model (MediaPipe)",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("mediapipe"),
|
||||
tooltip="Face Landmarker safetensors from models/mediapipe/."),
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("detection"),
|
||||
tooltip="Face detection model from models/detection/."),
|
||||
],
|
||||
outputs=[FaceLandmarkerType.Output()],
|
||||
outputs=[FaceDetectionType.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("mediapipe", model_name), safe_load=True)
|
||||
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("detection", model_name), safe_load=True)
|
||||
wrapper = FaceLandmarkerModel(sd)
|
||||
return io.NodeOutput(wrapper)
|
||||
|
||||
@ -234,10 +235,12 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MediaPipeFaceLandmarker",
|
||||
display_name="MediaPipe Face Landmarker",
|
||||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
|
||||
display_name="Detect Face Landmarks (MediaPipe)",
|
||||
category="image/detection",
|
||||
description="Detects facial landmarks using MediaPipe model.",
|
||||
inputs=[
|
||||
FaceLandmarkerType.Input("face_landmarker"),
|
||||
FaceDetectionType.Input("face_detection_model"),
|
||||
io.Image.Input("image"),
|
||||
io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short",
|
||||
tooltip="Face detector range. 'short' is tuned for close-up faces "
|
||||
@ -261,9 +264,9 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, face_landmarker, image, detector_variant, num_faces, min_confidence,
|
||||
def execute(cls, face_detection_model, image, detector_variant, num_faces, min_confidence,
|
||||
missing_frame_fallback) -> io.NodeOutput:
|
||||
canonical = face_landmarker.canonical_data
|
||||
canonical = face_detection_model.canonical_data
|
||||
img_np = _image_to_uint8(image)
|
||||
B, H, W = img_np.shape[:3]
|
||||
chunk = 16
|
||||
@ -276,7 +279,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq:
|
||||
for i in range(0, B, chunk):
|
||||
end = min(i + chunk, B)
|
||||
res.extend(face_landmarker.detect_batch(
|
||||
res.extend(face_detection_model.detect_batch(
|
||||
[img_np[bi] for bi in range(i, end)],
|
||||
num_faces=int(num_faces),
|
||||
score_thresh=float(min_confidence),
|
||||
@ -306,7 +309,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])})
|
||||
bboxes.append(per_bb)
|
||||
return io.NodeOutput({"frames": frames, "image_size": (H, W),
|
||||
"connection_sets": face_landmarker.connection_sets}, bboxes)
|
||||
"connection_sets": face_detection_model.connection_sets}, bboxes)
|
||||
|
||||
|
||||
# Topology keys unioned by the 'all' connections preset (contour parts + irises + nose).
|
||||
@ -332,8 +335,10 @@ class MediaPipeFaceMeshVisualize(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MediaPipeFaceMeshVisualize",
|
||||
display_name="MediaPipe Face Mesh Visualize",
|
||||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection", "visualize"],
|
||||
display_name="Visualize Face Landmarks (MediaPipe)",
|
||||
category="image/detection",
|
||||
description="Draws face landmarks mesh on the input image.",
|
||||
inputs=[
|
||||
FaceLandmarksType.Input("face_landmarks"),
|
||||
io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."),
|
||||
@ -443,8 +448,10 @@ class MediaPipeFaceMask(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MediaPipeFaceMask",
|
||||
display_name="MediaPipe Face Mask",
|
||||
search_aliases=["face", "facial", "mediapipe", "face mask", "blazeface", "face detection", "visualize"],
|
||||
display_name="Draw Face Mask (MediaPipe)",
|
||||
category="image/detection",
|
||||
description="Draws a mask from face landmarks.",
|
||||
inputs=[
|
||||
FaceLandmarksType.Input("face_landmarks"),
|
||||
io.DynamicCombo.Input(
|
||||
|
||||
@ -103,8 +103,10 @@ class MoGePanoramaInference(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGePanoramaInference",
|
||||
display_name="MoGe Panorama Inference",
|
||||
search_aliases=["moge", "panorama", "depth", "geometry", "depth estimation", "geometry estimation"],
|
||||
display_name="Run MoGe Panorama Inference",
|
||||
category="image/geometry_estimation",
|
||||
description="Run MoGe on an equirectangular panorama by splitting it into 12 perspective views, running inference on each, and merging the results into a single depth map.",
|
||||
inputs=[
|
||||
MoGeModelType.Input("moge_model"),
|
||||
io.Image.Input("image", tooltip="Equirectangular panorama (any aspect)."),
|
||||
@ -222,7 +224,9 @@ class MoGeInference(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeInference",
|
||||
display_name="MoGe Inference",
|
||||
search_aliases=["moge", "depth", "geometry", "depth estimation", "geometry estimation"],
|
||||
display_name="Run MoGe Inference",
|
||||
description="Run MoGe on a single image to estimate depth and geometry.",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeModelType.Input("moge_model"),
|
||||
@ -277,7 +281,9 @@ class MoGeRender(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeRender",
|
||||
display_name="MoGe Render",
|
||||
search_aliases=["moge", "render", "geometry", "depth", "normal"],
|
||||
display_name="Render MoGe Geometry",
|
||||
description="Render a depth map or normal map from geometry data",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeGeometry.Input("moge_geometry"),
|
||||
@ -342,7 +348,9 @@ class MoGePointMapToMesh(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGePointMapToMesh",
|
||||
display_name="MoGe Point Map to Mesh",
|
||||
search_aliases=["moge", "mesh", "geometry", "point map"],
|
||||
display_name="Convert MoGe Point Map to Mesh",
|
||||
description="Convert a MoGe point map into a 3D mesh.",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeGeometry.Input("moge_geometry"),
|
||||
|
||||
@ -14,7 +14,7 @@ class CreateList(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="CreateList",
|
||||
display_name="Create List",
|
||||
category="logic",
|
||||
category="utils",
|
||||
is_input_list=True,
|
||||
search_aliases=["Image Iterator", "Text Iterator", "Iterator"],
|
||||
inputs=[io.Autogrow.Input("inputs", template=template_autogrow)],
|
||||
|
||||
@ -60,7 +60,7 @@ folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geom
|
||||
|
||||
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["mediapipe"] = ([os.path.join(models_dir, "mediapipe")], supported_pt_extensions)
|
||||
folder_names_and_paths["detection"] = ([os.path.join(models_dir, "detection")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
|
||||
3908
openapi.yaml
3908
openapi.yaml
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.43.18
|
||||
comfyui-workflow-templates==0.9.79
|
||||
comfyui-workflow-templates==0.9.82
|
||||
comfyui-embedded-docs==0.5.0
|
||||
torch
|
||||
torchsde
|
||||
|
||||
@ -21,7 +21,6 @@ from app.assets.database.queries import (
|
||||
get_reference_ids_by_ids,
|
||||
ensure_tags_exist,
|
||||
add_tags_to_reference,
|
||||
set_reference_tags,
|
||||
)
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
@ -160,153 +159,6 @@ class TestListReferencesPage:
|
||||
assert refs[0].name == "large"
|
||||
|
||||
|
||||
class TestTagRetrievalOrder:
|
||||
"""End-to-end check: tags written through the public write paths come
|
||||
back from the public read paths in insertion order rather than the
|
||||
composite-PK alphabetical order SQLite would otherwise impose.
|
||||
|
||||
Each test deliberately picks tag names that would sort differently
|
||||
under alphabetical vs insertion order, so an alphabetical regression
|
||||
fails loudly.
|
||||
"""
|
||||
|
||||
def _make_ref(self, session: Session) -> AssetReference:
|
||||
asset = _make_asset(session, "h1")
|
||||
return _make_reference(session, asset, name="x.bin")
|
||||
|
||||
def test_set_reference_tags_preserves_input_order_in_list(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
# "checkpoints" < "models" alphabetically; if added_at stagger
|
||||
# works, list_references_page returns insertion order.
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
assert tag_map[ref.id] == ["models", "checkpoints"]
|
||||
|
||||
def test_set_reference_tags_preserves_input_order_in_fetch(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
# Subpath tag sorts before "models" alphabetically.
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "diffusers/kolors/text_encoder"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
result = fetch_reference_asset_and_tags(session, ref.id)
|
||||
assert result is not None
|
||||
_, _, tags = result
|
||||
# Bucket-prefix expansion appends the standalone `diffusers` token
|
||||
# at path-tier (microsecond stagger) so FE set-membership filters
|
||||
# match nested category paths.
|
||||
assert tags == ["models", "diffusers/kolors/text_encoder", "diffusers"]
|
||||
|
||||
def test_add_tags_to_reference_lands_after_path_tags(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
|
||||
session.commit()
|
||||
|
||||
# "aaa-..." sorts before both path tags alphabetically. If added_at
|
||||
# stagger is missing, alphabetic tiebreak would hoist it to tags[0].
|
||||
add_tags_to_reference(
|
||||
session, reference_id=ref.id, tags=["aaa-user-tag"], origin="manual"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
assert tag_map[ref.id] == ["models", "checkpoints", "aaa-user-tag"]
|
||||
|
||||
def test_multi_tag_batch_lands_after_path_tags(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
|
||||
session.commit()
|
||||
|
||||
# Three user tags inserted in non-alphabetical input order. Per-tag
|
||||
# microsecond stagger should preserve at least the "user batch is
|
||||
# after path tags" property; within the user batch insertion order
|
||||
# is also preserved.
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["zzz-z", "favorite", "experiment-q4"],
|
||||
origin="manual",
|
||||
)
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
tags = tag_map[ref.id]
|
||||
assert tags[0:2] == ["models", "checkpoints"]
|
||||
assert set(tags[2:]) == {"zzz-z", "favorite", "experiment-q4"}
|
||||
|
||||
def test_user_batch_lands_after_path_batch_under_clock_collision(
|
||||
self, session: Session, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Windows-specific race: when two back-to-back commits share the
|
||||
same datetime.now() microsecond, the path-tier and user-tier
|
||||
added_at values used to collide and alphabetic tiebreak would
|
||||
hoist user tags ahead of path tags. The fix reads
|
||||
max(existing_added_at) for the reference and seeds the next batch
|
||||
past it, deterministically restoring insertion order.
|
||||
|
||||
This test simulates the collision by pinning get_utc_now() so the
|
||||
platform-dependent race becomes a platform-independent failure.
|
||||
"""
|
||||
ref = self._make_ref(session)
|
||||
|
||||
from datetime import datetime
|
||||
from app.assets.database import queries as queries_pkg
|
||||
from app.assets.database.queries import tags as tags_module
|
||||
|
||||
frozen = datetime(2026, 1, 1, 0, 0, 0)
|
||||
monkeypatch.setattr(tags_module, "get_utc_now", lambda: frozen)
|
||||
monkeypatch.setattr(queries_pkg, "get_utc_now", lambda: frozen, raising=False)
|
||||
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
|
||||
session.commit()
|
||||
|
||||
# Same frozen timestamp — without the max(existing) seed, the
|
||||
# user batch would share added_at with the path batch and
|
||||
# `aaa-user-tag` would sort to position 0 via the alphabetic
|
||||
# tiebreaker.
|
||||
add_tags_to_reference(
|
||||
session, reference_id=ref.id, tags=["aaa-user-tag"], origin="manual"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
assert tag_map[ref.id] == ["models", "checkpoints", "aaa-user-tag"]
|
||||
|
||||
def test_remove_then_add_does_not_disrupt_path_tag_positions(
|
||||
self, session: Session
|
||||
):
|
||||
ref = self._make_ref(session)
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "loras/my/custom/path"],
|
||||
)
|
||||
session.commit()
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["temp-tag"])
|
||||
session.commit()
|
||||
from app.assets.database.queries import remove_tags_from_reference
|
||||
|
||||
remove_tags_from_reference(session, reference_id=ref.id, tags=["temp-tag"])
|
||||
session.commit()
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["second-tag"])
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
# `loras` is expanded from the nested category path; user-added
|
||||
# tags trail behind it via the microsecond stagger.
|
||||
assert tag_map[ref.id] == [
|
||||
"models",
|
||||
"loras/my/custom/path",
|
||||
"loras",
|
||||
"second-tag",
|
||||
]
|
||||
|
||||
|
||||
class TestFetchReferenceAssetAndTags:
|
||||
def test_returns_none_for_nonexistent(self, session: Session):
|
||||
result = fetch_reference_asset_and_tags(session, "nonexistent")
|
||||
|
||||
@ -160,120 +160,6 @@ class TestAddTagsToReference:
|
||||
add_tags_to_reference(session, reference_id="nonexistent", tags=["x"])
|
||||
|
||||
|
||||
class TestBucketPrefixExpansion:
|
||||
"""The standalone bucket token must appear in the asset's tag set for
|
||||
nested category paths so FE filters like
|
||||
`include_tags=models,checkpoints` continue to match.
|
||||
"""
|
||||
|
||||
def test_set_reference_tags_inserts_bucket_for_nested_path(
|
||||
self, session: Session
|
||||
):
|
||||
asset = _make_asset(session, "hash-nested")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
result = set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "checkpoints/flux"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert set(result.total) == {"models", "checkpoints/flux", "checkpoints"}
|
||||
stored = get_reference_tags(session, reference_id=ref.id)
|
||||
# tag[1] keeps the slash-joined positional contract; the standalone
|
||||
# bucket lands after it via path-tier microsecond stagger so user
|
||||
# tags remain at the tail.
|
||||
assert stored[:3] == ["models", "checkpoints/flux", "checkpoints"]
|
||||
|
||||
def test_set_reference_tags_idempotent_on_replay(self, session: Session):
|
||||
asset = _make_asset(session, "hash-replay")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "checkpoints/flux"],
|
||||
)
|
||||
# Replay with the same caller-supplied set; expansion is already
|
||||
# baked in, so nothing should be added or removed.
|
||||
result = set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "checkpoints/flux"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert result.added == []
|
||||
assert result.removed == []
|
||||
assert set(result.total) == {"models", "checkpoints/flux", "checkpoints"}
|
||||
|
||||
def test_add_tags_to_reference_expands_bucket(self, session: Session):
|
||||
asset = _make_asset(session, "hash-add")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
result = add_tags_to_reference(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["loras/style/v2"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert set(result.added) == {"loras/style/v2", "loras"}
|
||||
stored = get_reference_tags(session, reference_id=ref.id)
|
||||
assert "loras" in stored
|
||||
assert "loras/style/v2" in stored
|
||||
|
||||
def test_add_tags_does_not_duplicate_existing_bucket(self, session: Session):
|
||||
asset = _make_asset(session, "hash-dedupe")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
add_tags_to_reference(
|
||||
session, reference_id=ref.id, tags=["models", "checkpoints"]
|
||||
)
|
||||
result = add_tags_to_reference(
|
||||
session, reference_id=ref.id, tags=["checkpoints/flux"]
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# `checkpoints` was already there from the first add; only the
|
||||
# slash-joined token is genuinely new.
|
||||
assert result.added == ["checkpoints/flux"]
|
||||
assert "checkpoints" in result.already_present
|
||||
|
||||
def test_flat_category_is_unaffected(self, session: Session):
|
||||
asset = _make_asset(session, "hash-flat")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
result = set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "checkpoints"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert set(result.total) == {"models", "checkpoints"}
|
||||
assert get_reference_tags(session, reference_id=ref.id) == [
|
||||
"models",
|
||||
"checkpoints",
|
||||
]
|
||||
|
||||
def test_unknown_prefix_passes_through(self, session: Session):
|
||||
asset = _make_asset(session, "hash-user")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
# `my-org` isn't a registered bucket — the slash-joined user tag
|
||||
# should not trigger bucket expansion.
|
||||
result = set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["my-org/team-a"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert result.total == ["my-org/team-a"]
|
||||
|
||||
|
||||
class TestRemoveTagsFromReference:
|
||||
def test_removes_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
|
||||
@ -4,7 +4,7 @@ from pathlib import Path
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets
|
||||
|
||||
|
||||
@ -102,82 +102,6 @@ class TestBatchInsertSeedAssets:
|
||||
assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}"
|
||||
|
||||
|
||||
class TestBucketPrefixExpansionOnIngest:
|
||||
"""Path-scanning ingest must persist the standalone bucket token for
|
||||
nested category paths so the FE set-membership filter
|
||||
(`include_tags=models,checkpoints`) matches assets organized into
|
||||
subfolders (`models/checkpoints/flux/foo.safetensors`).
|
||||
"""
|
||||
|
||||
def test_nested_path_inserts_standalone_bucket(
|
||||
self, session: Session, temp_dir: Path
|
||||
):
|
||||
file_path = temp_dir / "flux.safetensors"
|
||||
file_path.write_bytes(b"content")
|
||||
|
||||
specs: list[SeedAssetSpec] = [
|
||||
{
|
||||
"abs_path": str(file_path),
|
||||
"size_bytes": 7,
|
||||
"mtime_ns": 1234567890000000000,
|
||||
"info_name": "flux",
|
||||
# Shape emitted by get_name_and_tags_from_asset_path for a
|
||||
# nested model path.
|
||||
"tags": ["models", "checkpoints/flux"],
|
||||
"fname": "flux.safetensors",
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": "application/safetensors",
|
||||
}
|
||||
]
|
||||
|
||||
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||
|
||||
assert result.inserted_refs == 1
|
||||
ref = session.query(AssetReference).filter_by(name="flux").one()
|
||||
stored = [
|
||||
row.tag_name
|
||||
for row in session.query(AssetReferenceTag)
|
||||
.filter_by(asset_reference_id=ref.id)
|
||||
.order_by(AssetReferenceTag.added_at.asc())
|
||||
.all()
|
||||
]
|
||||
assert stored == ["models", "checkpoints/flux", "checkpoints"]
|
||||
|
||||
def test_flat_path_remains_two_tags(
|
||||
self, session: Session, temp_dir: Path
|
||||
):
|
||||
file_path = temp_dir / "vanilla.safetensors"
|
||||
file_path.write_bytes(b"content")
|
||||
|
||||
specs: list[SeedAssetSpec] = [
|
||||
{
|
||||
"abs_path": str(file_path),
|
||||
"size_bytes": 7,
|
||||
"mtime_ns": 1234567890000000000,
|
||||
"info_name": "vanilla",
|
||||
"tags": ["models", "checkpoints"],
|
||||
"fname": "vanilla.safetensors",
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": "application/safetensors",
|
||||
}
|
||||
]
|
||||
|
||||
batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||
|
||||
ref = session.query(AssetReference).filter_by(name="vanilla").one()
|
||||
stored = {
|
||||
row.tag_name
|
||||
for row in session.query(AssetReferenceTag)
|
||||
.filter_by(asset_reference_id=ref.id)
|
||||
.all()
|
||||
}
|
||||
# Dedupe means flat layouts don't pick up a redundant `checkpoints`
|
||||
# row — tag[1] already serves both positional and set-membership.
|
||||
assert stored == {"models", "checkpoints"}
|
||||
|
||||
|
||||
class TestMetadataExtraction:
|
||||
def test_extracts_mime_type_for_model_files(self, temp_dir: Path):
|
||||
"""Verify metadata extraction returns correct mime_type for model files."""
|
||||
|
||||
@ -6,11 +6,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.services.path_utils import (
|
||||
get_asset_category_and_relative_path,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
)
|
||||
from app.assets.services.path_utils import get_asset_category_and_relative_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -42,50 +38,6 @@ def fake_dirs():
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_dirs_multi_bucket():
|
||||
"""Variant fixture with multiple model buckets (checkpoints + diffusers + loras)."""
|
||||
with tempfile.TemporaryDirectory() as root:
|
||||
root_path = Path(root)
|
||||
input_dir = root_path / "input"
|
||||
output_dir = root_path / "output"
|
||||
temp_dir = root_path / "temp"
|
||||
checkpoints_dir = root_path / "models" / "checkpoints"
|
||||
diffusers_dir = root_path / "models" / "diffusers"
|
||||
loras_dir = root_path / "models" / "loras"
|
||||
for d in (
|
||||
input_dir,
|
||||
output_dir,
|
||||
temp_dir,
|
||||
checkpoints_dir,
|
||||
diffusers_dir,
|
||||
loras_dir,
|
||||
):
|
||||
d.mkdir(parents=True)
|
||||
|
||||
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
|
||||
mock_fp.get_input_directory.return_value = str(input_dir)
|
||||
mock_fp.get_output_directory.return_value = str(output_dir)
|
||||
mock_fp.get_temp_directory.return_value = str(temp_dir)
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[
|
||||
("checkpoints", [str(checkpoints_dir)]),
|
||||
("diffusers", [str(diffusers_dir)]),
|
||||
("loras", [str(loras_dir)]),
|
||||
],
|
||||
):
|
||||
yield {
|
||||
"input": input_dir,
|
||||
"output": output_dir,
|
||||
"temp": temp_dir,
|
||||
"checkpoints": checkpoints_dir,
|
||||
"diffusers": diffusers_dir,
|
||||
"loras": loras_dir,
|
||||
}
|
||||
|
||||
|
||||
class TestGetAssetCategoryAndRelativePath:
|
||||
def test_input_file(self, fake_dirs):
|
||||
f = fake_dirs["input"] / "photo.png"
|
||||
@ -127,161 +79,3 @@ class TestGetAssetCategoryAndRelativePath:
|
||||
def test_unknown_path_raises(self, fake_dirs):
|
||||
with pytest.raises(ValueError, match="not within"):
|
||||
get_asset_category_and_relative_path("/some/random/path.png")
|
||||
|
||||
|
||||
class TestGetNameAndTagsFromAssetPath:
|
||||
"""tags collapse the parent subpath into a single slash-joined tag.
|
||||
|
||||
Consumers should be able to read ``tags[1]`` as a stable category
|
||||
identifier regardless of how deep the file lives in the bucket.
|
||||
"""
|
||||
|
||||
def test_flat_input(self, fake_dirs_multi_bucket):
|
||||
f = fake_dirs_multi_bucket["input"] / "photo.png"
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "photo.png"
|
||||
assert tags == ["input"]
|
||||
|
||||
def test_flat_output(self, fake_dirs_multi_bucket):
|
||||
f = fake_dirs_multi_bucket["output"] / "result_00001.png"
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "result_00001.png"
|
||||
assert tags == ["output"]
|
||||
|
||||
def test_flat_models_checkpoint(self, fake_dirs_multi_bucket):
|
||||
f = fake_dirs_multi_bucket["checkpoints"] / "flux.safetensors"
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "flux.safetensors"
|
||||
assert tags == ["models", "checkpoints"]
|
||||
|
||||
def test_diffusers_nested_subpath_slash_joined(self, fake_dirs_multi_bucket):
|
||||
"""Diffusers components live in nested directories — the full subpath
|
||||
must collapse into one tag so consumers can look up the model category
|
||||
via tags[1] regardless of nesting depth.
|
||||
|
||||
The subpath is lowercased to match the canonicalization
|
||||
:func:`ensure_tags_exist` applies on the write side; without that,
|
||||
the asset_reference_tags.tag_name FK to tags.name would fail for
|
||||
any path containing uppercase letters.
|
||||
"""
|
||||
nested = (
|
||||
fake_dirs_multi_bucket["diffusers"]
|
||||
/ "Kolors"
|
||||
/ "text_encoder"
|
||||
)
|
||||
nested.mkdir(parents=True)
|
||||
f = nested / "model.safetensors"
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "model.safetensors"
|
||||
assert tags == ["models", "diffusers/kolors/text_encoder"]
|
||||
|
||||
def test_deep_lora_user_subpath_slash_joined(self, fake_dirs_multi_bucket):
|
||||
"""User-created subdirectories under a model bucket also collapse to a
|
||||
single tag rather than one tag per directory."""
|
||||
nested = (
|
||||
fake_dirs_multi_bucket["loras"]
|
||||
/ "my"
|
||||
/ "custom"
|
||||
/ "path"
|
||||
)
|
||||
nested.mkdir(parents=True)
|
||||
f = nested / "v0001.safetensors"
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "v0001.safetensors"
|
||||
assert tags == ["models", "loras/my/custom/path"]
|
||||
|
||||
|
||||
class TestResolveDestinationFromTags:
|
||||
"""resolve_destination_from_tags must accept both the legacy
|
||||
one-tag-per-directory shape and the new slash-joined shape so that an
|
||||
upload using the tags it just read back from /api/assets round-trips
|
||||
to the right on-disk destination.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def resolve_dirs(self):
|
||||
with tempfile.TemporaryDirectory() as root:
|
||||
root_path = Path(root)
|
||||
input_dir = root_path / "input"
|
||||
output_dir = root_path / "output"
|
||||
checkpoints_dir = root_path / "models" / "checkpoints"
|
||||
diffusers_dir = root_path / "models" / "diffusers"
|
||||
loras_dir = root_path / "models" / "loras"
|
||||
for d in (input_dir, output_dir, checkpoints_dir, diffusers_dir, loras_dir):
|
||||
d.mkdir(parents=True)
|
||||
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
|
||||
mock_fp.get_input_directory.return_value = str(input_dir)
|
||||
mock_fp.get_output_directory.return_value = str(output_dir)
|
||||
mock_fp.folder_names_and_paths = {
|
||||
"checkpoints": ([str(checkpoints_dir)], None),
|
||||
"diffusers": ([str(diffusers_dir)], None),
|
||||
"loras": ([str(loras_dir)], None),
|
||||
}
|
||||
yield {
|
||||
"input": input_dir,
|
||||
"output": output_dir,
|
||||
"checkpoints": checkpoints_dir,
|
||||
"diffusers": diffusers_dir,
|
||||
"loras": loras_dir,
|
||||
}
|
||||
|
||||
def test_models_flat_category(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["models", "checkpoints"])
|
||||
assert base == str(resolve_dirs["checkpoints"])
|
||||
assert subdirs == []
|
||||
|
||||
def test_models_slash_joined_new_shape(self, resolve_dirs):
|
||||
# The shape get_name_and_tags_from_asset_path now emits.
|
||||
base, subdirs = resolve_destination_from_tags(
|
||||
["models", "diffusers/kolors/text_encoder"]
|
||||
)
|
||||
assert base == str(resolve_dirs["diffusers"])
|
||||
assert subdirs == ["kolors", "text_encoder"]
|
||||
|
||||
def test_models_legacy_one_tag_per_dir(self, resolve_dirs):
|
||||
# The legacy shape must still resolve identically.
|
||||
base, subdirs = resolve_destination_from_tags(
|
||||
["models", "diffusers", "kolors", "text_encoder"]
|
||||
)
|
||||
assert base == str(resolve_dirs["diffusers"])
|
||||
assert subdirs == ["kolors", "text_encoder"]
|
||||
|
||||
def test_models_loras_slash_joined(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(
|
||||
["models", "loras/my/custom/path"]
|
||||
)
|
||||
assert base == str(resolve_dirs["loras"])
|
||||
assert subdirs == ["my", "custom", "path"]
|
||||
|
||||
def test_input_no_subdir(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["input"])
|
||||
assert base == str(resolve_dirs["input"])
|
||||
assert subdirs == []
|
||||
|
||||
def test_input_slash_joined_subdir(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["input", "portraits/2026"])
|
||||
assert base == str(resolve_dirs["input"])
|
||||
assert subdirs == ["portraits", "2026"]
|
||||
|
||||
def test_output_slash_joined_subdir(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["output", "runs/abc"])
|
||||
assert base == str(resolve_dirs["output"])
|
||||
assert subdirs == ["runs", "abc"]
|
||||
|
||||
def test_unknown_category_rejected(self, resolve_dirs):
|
||||
with pytest.raises(ValueError, match="unknown model category"):
|
||||
resolve_destination_from_tags(["models", "not_a_real_category"])
|
||||
|
||||
def test_unknown_category_via_slash_joined(self, resolve_dirs):
|
||||
# First segment of a slash-joined tag must still match a registered category.
|
||||
with pytest.raises(ValueError, match="unknown model category 'bogus'"):
|
||||
resolve_destination_from_tags(["models", "bogus/sub/path"])
|
||||
|
||||
def test_traversal_in_subdir_rejected(self, resolve_dirs):
|
||||
with pytest.raises(ValueError, match="invalid path component"):
|
||||
resolve_destination_from_tags(["models", "checkpoints/..", "evil"])
|
||||
|
||||
@ -32,7 +32,7 @@ def test_seed_asset_removed_when_file_is_deleted(
|
||||
# Verify it is visible via API and carries no hash (seed)
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests/syncseed", "name_contains": name},
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body1 = r1.json()
|
||||
@ -52,7 +52,7 @@ def test_seed_asset_removed_when_file_is_deleted(
|
||||
# It should disappear (AssetInfo and seed Asset gone)
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests/syncseed", "name_contains": name},
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body2 = r2.json()
|
||||
@ -332,7 +332,7 @@ def test_fastpass_removes_stale_state_row_no_missing(
|
||||
|
||||
rl = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests/{scope}"},
|
||||
params={"include_tags": f"unit-tests,{scope}"},
|
||||
timeout=120,
|
||||
)
|
||||
bl = rl.json()
|
||||
|
||||
@ -280,15 +280,9 @@ def test_metadata_filename_is_set_for_seed_asset_without_hash(
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Scanner emits tags as ``[root, "<dir1>/<dir2>/..."]`` — the second tag
|
||||
# is the slash-joined parent subpath. For ``<root>/unit-tests/<scope>/a/b/<name>``
|
||||
# the second tag is ``"unit-tests/<scope>/a/b"``.
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": f"unit-tests/{scope}/a/b",
|
||||
"name_contains": name,
|
||||
},
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body = r1.json()
|
||||
|
||||
@ -1,69 +0,0 @@
|
||||
"""Unit tests for app.assets.helpers."""
|
||||
|
||||
from app.assets.helpers import expand_bucket_prefixes
|
||||
|
||||
|
||||
class TestExpandBucketPrefixes:
|
||||
def test_flat_category_unchanged(self):
|
||||
# `checkpoints` is already a standalone token, no expansion needed.
|
||||
assert expand_bucket_prefixes(["models", "checkpoints"]) == [
|
||||
"models",
|
||||
"checkpoints",
|
||||
]
|
||||
|
||||
def test_nested_category_inserts_bucket(self):
|
||||
# Path-derived shape for `models/checkpoints/flux/foo.safetensors` —
|
||||
# the standalone bucket has to be present so the FE set-membership
|
||||
# filter (`include_tags=models,checkpoints`) matches the asset.
|
||||
assert expand_bucket_prefixes(["models", "checkpoints/flux"]) == [
|
||||
"models",
|
||||
"checkpoints/flux",
|
||||
"checkpoints",
|
||||
]
|
||||
|
||||
def test_deeply_nested_only_first_segment_expands(self):
|
||||
# Only the FIRST slash segment ever gets emitted as a standalone —
|
||||
# intermediate path segments don't have routing significance.
|
||||
assert expand_bucket_prefixes(
|
||||
["models", "diffusers/kolors/text_encoder"]
|
||||
) == ["models", "diffusers/kolors/text_encoder", "diffusers"]
|
||||
|
||||
def test_unknown_prefix_does_not_expand(self):
|
||||
# Free-form user labels with slashes whose first segment is not a
|
||||
# registered bucket pass through opaquely.
|
||||
assert expand_bucket_prefixes(["models", "my-org/team-a"]) == [
|
||||
"models",
|
||||
"my-org/team-a",
|
||||
]
|
||||
|
||||
def test_idempotent(self):
|
||||
# Re-applying the helper is a no-op once the bucket is in the set.
|
||||
expanded = expand_bucket_prefixes(["models", "checkpoints/flux"])
|
||||
assert expand_bucket_prefixes(expanded) == expanded
|
||||
|
||||
def test_does_not_duplicate_existing_bucket(self):
|
||||
# If the caller already supplied the standalone bucket, don't add a
|
||||
# second copy.
|
||||
assert expand_bucket_prefixes(
|
||||
["models", "checkpoints/flux", "checkpoints"]
|
||||
) == ["models", "checkpoints/flux", "checkpoints"]
|
||||
|
||||
def test_preserves_caller_order(self):
|
||||
# User tags after path tags must stay after; the inserted bucket
|
||||
# token slots in immediately after its slash-joined parent so the
|
||||
# microsecond stagger lands it at path-tier before user-tier.
|
||||
assert expand_bucket_prefixes(
|
||||
["models", "loras/style", "favorite", "v2"]
|
||||
) == ["models", "loras/style", "loras", "favorite", "v2"]
|
||||
|
||||
def test_empty_input(self):
|
||||
assert expand_bucket_prefixes([]) == []
|
||||
|
||||
def test_input_root_with_subpath_no_expansion(self):
|
||||
# `portraits` isn't a registered model category, so the input
|
||||
# subpath stays opaque (FE filter doesn't have a checkpoint-loader
|
||||
# analogue for input subfolders).
|
||||
assert expand_bucket_prefixes(["input", "portraits/2026"]) == [
|
||||
"input",
|
||||
"portraits/2026",
|
||||
]
|
||||
@ -29,10 +29,7 @@ def create_seed_file(comfy_tmp_base_dir: Path):
|
||||
def find_asset(http: requests.Session, api_base: str):
|
||||
"""Query API for assets matching scope and optional name."""
|
||||
def _find(scope: str, name: str | None = None) -> list[dict]:
|
||||
# Scanner now emits tags as ``[root, "<dir1>/<dir2>/..."]`` rather than
|
||||
# one tag per directory. For files at ``<root>/unit-tests/<scope>/...``
|
||||
# the second tag is exactly ``"unit-tests/<scope>"``.
|
||||
params = {"include_tags": f"unit-tests/{scope}"}
|
||||
params = {"include_tags": f"unit-tests,{scope}"}
|
||||
if name:
|
||||
params["name_contains"] = name
|
||||
r = http.get(f"{api_base}/api/assets", params=params, timeout=120)
|
||||
@ -141,7 +138,4 @@ def test_special_chars_in_path_escaped_correctly(
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Scanner emits the full parent subpath as a single slash-joined tag, so
|
||||
# the lookup tag is ``unit-tests/<scope>`` even when <scope> itself
|
||||
# contains a slash (parent + special-char dirname).
|
||||
assert find_asset(scope, fp.name), "Asset with special chars should survive"
|
||||
assert find_asset(scope.split("/")[0], fp.name), "Asset with special chars should survive"
|
||||
|
||||
@ -1,135 +0,0 @@
|
||||
"""HTTP-layer smoke test: user-added tags via POST /api/assets/{id}/tags
|
||||
land after path tags when read back via GET /api/assets.
|
||||
|
||||
Exercises the full route handler -> service -> query path that the unit
|
||||
tests at tests-unit/assets_test/queries/test_asset_info.py only cover at
|
||||
the service layer.
|
||||
"""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def smoke_asset(http: requests.Session, api_base: str):
|
||||
"""Upload a single asset into models/checkpoints/unit-tests/smoke
|
||||
and delete it on teardown."""
|
||||
name = "smoke_user_tag.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "smoke"]
|
||||
files = {"file": (name, b"S" * 4096, "application/octet-stream")}
|
||||
form_data = {
|
||||
"tags": json.dumps(tags),
|
||||
"name": name,
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||
assert r.status_code == 201, r.text
|
||||
body = r.json()
|
||||
yield body
|
||||
http.delete(
|
||||
f"{api_base}/api/assets/{body['id']}?delete_content=true", timeout=30
|
||||
)
|
||||
|
||||
|
||||
def _fetch_asset_tags(http, api_base, ref_id):
|
||||
r = http.get(f"{api_base}/api/assets/{ref_id}", timeout=30)
|
||||
assert r.status_code == 200, r.text
|
||||
return r.json()["tags"]
|
||||
|
||||
|
||||
def test_user_tag_lands_after_path_tags_via_http(
|
||||
http: requests.Session, api_base: str, smoke_asset: dict
|
||||
):
|
||||
ref_id = smoke_asset["id"]
|
||||
|
||||
initial_tags = _fetch_asset_tags(http, api_base, ref_id)
|
||||
# Path tags should already be at the front in upload order.
|
||||
assert initial_tags[:2] == ["models", "checkpoints"]
|
||||
|
||||
# Add a user tag that would jump to position 0 under alphabetical sort.
|
||||
r = http.post(
|
||||
f"{api_base}/api/assets/{ref_id}/tags",
|
||||
json={"tags": ["aaa-user-tag"]},
|
||||
timeout=30,
|
||||
)
|
||||
assert r.status_code in (200, 201), r.text
|
||||
|
||||
tags_after = _fetch_asset_tags(http, api_base, ref_id)
|
||||
# Path tags must still be at the front; user tag goes to the end.
|
||||
assert tags_after[0] == "models"
|
||||
assert tags_after[1] == "checkpoints"
|
||||
assert "aaa-user-tag" in tags_after
|
||||
assert tags_after[-1] == "aaa-user-tag"
|
||||
|
||||
|
||||
def test_user_tag_batch_lands_after_path_tags_via_http(
|
||||
http: requests.Session, api_base: str, smoke_asset: dict
|
||||
):
|
||||
ref_id = smoke_asset["id"]
|
||||
|
||||
# Add three user tags in a single request, in non-alphabetical input
|
||||
# order. They should all land after the path tags (microsecond stagger
|
||||
# in set_reference_tags / add_tags_to_reference is what makes this
|
||||
# work — without it, "aaa" would jump to position 0).
|
||||
r = http.post(
|
||||
f"{api_base}/api/assets/{ref_id}/tags",
|
||||
json={"tags": ["zzz-z", "favorite", "aaa-experiment"]},
|
||||
timeout=30,
|
||||
)
|
||||
assert r.status_code in (200, 201), r.text
|
||||
|
||||
tags_after = _fetch_asset_tags(http, api_base, ref_id)
|
||||
assert tags_after[0] == "models"
|
||||
assert tags_after[1] == "checkpoints"
|
||||
user_tail = tags_after[len({"models", "checkpoints", "unit-tests", "smoke"}):]
|
||||
assert set(user_tail) >= {"zzz-z", "favorite", "aaa-experiment"}
|
||||
# Critically: alphabetical sort would put 'aaa-experiment' at position 0.
|
||||
assert tags_after.index("aaa-experiment") > tags_after.index("models")
|
||||
assert tags_after.index("aaa-experiment") > tags_after.index("checkpoints")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nested_checkpoint_asset(http: requests.Session, api_base: str):
|
||||
"""Upload a checkpoint at the slash-joined path shape cloud emits
|
||||
(`models/checkpoints/flux/...`), then delete it on teardown.
|
||||
"""
|
||||
name = "nested_checkpoint.safetensors"
|
||||
tags = ["models", "checkpoints/flux"]
|
||||
files = {"file": (name, b"S" * 4096, "application/octet-stream")}
|
||||
form_data = {
|
||||
"tags": json.dumps(tags),
|
||||
"name": name,
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||
assert r.status_code == 201, r.text
|
||||
body = r.json()
|
||||
yield body
|
||||
http.delete(
|
||||
f"{api_base}/api/assets/{body['id']}?delete_content=true", timeout=30
|
||||
)
|
||||
|
||||
|
||||
def test_nested_checkpoint_satisfies_fe_set_filter(
|
||||
http: requests.Session, api_base: str, nested_checkpoint_asset: dict
|
||||
):
|
||||
"""The case Simon flagged: a nested-path checkpoint must still match
|
||||
`include_tags=models,checkpoints` — the FE combo-widget filter.
|
||||
"""
|
||||
ref_id = nested_checkpoint_asset["id"]
|
||||
|
||||
stored = _fetch_asset_tags(http, api_base, ref_id)
|
||||
# tag[1] keeps cloud's slash-joined positional contract; tag[2] holds
|
||||
# the standalone bucket the FE filter looks for.
|
||||
assert stored[:3] == ["models", "checkpoints/flux", "checkpoints"]
|
||||
|
||||
# The actual FE query — exact set-membership across both tokens.
|
||||
r = http.get(
|
||||
f"{api_base}/api/assets",
|
||||
params=[("include_tags", "models"), ("include_tags", "checkpoints")],
|
||||
timeout=30,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
returned_ids = {a["id"] for a in r.json()["assets"]}
|
||||
assert ref_id in returned_ids
|
||||
Reference in New Issue
Block a user