Compare commits

..

15 Commits

Author SHA1 Message Date
be3e51250e openapi: add type enum to Workspace schema (cutover follow-up)
Cloud's Workspace runtime shape includes a 'type' field with enum
[personal, team] that vendor's Workspace was missing. Cloud handlers
reference the generated ingest.WorkspaceType Go enum.

Same kind of surgical addition as JobEntry.status / BillingStatus /
JobDetailResponse.status in this PR — adds cloud-runtime field to
existing vendor schema.
2026-05-22 18:05:19 -07:00
332acf6777 openapi: add enum values + FeedbackRequest schema for cloud cutover (PR E)
Adds missing cloud-runtime enum values to vendor schemas that the
cloud runtime emits but vendor declared as plain strings.

Changes:
  - JobEntry.status: enum [pending, in_progress, completed, failed, cancelled]
  - JobDetailResponse.status: same enum
  - BillingStatus: enum [awaiting_payment_method, pending_payment, paid,
      payment_failed, inactive]
  - FeedbackRequest schema added (with type enum)
  - /api/feedback POST: requestBody now $refs FeedbackRequest

All cloud-runtime-emitted; no impact on OSS-local semantics.

Identified via Comfy-Org/cloud's TestCutoverSafe gate (BE-1106) as
the remaining schema-level divergences after PRs A-D landed and got
synced.
2026-05-22 17:57:22 -07:00
c3c881f37b openapi: rename cloud-side response schemas to match runtime (PR D) (#14065)
* openapi: rename cloud-side response schemas to match runtime (PR D)

Follow-up to the BE-1106 stack (#14060/61/63). Cloud's Go handlers
reference response schemas by name (e.g., ingest.WorkflowResponse,
ingest.SubscribeResponse), but vendor's matching operations were
declaring those responses against differently-named vendor-side
schemas (CloudWorkflow, BillingSubscription, etc.). After the stack
landed, schemas like WorkflowResponse exist in vendor but weren't
referenced by any path, so codegen pruned the unreferenced types.

This PR:
  1. Updates 34 operation $refs in cloud-runtime paths to point to
     the schema names cloud's handlers expect (e.g., CloudWorkflow →
     WorkflowResponse on /api/workflows/{workflow_id}).
  2. Adds 12 cloud-only schemas that weren't in vendor yet but are
     referenced by these renames (e.g., SubscribeResponse,
     CancelSubscriptionResponse, BillingOpStatusResponse). Each
     copied verbatim from Comfy-Org/cloud's hand-written ingest spec
     and tagged x-runtime: [cloud] with a [cloud-only] description
     prefix.

Schema renames span the same domains as the operationId renames in
PR A: billing/subscriptions (7 schemas), workflows (5), userdata (3),
jobs (2), hub (2), history (2), auth/workspace (4), and misc cloud
endpoints (9).

Convergent safety check after this lands (against cloud's
TestCutoverSafe gate, BE-1106):
  Pre-PR D:   205 missing handler refs
  Post-PR D:  105 missing handler refs (-49%)
  Cumulative since the original 938-ref baseline: -89%

The remaining 105 are a Phase 3 follow-up (response headers,
text/plain responses, codegen-derived enum sub-types, and a small
set of inline-response-schema operations that vendor declares
inline where cloud has named-schema $refs).

* openapi: drop PR-label comment from new schemas block

PR-internal labels don't belong in committed code — future readers
won't know what 'PR D' means and the marker stops being useful the
moment this PR merges.
2026-05-22 16:34:52 -07:00
7984a6a38e openapi: rename 55 cloud-side operationIds to match runtime (PR A of 3) (#14060)
* openapi: rename 55 cloud-side operationIds to match runtime handlers

For the 55 operations below, vendor's operationId did not match the
name cloud's runtime handlers expect. Generated types from vendor
therefore had different names (e.g. CreateSubscription200JSONResponse)
than what cloud handlers reference (Subscribe200JSONResponse), which
blocks the post-cutover combined-spec codegen.

All 55 renames target the cloud-runtime-authoritative name. Several
of these endpoints are shared concepts (queue, settings, userdata,
object_info) that OSS local also serves — the rename aligns vendor
with the longstanding cloud handler-side convention to unblock the
shared codegen. No request/response *shape* changes in this PR; only
operationId labels.

Notable categories:
  - Billing/subscriptions: 7 renames (subscribe, getBillingPlans, ...)
  - Workspace + workflows: 13 renames (createWorkflow, ...)
  - Hub: 3 renames
  - Auth/users: 5 renames
  - Shared OSS surface (settings, queue, view, userdata): 12 renames
  - Misc cloud-only: 15 renames

Identified via Comfy-Org/cloud's TestCutoverSafe build-safety gate
(BE-1106), which compares handler type references against codegen
output from the combined spec.

* fix(openapi): resolve getHistory operationId collision

Spectral flagged: both /api/history (OSS local) and /api/history_v2
(cloud) had operationId 'getHistory' after the rename. Rename vendor's
/api/history to 'getPromptHistory' to disambiguate. Cloud's runtime
denies /api/history at the overlay level so combined codegen is
unaffected by this change.

* openapi: add 41 cloud-runtime schemas to components.schemas (PR B of 3) (#14061)

* openapi: add 41 cloud-runtime schemas to components.schemas (cutover prep)

Adds schemas that exist in Comfy-Org/cloud's hand-written ingest spec
but not yet in this vendored OSS spec. All tagged x-runtime: [cloud]
per the field-drift convention and prefixed with [cloud-only] in the
description.

These schemas are referenced by cloud's Go handlers via the generated
ingest.<Schema> Go type names. Codegen from the vendored spec didn't
produce those types because the schemas weren't declared here. Adding
them unblocks the post-cutover combined-spec codegen.

Schemas added (alphabetical):
  AssetDownloadResponse, AssetMetadataResponse, BillingBalanceResponse,
  BillingPlansResponse, BillingStatusResponse, GetUserDataResponseFull,
  HistoryDetailEntry, HistoryDetailResponse, HistoryResponse,
  HubLabelInfo, HubProfileSummary, HubWorkflowListResponse,
  HubWorkflowStatus, HubWorkflowSummary, HubWorkflowTemplateEntry,
  JobStatusResponse, JobsListResponse, LabelRef, LogsResponse, Member,
  OAuthRegisterBadRequestResponse, PendingInvite, Plan, PlanAvailability,
  PlanAvailabilityReason, PlanSeatSummary, PreviewPlanInfo,
  PreviewSubscribeResponse, PublishedWorkflowDetail, SecretResponse,
  SubscriptionDuration, SubscriptionTier, UserDataResponseFull,
  ValidationError, ValidationResult, WorkflowForkedFrom, WorkflowResponse,
  WorkflowVersionContentResponse, WorkspaceAPIKeyInfo, WorkspaceSummary,
  WorkspaceWithRole

Identified via Comfy-Org/cloud's TestCutoverSafe build-safety gate
(BE-1106). Companion to PR #14060 (operationId renames).

* fix(openapi): add BindingErrorResponse schema

OAuthRegisterBadRequestResponse references BindingErrorResponse but
that schema wasn't in the original add. Adding it now as a cloud-only
schema matching the cloud runtime's binding-error shape (single
'message' string field).

* openapi: add missing 4xx/5xx response bodies for cloud-emitting endpoints (#14063)

Vendor declares shared endpoints (e.g. /api/queue, /api/settings,
/api/assets/*, /api/billing/*) with success responses but is missing
many of the 4xx/5xx error response bodies that Comfy-Org/cloud's
runtime actually emits. Cloud's Go handlers reference the generated
ingest.Op<StatusCode>JSONResponse types for these missing statuses,
which currently fail to resolve when codegen runs against the
vendored spec.

This PR adds 237 response entries across 117 operations, restoring
the documented error responses that cloud emits. Bodies are copied
verbatim from Comfy-Org/cloud's hand-written ingest spec
(services/ingest/openapi.yaml) and reference a new ErrorResponse
schema also added in this PR (matches cloud's {code, message} runtime
shape, tagged x-runtime: [cloud]).

ErrorResponse is intentionally separate from the existing CloudError
schema. CloudError's shape ({error}) describes one runtime; cloud
emits a different shape ({code, message}). Existing CloudError refs
in vendor are untouched; new cloud-emitting error references use
ErrorResponse.

Identified via Comfy-Org/cloud's TestCutoverSafe build-safety gate
(BE-1106). Companion to PR #14060 (operationId renames) and PR #14061
(cloud-only schema additions).
2026-05-22 16:15:18 -07:00
e75b739c1d Delete the source branch after doing the backport. (#14062) 2026-05-22 15:47:03 -07:00
112fcd5f3b openapi: align response declarations with implementation (5 endpoints) (#14058)
* openapi: align response declarations with implementation (5 endpoints)

- POST /api/assets/download: replace 200 with 202 + tracking-task body
  (endpoint runs asynchronously and returns task_id/status/message).
- POST /api/assets/export: same 200 → 202 + tracking-task body.
- POST /api/assets/from-workflow: change 201 → 200 (handler responds 200,
  not 201; no Location header emitted).
- POST /api/feedback: change 200 → 201 (creates a feedback record).
- /api/jobs and /api/jobs/{job_id}: change timestamp fields from
  type: number to type: integer + format: int64. Values are Unix
  milliseconds — number causes oapi-codegen to emit float64, losing
  precision and producing the wrong Go type. Affected fields:
  create_time, update_time, execution_start_time, execution_end_time.

Verification: each change reflects what the endpoint observably returns;
no handler changes required. Backwards-compatible for existing clients
(integer is a subset of number; status code shifts within 2xx).

* openapi: align asset download/export 202 status enum with runtime + sibling schemas

CodeRabbit caught a vocabulary mismatch: the two new 202 response schemas
declared `[pending, running, completed, failed]` while the rest of the same
spec uses `[created, running, completed, failed]` for the identical task
lifecycle (download/export progress WebSocket events, /api/tasks, TaskEntry,
TaskResponse — 4 sites total). Cloud's runtime emits `created` on initial
creation (AssetDownloadResponseStatusCreated; task.Status sourced from the
DB enum whose initial value is Created). `pending` would have introduced a
fifth, contradictory vocabulary for the same lifecycle and pushed the spec
further from the implementation it is meant to align with.

Followup tracked separately: extract a shared TaskStatus enum so all five
sites move in lockstep instead of needing per-site edits.
2026-05-22 14:31:43 -07:00
1579bbb52d [Partner Nodes] add new Rodin2.5 nodes (#14051)
* [Partner Nodes] add new Rodin2.5 nodes

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] fixed Quality Mesh Options

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] fix: remove non-supported "usdz"

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] fix: always pass seed to server

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] fix: set the default "material" value to "Shaded"

Signed-off-by: bigcat88 <bigcat88@icloud.com>

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-22 09:07:21 -07:00
93888ae8e3 Move logic nodes into utils category (#14033) 2026-05-22 13:32:08 +08:00
38ebc19037 Adding in And, Or, and Not nodes. (#14004) 2026-05-22 11:01:12 +08:00
9650570378 Update Discord invite link in README.md (#14045) 2026-05-21 19:52:38 -07:00
f48c32871b fe: Consolidate warnings (#13970) 2026-05-22 10:18:13 +08:00
8edff549e3 Update backport workflow to use commit SHA input (#14043) 2026-05-21 18:22:47 -07:00
8fecef0686 Add validation for source branch in backport workflow (#14042) 2026-05-21 16:39:19 -07:00
5d681a5420 Fix SIGPIPE false negative in backport release validation (#14041) 2026-05-21 16:29:08 -07:00
32e58393b8 Add backport release workflow. (#14038) 2026-05-21 14:49:55 -07:00
25 changed files with 4413 additions and 3123 deletions

519
.github/workflows/backport_release.yaml vendored Normal file
View File

@ -0,0 +1,519 @@
name: Backport Release
on:
workflow_dispatch:
inputs:
commit:
description: 'Full 40-char SHA of the tip commit of the backport source branch (the PR head commit that passed tests). The branch is resolved from this SHA and must be unique.'
required: true
type: string
permissions:
contents: read
pull-requests: read
checks: read
jobs:
backport-release:
name: Create backport release
runs-on: ubuntu-latest
environment: backport release
steps:
- name: Generate GitHub App token
id: app-token
uses: actions/create-github-app-token@bcd2ba49218906704ab6c1aa796996da409d3eb1
with:
app-id: ${{ secrets.FEN_RELEASE_APP_ID }}
private-key: ${{ secrets.FEN_RELEASE_PRIVATE_KEY }}
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
with:
token: ${{ steps.app-token.outputs.token }}
fetch-depth: 0
fetch-tags: true
- name: Configure git
run: |
git config user.name "fen-release[bot]"
git config user.email "fen-release[bot]@users.noreply.github.com"
- name: Resolve source branch from commit SHA
id: resolve
env:
SOURCE_COMMIT: ${{ inputs.commit }}
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
run: |
set -euo pipefail
# Require a full 40-char lowercase-hex SHA. Short SHAs are ambiguous
# and we will be comparing this value against API responses (PR head
# SHA, ref tips) that always return the full form.
if [[ ! "${SOURCE_COMMIT}" =~ ^[0-9a-f]{40}$ ]]; then
echo "::error::Input commit '${SOURCE_COMMIT}' is not a full 40-char lowercase hex SHA."
exit 1
fi
# Fetch all remote branches so we can search for which one(s) point
# at this SHA. `actions/checkout` with fetch-depth: 0 fetches full
# history of the checked-out ref but does not necessarily populate
# every refs/remotes/origin/*, so do it explicitly.
git fetch --prune origin '+refs/heads/*:refs/remotes/origin/*'
# Verify the commit actually exists in this repo's object DB.
if ! git cat-file -e "${SOURCE_COMMIT}^{commit}" 2>/dev/null; then
echo "::error::Commit ${SOURCE_COMMIT} was not found in the repository."
exit 1
fi
# Find every remote branch whose tip == SOURCE_COMMIT. Exactly one
# branch must point at it. If zero, the commit isn't anyone's tip
# (likely stale, force-pushed past, or never the PR head). If more
# than one, the (branch -> SHA) mapping is ambiguous and we refuse
# to guess — the operator must give us a unique branch to release.
mapfile -t matching_branches < <(
git for-each-ref \
--format='%(refname:strip=3)' \
--points-at="${SOURCE_COMMIT}" \
refs/remotes/origin/ \
| grep -vx 'HEAD' || true
)
if [[ "${#matching_branches[@]}" -eq 0 ]]; then
echo "::error::No branch on origin has ${SOURCE_COMMIT} as its tip."
echo "::error::Either the branch was updated after you copied this SHA, or this commit was never the head of a branch."
exit 1
fi
if [[ "${#matching_branches[@]}" -gt 1 ]]; then
echo "::error::More than one branch on origin has ${SOURCE_COMMIT} as its tip; cannot pick one:"
for b in "${matching_branches[@]}"; do
echo "::error:: - ${b}"
done
echo "::error::Refusing to proceed with an ambiguous source branch."
exit 1
fi
source_branch="${matching_branches[0]}"
if [[ "${source_branch}" == "${DEFAULT_BRANCH}" ]]; then
echo "::error::Source branch must not be the default branch ('${DEFAULT_BRANCH}')."
exit 1
fi
echo "Resolved commit ${SOURCE_COMMIT} to branch '${source_branch}'."
echo "source_branch=${source_branch}" >> "$GITHUB_OUTPUT"
- name: Determine latest stable release
id: latest
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
run: |
set -euo pipefail
# List all tags matching vMAJOR.MINOR.PATCH and pick the highest by numeric
# comparison of each component. We DO NOT use `sort -V` because it treats
# v0.19.99 as higher than v0.20.1.
latest_tag="$(
git tag --list 'v[0-9]*.[0-9]*.[0-9]*' \
| grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' \
| awk -F'[v.]' '{ printf "%010d %010d %010d %s\n", $2, $3, $4, $0 }' \
| sort -k1,1n -k2,2n -k3,3n \
| tail -n1 \
| awk '{print $4}'
)"
if [[ -z "${latest_tag}" ]]; then
echo "::error::No stable release tags (vMAJOR.MINOR.PATCH) were found."
exit 1
fi
# Parse components
ver="${latest_tag#v}"
major="${ver%%.*}"
rest="${ver#*.}"
minor="${rest%%.*}"
patch="${rest#*.}"
new_patch=$((patch + 1))
new_version="v${major}.${minor}.${new_patch}"
release_branch="release/v${major}.${minor}"
latest_sha="$(git rev-list -n 1 "refs/tags/${latest_tag}")"
echo "latest_tag=${latest_tag}" >> "$GITHUB_OUTPUT"
echo "latest_sha=${latest_sha}" >> "$GITHUB_OUTPUT"
echo "major=${major}" >> "$GITHUB_OUTPUT"
echo "minor=${minor}" >> "$GITHUB_OUTPUT"
echo "patch=${patch}" >> "$GITHUB_OUTPUT"
echo "new_version=${new_version}" >> "$GITHUB_OUTPUT"
echo "new_version_no_v=${major}.${minor}.${new_patch}" >> "$GITHUB_OUTPUT"
echo "release_branch=${release_branch}" >> "$GITHUB_OUTPUT"
echo "Latest stable release: ${latest_tag} (${latest_sha})"
echo "New version will be: ${new_version}"
echo "Release branch: ${release_branch}"
- name: Validate source branch is cut directly from the latest stable release
env:
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
run: |
set -euo pipefail
# Use the user-provided SHA directly rather than re-resolving the branch
# tip — the resolve step already proved the branch tip equals SOURCE_COMMIT,
# and pinning to the SHA here makes the rest of the job TOCTOU-safe against
# someone pushing to the branch mid-run.
source_sha="${SOURCE_COMMIT}"
# Walking first-parent from the source tip must reach LATEST_TAG_SHA.
# We capture rev-list into a variable and grep against a here-string
# rather than piping `rev-list | grep -q`: under `set -o pipefail`,
# `grep -q` would exit on first match and SIGPIPE the still-streaming
# `rev-list`, propagating exit 141 as a spurious "not found".
first_parent_chain="$(git rev-list --first-parent "${source_sha}")"
if ! grep -Fxq "${LATEST_TAG_SHA}" <<< "${first_parent_chain}"; then
echo "::error::Source branch '${SOURCE_BRANCH}' is not cut from '${LATEST_TAG}'."
echo "::error::Its first-parent history does not include ${LATEST_TAG_SHA}."
exit 1
fi
# Additionally, every commit added on top of the tag (the set we are
# about to publish) must itself be a descendant of the tag along
# first-parent — i.e. no sibling commits from master sneak in via a
# non-first-parent path. Enforce by requiring that the symmetric
# difference is empty in one direction: commits in source that are
# NOT first-parent-reachable from source starting at the tag.
# We do this by intersecting:
# A = commits reachable from source but not from tag (full DAG)
# B = commits on the first-parent chain from source down to tag
# and requiring A == B.
all_added="$(git rev-list "${LATEST_TAG_SHA}..${source_sha}" | sort)"
first_parent_added="$(
git rev-list --first-parent "${LATEST_TAG_SHA}..${source_sha}" | sort
)"
if [[ "${all_added}" != "${first_parent_added}" ]]; then
echo "::error::Source branch '${SOURCE_BRANCH}' contains commits not on its first-parent chain from '${LATEST_TAG}'."
echo "::error::This usually means the branch was cut from master (not from the tag) or contains a merge from master."
echo "Commits reachable but not on first-parent chain:"
comm -23 <(printf '%s\n' "${all_added}") <(printf '%s\n' "${first_parent_added}") \
| while read -r sha; do
echo " $(git log -1 --format='%h %s' "${sha}")"
done
exit 1
fi
added_count="$(printf '%s\n' "${all_added}" | grep -c . || true)"
echo "Source branch is cut directly from ${LATEST_TAG} with ${added_count} commit(s) on top."
- name: Validate PR exists, is open, named correctly, has latest commit, and checks pass
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
REPO: ${{ github.repository }}
run: |
set -euo pipefail
expected_title="ComfyUI backport release ${NEW_VERSION}"
# Find open PRs from this branch into master. The --state open filter
# is load-bearing: a closed/merged PR with passing checks must not be
# accepted as authorization for a new release.
pr_json="$(
gh pr list \
--repo "${REPO}" \
--state open \
--head "${SOURCE_BRANCH}" \
--base master \
--json number,title,headRefOid,state \
--limit 10
)"
pr_count="$(echo "${pr_json}" | jq 'length')"
if [[ "${pr_count}" -eq 0 ]]; then
echo "::error::No open PR found from '${SOURCE_BRANCH}' into 'master'. The PR must exist and be open."
exit 1
fi
# Pick the PR matching the expected title
pr_number="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
map(select(.title == $t)) | .[0].number // empty
')"
pr_head_sha="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
map(select(.title == $t)) | .[0].headRefOid // empty
')"
if [[ -z "${pr_number}" ]]; then
echo "::error::No open PR from '${SOURCE_BRANCH}' into 'master' is titled '${expected_title}'."
echo "Found PRs:"
echo "${pr_json}" | jq -r '.[] | " #\(.number): \(.title)"'
exit 1
fi
# The PR's current head commit must equal the SHA the operator gave us.
# This is what closes the door on releasing stale code: if anyone has
# pushed to the branch since the operator validated tests passed, the
# PR head will have advanced past SOURCE_COMMIT and we abort. (The
# resolve step already proved the branch tip == SOURCE_COMMIT; this
# ties that same SHA to the PR that authorizes the release.)
if [[ "${pr_head_sha}" != "${SOURCE_COMMIT}" ]]; then
echo "::error::PR #${pr_number} head commit is ${pr_head_sha}, but the operator-provided commit is ${SOURCE_COMMIT}."
echo "::error::The PR has new commits since this release was authorized. Re-run with the new head SHA after verifying its checks."
exit 1
fi
echo "Found open PR #${pr_number} titled '${expected_title}' at head ${pr_head_sha} (matches operator-provided commit)."
# Verify all check runs on the head commit have completed successfully.
# A check is considered passing if conclusion is success, neutral, or skipped.
checks_json="$(
gh api \
--paginate \
"repos/${REPO}/commits/${pr_head_sha}/check-runs" \
--jq '.check_runs[] | {name: .name, status: .status, conclusion: .conclusion}'
)"
if [[ -z "${checks_json}" ]]; then
echo "::error::No check runs found on PR head commit ${pr_head_sha}."
exit 1
fi
echo "Check runs on ${pr_head_sha}:"
echo "${checks_json}" | jq -s '.'
failing="$(echo "${checks_json}" | jq -s '
map(select(
.status != "completed"
or (.conclusion as $c
| ["success","neutral","skipped"]
| index($c) | not)
))
')"
failing_count="$(echo "${failing}" | jq 'length')"
if [[ "${failing_count}" -gt 0 ]]; then
echo "::error::One or more checks have not passed on PR head commit ${pr_head_sha}:"
echo "${failing}" | jq -r '.[] | " - \(.name): status=\(.status) conclusion=\(.conclusion)"'
exit 1
fi
echo "All checks have passed on ${pr_head_sha}."
- name: Prepare release branch
id: prepare
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
REPO: ${{ github.repository }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
PATCH: ${{ steps.latest.outputs.patch }}
run: |
set -euo pipefail
# Try to fetch the release branch. If patch == 0, it shouldn't exist yet
# and we'll create it from the latest stable tag. If patch > 0, it must
# already exist and its tip must equal the latest stable tag commit (i.e.
# the previous patch release).
if git ls-remote --exit-code --heads origin "${RELEASE_BRANCH}" >/dev/null 2>&1; then
echo "Release branch '${RELEASE_BRANCH}' already exists on origin."
git fetch origin "refs/heads/${RELEASE_BRANCH}:refs/remotes/origin/${RELEASE_BRANCH}"
git checkout -B "${RELEASE_BRANCH}" "refs/remotes/origin/${RELEASE_BRANCH}"
current_tip="$(git rev-parse HEAD)"
if [[ "${current_tip}" != "${LATEST_TAG_SHA}" ]]; then
echo "::error::Release branch '${RELEASE_BRANCH}' tip (${current_tip}) is not at the latest stable release '${LATEST_TAG}' (${LATEST_TAG_SHA})."
echo "::error::Refusing to release on top of a divergent branch."
exit 1
fi
echo "branch_existed=true" >> "$GITHUB_OUTPUT"
else
if [[ "${PATCH}" != "0" ]]; then
echo "::error::Release branch '${RELEASE_BRANCH}' does not exist on origin, but the latest stable release '${LATEST_TAG}' has patch=${PATCH} (>0). This is inconsistent."
exit 1
fi
echo "Release branch '${RELEASE_BRANCH}' does not exist. Creating from ${LATEST_TAG}."
git checkout -B "${RELEASE_BRANCH}" "refs/tags/${LATEST_TAG}"
echo "branch_existed=false" >> "$GITHUB_OUTPUT"
fi
- name: Fast-forward merge source branch into release branch
env:
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
run: |
set -euo pipefail
# --ff-only guarantees no merge commit is created. If a fast-forward is
# not possible (i.e. the release branch has commits the source branch
# doesn't), the merge will fail and we abort. Because we already validated
# that the source branch is rooted on the latest stable tag, and the
# release branch tip equals that same tag, this fast-forward should
# always succeed for a well-formed backport branch.
#
# We merge the operator-provided SHA, not the branch ref, so a push to
# the branch in the window between resolve and now cannot smuggle new
# commits into the release.
if ! git merge --ff-only "${SOURCE_COMMIT}"; then
echo "::error::Cannot fast-forward '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}'). A merge commit would be required. Aborting."
exit 1
fi
echo "Fast-forwarded '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}')."
- name: Bump version files
env:
NEW_VERSION_NO_V: ${{ steps.latest.outputs.new_version_no_v }}
run: |
set -euo pipefail
if [[ ! -f comfyui_version.py ]]; then
echo "::error::comfyui_version.py not found in repo root."
exit 1
fi
if [[ ! -f pyproject.toml ]]; then
echo "::error::pyproject.toml not found in repo root."
exit 1
fi
# Replace the version string in comfyui_version.py.
# Expected format: __version__ = "X.Y.Z"
python3 - "$NEW_VERSION_NO_V" <<'PY'
import re, sys, pathlib
new = sys.argv[1]
p = pathlib.Path("comfyui_version.py")
src = p.read_text()
new_src, n = re.subn(
r'(__version__\s*=\s*[\'"])[^\'"]+([\'"])',
lambda m: f'{m.group(1)}{new}{m.group(2)}',
src,
count=1,
)
if n != 1:
sys.exit("Could not find __version__ assignment in comfyui_version.py")
p.write_text(new_src)
p = pathlib.Path("pyproject.toml")
src = p.read_text()
# Replace the first `version = "..."` inside [project] or [tool.poetry].
new_src, n = re.subn(
r'(?m)^(version\s*=\s*")[^"]+(")',
lambda m: f'{m.group(1)}{new}{m.group(2)}',
src,
count=1,
)
if n != 1:
sys.exit("Could not find version assignment in pyproject.toml")
p.write_text(new_src)
PY
echo "Updated version to ${NEW_VERSION_NO_V} in comfyui_version.py and pyproject.toml."
git --no-pager diff -- comfyui_version.py pyproject.toml
- name: Commit version bump and tag release
env:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
run: |
set -euo pipefail
git add comfyui_version.py pyproject.toml
git commit -m "ComfyUI ${NEW_VERSION}"
if git rev-parse -q --verify "refs/tags/${NEW_VERSION}" >/dev/null; then
echo "::error::Tag ${NEW_VERSION} already exists locally."
exit 1
fi
git tag "${NEW_VERSION}"
- name: Verify tag does not already exist on origin
env:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
run: |
set -euo pipefail
if git ls-remote --exit-code --tags origin "refs/tags/${NEW_VERSION}" >/dev/null 2>&1; then
echo "::error::Tag ${NEW_VERSION} already exists on origin. Aborting."
exit 1
fi
- name: Push release branch and tag
env:
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
run: |
set -euo pipefail
# Push the branch first, then the tag. Atomic-ish: if the branch push
# fails we never publish the tag.
git push origin "refs/heads/${RELEASE_BRANCH}:refs/heads/${RELEASE_BRANCH}"
git push origin "refs/tags/${NEW_VERSION}"
echo "Released ${NEW_VERSION} on ${RELEASE_BRANCH}."
- name: Delete remote source branch
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
REPO: ${{ github.repository }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
run: |
set -euo pipefail
# Belt-and-braces: the resolve step already refuses the default branch,
# but never delete the default or the release branch under any
# circumstances.
if [[ "${SOURCE_BRANCH}" == "${DEFAULT_BRANCH}" || "${SOURCE_BRANCH}" == "${RELEASE_BRANCH}" ]]; then
echo "::error::Refusing to delete '${SOURCE_BRANCH}' (matches default or release branch)."
exit 1
fi
# Delete the source branch on origin, but only if its tip is still the
# SHA we released from. If someone pushed new commits to it after we
# resolved it, leave it alone — those commits would be silently lost.
current_tip="$(git ls-remote origin "refs/heads/${SOURCE_BRANCH}" | awk '{print $1}')"
if [[ -z "${current_tip}" ]]; then
echo "Source branch '${SOURCE_BRANCH}' no longer exists on origin; nothing to delete."
exit 0
fi
if [[ "${current_tip}" != "${SOURCE_COMMIT}" ]]; then
echo "::warning::Source branch '${SOURCE_BRANCH}' tip (${current_tip}) no longer matches released commit (${SOURCE_COMMIT}). Leaving it in place."
exit 0
fi
git push origin --delete "refs/heads/${SOURCE_BRANCH}"
echo "Deleted remote branch '${SOURCE_BRANCH}'."
- name: Summary
if: always()
env:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
run: |
# SOURCE_BRANCH is empty if the resolve step never produced an output
# (e.g. the workflow failed in or before that step). Show a placeholder
# in that case so the summary table still renders cleanly.
source_branch_display="${SOURCE_BRANCH:-(unresolved)}"
{
echo "## Backport release"
echo ""
echo "| Field | Value |"
echo "|---|---|"
echo "| Source commit | \`${SOURCE_COMMIT}\` |"
echo "| Source branch | \`${source_branch_display}\` |"
echo "| Previous stable | \`${LATEST_TAG}\` |"
echo "| New version | \`${NEW_VERSION}\` |"
echo "| Release branch | \`${RELEASE_BRANCH}\` |"
} >> "$GITHUB_STEP_SUMMARY"

View File

@ -20,7 +20,7 @@
[website-url]: https://www.comfy.org/
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
[discord-url]: https://www.comfy.org/discord
[discord-url]: https://discord.com/invite/comfyorg
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
[twitter-url]: https://x.com/ComfyUI

View File

@ -62,6 +62,8 @@ def get_comfy_package_versions():
def check_comfy_packages_versions():
"""Warn for every comfy* package whose installed version is below requirements.txt."""
from packaging.version import InvalidVersion, parse as parse_pep440
outdated_packages = []
for pkg in get_comfy_package_versions():
installed_str = pkg["installed"]
required_str = pkg["required"]
@ -73,19 +75,26 @@ def check_comfy_packages_versions():
logging.error(f"Failed to check {pkg['name']} version: {e}")
continue
if outdated:
app.logger.log_startup_warning(
f"""
outdated_packages.append((pkg["name"], installed_str, required_str))
else:
logging.info("{} version: {}".format(pkg["name"], installed_str))
if outdated_packages:
package_warnings = "\n".join(
f"Installed {name} version {installed} is lower than the recommended version {required}."
for name, installed, required in outdated_packages
)
app.logger.log_startup_warning(
f"""
________________________________________________________________________
WARNING WARNING WARNING WARNING WARNING
Installed {pkg["name"]} version {installed_str} is lower than the recommended version {required_str}.
{package_warnings}
{get_missing_requirements_message()}
________________________________________________________________________
""".strip()
)
else:
logging.info("{} version: {}".format(pkg["name"], installed_str))
)
REQUEST_TIMEOUT = 10 # seconds

View File

@ -1,13 +1,7 @@
import torch
import torch.nn.functional as F
from comfy.text_encoders.bert import BertAttention
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.ldm.depth_anything_3.reference_view_selector import (
select_reference_view, reorder_by_reference, restore_original_order,
THRESH_FOR_REF_SELECTION,
)
class Dino2AttentionOutput(torch.nn.Module):
@ -20,42 +14,13 @@ class Dino2AttentionOutput(torch.nn.Module):
class Dino2AttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations,
qk_norm=False):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.heads = heads
self.head_dim = embed_dim // heads
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
if qk_norm:
self.q_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
self.k_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
else:
self.q_norm = None
self.k_norm = None
def forward(self, x, mask, optimized_attention, pos=None, rope=None):
# Fast path used by the existing CLIP-vision DINOv2 (no DA3 extensions).
if self.q_norm is None and rope is None:
return self.output(self.attention(x, mask, optimized_attention))
# DA3 path: do QKV manually so we can apply per-head QK-norm and 2D RoPE.
attn = self.attention
B, N, C = x.shape
h = self.heads
d = self.head_dim
q = attn.query(x).view(B, N, h, d).transpose(1, 2)
k = attn.key(x).view(B, N, h, d).transpose(1, 2)
v = attn.value(x).view(B, N, h, d).transpose(1, 2)
if self.q_norm is not None:
q = self.q_norm(q)
k = self.k_norm(k)
if rope is not None and pos is not None:
q = rope(q, pos)
k = rope(k, pos)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
out = out.transpose(1, 2).reshape(B, N, C)
return self.output(out)
def forward(self, x, mask, optimized_attention):
return self.output(self.attention(x, mask, optimized_attention))
class LayerScale(torch.nn.Module):
@ -99,11 +64,9 @@ class SwiGLUFFN(torch.nn.Module):
class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn,
qk_norm=False):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations,
qk_norm=qk_norm)
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
if use_swiglu_ffn:
@ -113,90 +76,19 @@ class Dino2Block(torch.nn.Module):
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, x, optimized_attention, pos=None, rope=None, attn_mask=None):
x = x + self.layer_scale1(self.attention(self.norm1(x), attn_mask, optimized_attention,
pos=pos, rope=rope))
def forward(self, x, optimized_attention):
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
return x
# -----------------------------------------------------------------------------
# 2D Rotary position embedding (DA3 extension)
# -----------------------------------------------------------------------------
class _PositionGetter:
"""Cache (h, w) -> flat (y, x) position grid used to feed ``rope``."""
def __init__(self):
self._cache: dict = {}
def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor:
key = (height, width, device)
if key not in self._cache:
y = torch.arange(height, device=device)
x = torch.arange(width, device=device)
self._cache[key] = torch.cartesian_prod(y, x)
cached = self._cache[key]
return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
class RotaryPositionEmbedding2D(torch.nn.Module):
"""2D RoPE used by DA3-Small/Base. No learnable parameters."""
def __init__(self, frequency: float = 100.0):
super().__init__()
self.base_frequency = frequency
self._freq_cache: dict = {}
def _components(self, dim: int, seq_len: int, device, dtype):
key = (dim, seq_len, device, dtype)
if key not in self._freq_cache:
exp = torch.arange(0, dim, 2, device=device).float() / dim
inv_freq = 1.0 / (self.base_frequency ** exp)
pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
ang = torch.einsum("i,j->ij", pos, inv_freq)
ang = ang.to(dtype)
ang = torch.cat((ang, ang), dim=-1)
self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype))
return self._freq_cache[key]
@staticmethod
def _rotate(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1]
x1, x2 = x[..., : d // 2], x[..., d // 2:]
return torch.cat((-x2, x1), dim=-1)
def _apply_1d(self, tokens, positions, cos_c, sin_c):
cos = F.embedding(positions, cos_c)[:, None, :, :]
sin = F.embedding(positions, sin_c)[:, None, :, :]
return (tokens * cos) + (self._rotate(tokens) * sin)
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
feature_dim = tokens.size(-1) // 2
max_pos = int(positions.max()) + 1
cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype)
v, h = tokens.chunk(2, dim=-1)
v = self._apply_1d(v, positions[..., 0], cos_c, sin_c)
h = self._apply_1d(h, positions[..., 1], cos_c, sin_c)
return torch.cat((v, h), dim=-1)
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn,
qknorm_start: int = -1):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
super().__init__()
self.layer = torch.nn.ModuleList([
Dino2Block(
dim, num_heads, layer_norm_eps, dtype, device, operations,
use_swiglu_ffn=use_swiglu_ffn,
qk_norm=(qknorm_start != -1 and i >= qknorm_start),
)
for i in range(num_layers)
])
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
for _ in range(num_layers)])
def forward(self, x, intermediate_output=None):
# Backward-compat path used by ``ClipVisionModel`` (no DA3 extensions).
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
if intermediate_output is not None:
@ -230,27 +122,16 @@ class Dino2PatchEmbeddings(torch.nn.Module):
class Dino2Embeddings(torch.nn.Module):
def __init__(self, dim, dtype, device, operations,
patch_size: int = 14, image_size: int = 518,
use_mask_token: bool = True,
num_camera_tokens: int = 0):
def __init__(self, dim, dtype, device, operations):
super().__init__()
patch_size = 14
image_size = 518
self.patch_size = patch_size
self.image_size = image_size
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) # mask_token is a pre-training param, kept only so strict loading accepts the key.
if use_mask_token:
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
else:
self.mask_token = None
if num_camera_tokens > 0:
# DA3 stores (ref_token, src_token) pairs that get injected at the
# alt-attn boundary; see ``Dinov2Model._inject_camera_token``.
self.camera_token = torch.nn.Parameter(torch.empty(1, num_camera_tokens, dim, dtype=dtype, device=device))
else:
self.camera_token = None
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
def interpolate_pos_encoding(self, x, h_pixels, w_pixels):
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32)
@ -259,22 +140,12 @@ class Dino2Embeddings(torch.nn.Module):
patch_pos = pos_embed[:, 1:]
N = patch_pos.shape[1]
M = int(N ** 0.5)
assert N == M * M, f"DINOv2 position grid must be square, got N={N} patches (sqrt={M})"
h0 = h_pixels // self.patch_size
w0 = w_pixels // self.patch_size
# +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
# scale_factor is (height_scale, width_scale) -- height MUST come first;
# swapping these only happens to work for square inputs and breaks
# non-square paths like DA3-Small / DA3-Base multi-view.
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M)
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2)
patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False)
assert (h0, w0) == patch_pos.shape[-2:], (
f"Interpolated pos-embed grid {tuple(patch_pos.shape[-2:])} does not match "
f"target patch grid ({h0}, {w0}) for input {h_pixels}x{w_pixels} (patch_size={self.patch_size}); "
f"check scale_factor axis order and +0.1 rounding workaround"
)
patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2)
return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype)
@ -290,21 +161,6 @@ class Dino2Embeddings(torch.nn.Module):
class Dinov2Model(torch.nn.Module):
"""DINOv2 vision backbone.
Supports two operating modes:
* **CLIP-vision DINOv2** (default): vanilla DINOv2-ViT used for
``ClipVisionModel`` and SigLIP-style image encoding.
* **Depth Anything 3** extensions (opt-in via config keys): 2D RoPE,
QK-norm, alternating local/global attention, camera-token injection,
``cat_token`` output and multi-layer feature extraction. These are
enabled when the corresponding fields (``alt_start``, ``qknorm_start``,
``rope_start``, ``cat_token``) are set in ``config_dict``. When all of
them are at their disabled defaults this module behaves identically to
the historical ``Dinov2Model``.
"""
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
num_layers = config_dict["num_hidden_layers"]
@ -312,51 +168,12 @@ class Dinov2Model(torch.nn.Module):
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
patch_size = config_dict.get("patch_size", 14)
image_size = config_dict.get("image_size", 518)
use_mask_token = config_dict.get("use_mask_token", True)
# DA3 extensions (all default to disabled).
self.alt_start = config_dict.get("alt_start", -1)
self.qknorm_start = config_dict.get("qknorm_start", -1)
self.rope_start = config_dict.get("rope_start", -1)
self.cat_token = config_dict.get("cat_token", False)
rope_freq = config_dict.get("rope_freq", 100.0)
self.embed_dim = dim
self.patch_size = patch_size
self.num_register_tokens = 0
self.patch_start_idx = 1
if self.rope_start != -1 and rope_freq > 0:
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq)
self._position_getter = _PositionGetter()
else:
self.rope = None
self._position_getter = None
# camera_token shape: (1, 2, dim) -> (ref_token, src_token).
num_cam_tokens = 2 if self.alt_start != -1 else 0
self.embeddings = Dino2Embeddings(
dim, dtype, device, operations,
patch_size=patch_size, image_size=image_size,
use_mask_token=use_mask_token, num_camera_tokens=num_cam_tokens,
)
self.encoder = Dino2Encoder(
dim, heads, layer_norm_eps, num_layers, dtype, device, operations,
use_swiglu_ffn=use_swiglu_ffn,
qknorm_start=self.qknorm_start,
)
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
if self.alt_start != -1:
raise RuntimeError(
"Dinov2Model.forward() is the backward-compatible CLIP-vision path and does not "
"apply DA3 extensions (RoPE, alternating attention, camera-token injection). "
"Use get_intermediate_layers_da3() for Depth Anything 3 models."
)
x = self.embeddings(pixel_values)
x, i = self.encoder(x, intermediate_output=intermediate_output)
x = self.layernorm(x)
@ -364,21 +181,6 @@ class Dinov2Model(torch.nn.Module):
return x, i, pooled_output, None
def get_intermediate_layers(self, pixel_values, indices, apply_norm=True):
"""Single-view multi-layer feature extraction (MoGe / vanilla DINOv2).
For the multi-view Depth Anything 3 path (RoPE, alt-attention,
camera-token injection, ref-view selection, cat_token), use
:meth:`get_intermediate_layers_da3` instead.
Args:
pixel_values: ``(B, 3, H, W)`` single-view input.
indices: layer indices to extract; supports negative indexing.
apply_norm: if True, apply the final layernorm to each output.
Returns:
list of ``(patch_tokens, cls_token)`` tuples with shapes
``(B, N_patch, C)`` and ``(B, C)`` (one entry per ``indices``).
"""
x = self.embeddings(pixel_values)
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
n_layers = len(self.encoder.layer)
@ -395,166 +197,3 @@ class Dinov2Model(torch.nn.Module):
if i >= max_idx:
break
return [cache[i] for i in resolved]
# ------------------------------------------------------------------
# Depth Anything 3 forward
# ------------------------------------------------------------------
def _prepare_rope_positions(self, B, S, H, W, device):
if self.rope is None:
return None, None
ph, pw = H // self.patch_size, W // self.patch_size
pos = self._position_getter(B * S, ph, pw, device=device)
# Shift so the cls/cam token at position 0 is reserved for "no diff".
pos = pos + 1
cls_pos = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype)
# Per-view local: real grid positions for patches, 0 for cls token.
pos_local = torch.cat([cls_pos, pos], dim=1)
# Global (across views): same grid positions; cls token still at 0,
# but patches share the same positions in every view.
pos_global = torch.cat([cls_pos, torch.zeros_like(pos) + 1], dim=1)
return pos_local, pos_global
def _inject_camera_token(self, x: torch.Tensor, B: int, S: int,
cam_token: "torch.Tensor | None") -> torch.Tensor:
# x: (B, S, N, C). Replace token at index 0 with the camera token.
if cam_token is not None:
inj = cam_token
else:
ct = comfy.model_management.cast_to_device(self.embeddings.camera_token, x.device, x.dtype)
ref_token = ct[:, :1].expand(B, -1, -1)
src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1)
inj = torch.cat([ref_token, src_token], dim=1)
x = x.clone()
x[:, :, 0] = inj
return x
def get_intermediate_layers_da3(self, pixel_values, out_layers, cam_token=None,
ref_view_strategy="saddle_balanced",
export_feat_layers=None):
"""Multi-view multi-layer feature extraction used by Depth Anything 3.
Adds RoPE positions, alternating local/global attention across views,
camera-token injection, reference-view selection/reordering,
``cat_token`` output and optional auxiliary feature exports on top of
the vanilla DINOv2 path. For the single-view MoGe / CLIP-vision use
case, see :meth:`get_intermediate_layers`.
Args:
pixel_values: ``(B, S, 3, H, W)`` views or ``(B, 3, H, W)``.
out_layers: indices into ``self.encoder.layer``.
cam_token: optional ``(B, S, dim)`` camera token to inject at
``alt_start``. If ``None`` and the model has its own
``camera_token`` parameter, that is used.
ref_view_strategy: when ``S >= 3`` and ``cam_token is None``,
pick a reference view via this strategy and move it to
position 0 right before the first alt-attention block.
The original view order is restored on the way out.
export_feat_layers: optional iterable of layer indices whose
local attention outputs to also return as auxiliary
features (``(B, S, N_patch, C)`` after final norm). Used
by the multi-view path to expose intermediate features
to the nested-architecture wrapper.
Returns:
``(layer_outputs, aux_outputs)`` where ``layer_outputs`` is a
list of ``(patch_tokens, cls_or_cam_token)`` tuples (one per
``out_layers`` entry) and ``aux_outputs`` is a list of
``(B, S, N_patch, C)`` features for ``export_feat_layers``
(empty list when not requested).
"""
if pixel_values.ndim == 4:
pixel_values = pixel_values.unsqueeze(1)
assert pixel_values.ndim == 5 and pixel_values.shape[2] == 3, \
f"expected (B,3,H,W) or (B,S,3,H,W); got {tuple(pixel_values.shape)}"
B, S, _, H, W = pixel_values.shape
# Patch + cls + (interpolated) pos embed for each view.
x = pixel_values.reshape(B * S, 3, H, W)
x = self.embeddings(x) # (B*S, 1+N, C)
x = x.reshape(B, S, x.shape[-2], x.shape[-1]) # (B, S, 1+N, C)
pos_local, pos_global = self._prepare_rope_positions(B, S, H, W, x.device)
# ``optimized_attention`` is only used by blocks without QK-norm/RoPE
# (vanilla DINOv2 path); enabling-aware blocks fall through to SDPA.
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
out_set = set(out_layers)
export_set = set(export_feat_layers) if export_feat_layers else set()
outputs: list[torch.Tensor] = []
aux_outputs: list[torch.Tensor] = []
local_x = x
b_idx = None
for i, blk in enumerate(self.encoder.layer):
apply_rope = self.rope is not None and i >= self.rope_start
block_rope = self.rope if apply_rope else None
l_pos = pos_local if apply_rope else None
g_pos = pos_global if apply_rope else None
# Reference-view selection threshold: matches the upstream constant
# ``THRESH_FOR_REF_SELECTION = 3``. Skipped when a user-supplied
# cam_token is provided (camera info already pins the geometry).
if (self.alt_start != -1 and i == self.alt_start - 1
and S >= THRESH_FOR_REF_SELECTION and cam_token is None):
b_idx = select_reference_view(x, strategy=ref_view_strategy)
x = reorder_by_reference(x, b_idx)
local_x = reorder_by_reference(local_x, b_idx)
if self.alt_start != -1 and i == self.alt_start:
x = self._inject_camera_token(x, B, S, cam_token)
if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1):
# Global attention across views: flatten S into the seq dim.
t = x.reshape(B, S * x.shape[-2], x.shape[-1])
p = g_pos.reshape(B, S * g_pos.shape[-2], g_pos.shape[-1]) if g_pos is not None else None
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
else:
# Per-view local attention.
t = x.reshape(B * S, x.shape[-2], x.shape[-1])
p = l_pos.reshape(B * S, l_pos.shape[-2], l_pos.shape[-1]) if l_pos is not None else None
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
local_x = x
if i in out_set:
if self.cat_token:
out_x = torch.cat([local_x, x], dim=-1)
else:
out_x = x
# Restore original view order on the way out so heads see views
# in the user's expected order.
if b_idx is not None and self.alt_start != -1:
out_x = restore_original_order(out_x, b_idx)
outputs.append(out_x)
if i in export_set:
aux = x
if b_idx is not None and self.alt_start != -1:
aux = restore_original_order(aux, b_idx)
aux_outputs.append(aux)
# Apply final norm. When ``cat_token`` is set, only the right half
# ("global" features) is normalised; the left half is left as-is to
# match the upstream DA3 head signature.
normed: list[torch.Tensor] = []
cls_tokens: list[torch.Tensor] = []
for out_x in outputs:
cls_tokens.append(out_x[:, :, 0])
if out_x.shape[-1] == self.embed_dim:
normed.append(self.layernorm(out_x))
elif out_x.shape[-1] == self.embed_dim * 2:
left = out_x[..., :self.embed_dim]
right = self.layernorm(out_x[..., self.embed_dim:])
normed.append(torch.cat([left, right], dim=-1))
else:
raise ValueError(f"Unexpected token width: {out_x.shape[-1]}")
# Drop cls/cam token from the patch sequence.
normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed]
# Final layernorm + drop cls token from auxiliary features too.
aux_normed = [self.layernorm(o)[..., 1 + self.num_register_tokens:, :]
for o in aux_outputs]
return list(zip(normed, cls_tokens)), aux_normed

View File

@ -1,25 +0,0 @@
"""Colormap utilities for depth and geometry visualisation."""
from __future__ import annotations
import torch
def turbo(x: torch.Tensor) -> torch.Tensor:
"""Anton Mikhailov polynomial approximation of the Turbo colormap.
Args:
x: Float tensor with values in [0, 1].
Returns:
RGB tensor of the same shape as ``x`` with a trailing size-3 dimension.
"""
x = x.clamp(0.0, 1.0)
x2 = x * x
x3 = x2 * x
x4 = x2 * x2
x5 = x4 * x
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)

View File

@ -1,7 +0,0 @@
# Depth Anything 3 - native ComfyUI port (Apache-2.0 monocular variants only).
#
# Supported variants:
# DA3-Small, DA3-Base (vits/vitb backbone, DualDPT head)
# DA3Mono-Large, DA3Metric-Large (vitl backbone, DPT head + sky mask)
#
# Original repo: https://github.com/ByteDance-Seed/Depth-Anything-3

View File

@ -1,204 +0,0 @@
"""Camera-token encoder and decoder for Depth Anything 3.
* :class:`CameraEnc` takes per-view extrinsics + intrinsics and produces a
per-view camera token that gets injected at the alt-attention boundary
in the DINOv2 backbone (block ``alt_start``).
* :class:`CameraDec` takes the final-layer camera token output by the
backbone and predicts a 9-D pose encoding (translation, quaternion,
field-of-view).
The module/parameter names match the upstream ``cam_enc.py``/``cam_dec.py``
so HF safetensors load directly with no key remapping (the upstream uses
fused QKV linears, which we replicate here).
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from .transform import affine_inverse, extri_intri_to_pose_encoding
# -----------------------------------------------------------------------------
# Building blocks (mirror ``depth_anything_3.model.utils.{attention,block}``)
# -----------------------------------------------------------------------------
class _Mlp(nn.Module):
"""Standard 2-layer MLP with GELU. Matches upstream ``utils.attention.Mlp``."""
def __init__(self, in_features, hidden_features=None, out_features=None,
*, device=None, dtype=None, operations=None):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = operations.Linear(in_features, hidden_features, bias=True,
device=device, dtype=dtype)
self.fc2 = operations.Linear(hidden_features, out_features, bias=True,
device=device, dtype=dtype)
def forward(self, x):
return self.fc2(F.gelu(self.fc1(x)))
class _LayerScale(nn.Module):
"""Per-channel learnable scaling. Matches upstream ``LayerScale``."""
def __init__(self, dim, *, device=None, dtype=None):
super().__init__()
self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x):
return x * self.gamma.to(dtype=x.dtype, device=x.device)
class _Attention(nn.Module):
"""Self-attention with fused QKV projection.
Mirrors upstream ``utils.attention.Attention``; layout matches the
HF safetensors (``attn.qkv.{weight,bias}`` and ``attn.proj.{weight,bias}``).
"""
def __init__(self, dim, num_heads,
*, device=None, dtype=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=True,
device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, bias=True,
device=device, dtype=dtype)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # 3, B, h, N, d
q, k, v = qkv.unbind(0)
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(B, N, C)
return self.proj(out)
class _Block(nn.Module):
"""Pre-norm transformer block with LayerScale.
Used by :class:`CameraEnc`. Layout follows upstream ``utils.block.Block``.
"""
def __init__(self, dim, num_heads, mlp_ratio=4, init_values=0.01,
*, device=None, dtype=None, operations=None):
super().__init__()
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.attn = _Attention(dim, num_heads,
device=device, dtype=dtype, operations=operations)
self.ls1 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.mlp = _Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio),
device=device, dtype=dtype, operations=operations)
self.ls2 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
def forward(self, x):
x = x + self.ls1(self.attn(self.norm1(x)))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
class CameraEnc(nn.Module):
"""Encode per-view (extrinsics, intrinsics) into a camera token.
Maps a 9-D pose-encoding vector through a small MLP up to the backbone's
``embed_dim``, then runs ``trunk_depth`` transformer blocks. The output
has shape ``(B, S, embed_dim)`` and is injected at block ``alt_start``
of the DINOv2 backbone in place of the cls token.
Parameters mirror the upstream ``cam_enc.py`` so HF weights load directly.
"""
def __init__(
self,
dim_out: int = 1024,
dim_in: int = 9,
trunk_depth: int = 4,
target_dim: int = 9,
num_heads: int = 16,
mlp_ratio: int = 4,
init_values: float = 0.01,
*,
device=None, dtype=None, operations=None,
**_kwargs,
):
super().__init__()
self.target_dim = target_dim
self.trunk_depth = trunk_depth
self.trunk = nn.Sequential(*[
_Block(dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio,
init_values=init_values,
device=device, dtype=dtype, operations=operations)
for _ in range(trunk_depth)
])
self.token_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
self.trunk_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
self.pose_branch = _Mlp(
in_features=dim_in,
hidden_features=dim_out // 2,
out_features=dim_out,
device=device, dtype=dtype, operations=operations,
)
def forward(self, extrinsics: torch.Tensor, intrinsics: torch.Tensor,
image_size_hw) -> torch.Tensor:
"""Encode camera parameters into ``(B, S, dim_out)`` tokens."""
c2ws = affine_inverse(extrinsics)
pose_encoding = extri_intri_to_pose_encoding(c2ws, intrinsics, image_size_hw)
tokens = self.pose_branch(pose_encoding.to(self.pose_branch.fc1.weight.dtype))
tokens = self.token_norm(tokens)
tokens = self.trunk(tokens)
tokens = self.trunk_norm(tokens)
return tokens
class CameraDec(nn.Module):
"""Decode the final cam token into a 9-D pose encoding.
Output layout: ``[T(3), quat_xyzw(4), fov_h, fov_w]``. The translation is
always predicted by the network; the quaternion and FoV can either be
predicted or supplied via ``camera_encoding`` (used at training time
when GT cameras are available -- not exercised at inference here).
Parameters mirror the upstream ``cam_dec.py`` so HF weights load directly.
"""
def __init__(self, dim_in: int = 1536,
*, device=None, dtype=None, operations=None, **_kwargs):
super().__init__()
d = dim_in
self.backbone = nn.Sequential(
operations.Linear(d, d, device=device, dtype=dtype),
nn.ReLU(),
operations.Linear(d, d, device=device, dtype=dtype),
nn.ReLU(),
)
self.fc_t = operations.Linear(d, 3, device=device, dtype=dtype)
self.fc_qvec = operations.Linear(d, 4, device=device, dtype=dtype)
self.fc_fov = nn.Sequential(
operations.Linear(d, 2, device=device, dtype=dtype),
nn.ReLU(),
)
def forward(self, feat: torch.Tensor,
camera_encoding: "torch.Tensor | None" = None) -> torch.Tensor:
"""Decode ``(B, N, dim_in)`` cam tokens into ``(B, N, 9)`` pose enc."""
B, N = feat.shape[:2]
feat = feat.reshape(B * N, -1)
feat = self.backbone(feat)
out_t = self.fc_t(feat.float()).reshape(B, N, 3)
if camera_encoding is None:
out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4)
out_fov = self.fc_fov(feat.float()).reshape(B, N, 2)
else:
out_qvec = camera_encoding[..., 3:7]
out_fov = camera_encoding[..., -2:]
return torch.cat([out_t, out_qvec, out_fov], dim=-1)

View File

@ -1,549 +0,0 @@
# DPT / DualDPT heads for Depth Anything 3.
#
# Ported from:
# src/depth_anything_3/model/dpt.py (DPT - single main head + sky head)
# src/depth_anything_3/model/dualdpt.py (DualDPT - depth + auxiliary "ray" head)
#
# In the monocular path we always discard the auxiliary "ray" output of
# DualDPT. The auxiliary branch is still constructed so that DA3 HF weights
# load cleanly without missing-key warnings.
from __future__ import annotations
from typing import List, Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# -----------------------------------------------------------------------------
# Helpers (matching upstream head_utils.py)
# -----------------------------------------------------------------------------
class Permute(nn.Module):
def __init__(self, dims: Tuple[int, ...]):
super().__init__()
self.dims = dims
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.permute(*self.dims)
def _custom_interpolate(
x: torch.Tensor,
size: Optional[Tuple[int, int]] = None,
scale_factor: Optional[float] = None,
mode: str = "bilinear",
align_corners: bool = True,
) -> torch.Tensor:
if size is None:
assert scale_factor is not None
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
INT_MAX = 1610612736
total = size[0] * size[1] * x.shape[0] * x.shape[1]
if total > INT_MAX:
chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0)
outs = [F.interpolate(c, size=size, mode=mode, align_corners=align_corners) for c in chunks]
return torch.cat(outs, dim=0).contiguous()
return F.interpolate(x, size=size, mode=mode, align_corners=align_corners)
def _create_uv_grid(width: int, height: int, aspect_ratio: float,
dtype, device) -> torch.Tensor:
"""Normalised UV grid spanning (-x_span, -y_span)..(x_span, y_span)."""
diag_factor = (aspect_ratio ** 2 + 1.0) ** 0.5
span_x = aspect_ratio / diag_factor
span_y = 1.0 / diag_factor
left_x = -span_x * (width - 1) / width
right_x = span_x * (width - 1) / width
top_y = -span_y * (height - 1) / height
bottom_y = span_y * (height - 1) / height
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
return torch.stack((uu, vv), dim=-1) # (H, W, 2)
def _make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100.0) -> torch.Tensor:
omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
omega = 1.0 / omega_0 ** (omega / (embed_dim / 2.0))
pos = pos.reshape(-1)
out = torch.einsum("m,d->md", pos, omega)
return torch.cat([out.sin(), out.cos()], dim=1).float()
def _position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int,
omega_0: float = 100.0) -> torch.Tensor:
H, W, _ = pos_grid.shape
pos_flat = pos_grid.reshape(-1, 2)
emb_x = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0)
emb_y = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0)
emb = torch.cat([emb_x, emb_y], dim=-1)
return emb.view(H, W, embed_dim)
def _add_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
"""Stateless UV positional embedding added to a feature map (B, C, h, w)."""
pw, ph = x.shape[-1], x.shape[-2]
pe = _create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
pe = _position_grid_to_embed(pe, x.shape[1]) * ratio
pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1).to(dtype=x.dtype)
return x + pe
def _apply_activation(x: torch.Tensor, activation: str) -> torch.Tensor:
act = (activation or "linear").lower()
if act == "exp":
return torch.exp(x)
if act == "expp1":
return torch.exp(x) + 1
if act == "expm1":
return torch.expm1(x)
if act == "relu":
return torch.relu(x)
if act == "sigmoid":
return torch.sigmoid(x)
if act == "softplus":
return F.softplus(x)
if act == "tanh":
return torch.tanh(x)
return x
# -----------------------------------------------------------------------------
# Fusion building blocks
# -----------------------------------------------------------------------------
class ResidualConvUnit(nn.Module):
def __init__(self, features: int,
device=None, dtype=None, operations=None):
super().__init__()
self.conv1 = operations.Conv2d(features, features, 3, 1, 1, bias=True,
device=device, dtype=dtype)
self.conv2 = operations.Conv2d(features, features, 3, 1, 1, bias=True,
device=device, dtype=dtype)
self.activation = nn.ReLU(inplace=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.activation(x)
out = self.conv1(out)
out = self.activation(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Module):
def __init__(self, features: int, has_residual: bool = True,
align_corners: bool = True,
device=None, dtype=None, operations=None):
super().__init__()
self.align_corners = align_corners
self.has_residual = has_residual
if has_residual:
self.resConfUnit1 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
else:
self.resConfUnit1 = None
self.resConfUnit2 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
self.out_conv = operations.Conv2d(features, features, 1, 1, 0, bias=True,
device=device, dtype=dtype)
def forward(self, *xs: torch.Tensor, size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
y = xs[0]
if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None:
y = y + self.resConfUnit1(xs[1])
y = self.resConfUnit2(y)
if size is None:
up_kwargs = {"scale_factor": 2.0}
else:
up_kwargs = {"size": size}
y = _custom_interpolate(y, **up_kwargs, mode="bilinear",
align_corners=self.align_corners)
y = self.out_conv(y)
return y
class _Scratch(nn.Module):
"""Container that mirrors upstream ``scratch`` attribute layout."""
def _make_scratch(in_shape: List[int], out_shape: int,
device=None, dtype=None, operations=None) -> _Scratch:
scratch = _Scratch()
scratch.layer1_rn = operations.Conv2d(in_shape[0], out_shape, 3, 1, 1, bias=False,
device=device, dtype=dtype)
scratch.layer2_rn = operations.Conv2d(in_shape[1], out_shape, 3, 1, 1, bias=False,
device=device, dtype=dtype)
scratch.layer3_rn = operations.Conv2d(in_shape[2], out_shape, 3, 1, 1, bias=False,
device=device, dtype=dtype)
scratch.layer4_rn = operations.Conv2d(in_shape[3], out_shape, 3, 1, 1, bias=False,
device=device, dtype=dtype)
return scratch
def _make_fusion_block(features: int, has_residual: bool = True,
device=None, dtype=None, operations=None) -> FeatureFusionBlock:
return FeatureFusionBlock(features, has_residual=has_residual,
align_corners=True,
device=device, dtype=dtype, operations=operations)
# -----------------------------------------------------------------------------
# DPT (single head + optional sky head) -- used by DA3Mono/Metric
# -----------------------------------------------------------------------------
class DPT(nn.Module):
"""Single-head DPT used by DA3Mono-Large and DA3Metric-Large."""
def __init__(
self,
dim_in: int,
patch_size: int = 14,
output_dim: int = 1,
activation: str = "exp",
conf_activation: str = "expp1",
features: int = 256,
out_channels: Sequence[int] = (256, 512, 1024, 1024),
pos_embed: bool = False,
down_ratio: int = 1,
head_name: str = "depth",
use_sky_head: bool = True,
sky_name: str = "sky",
sky_activation: str = "relu",
norm_type: str = "idt",
device=None, dtype=None, operations=None,
):
super().__init__()
self.patch_size = patch_size
self.activation = activation
self.conf_activation = conf_activation
self.pos_embed = pos_embed
self.down_ratio = down_ratio
self.head_main = head_name
self.sky_name = sky_name
self.out_dim = output_dim
self.has_conf = output_dim > 1
self.use_sky_head = use_sky_head
self.sky_activation = sky_activation
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
if norm_type == "layer":
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
else:
self.norm = nn.Identity()
out_channels = list(out_channels)
self.projects = nn.ModuleList([
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0,
device=device, dtype=dtype)
for oc in out_channels
])
self.resize_layers = nn.ModuleList([
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0,
device=device, dtype=dtype),
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0,
device=device, dtype=dtype),
nn.Identity(),
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1,
device=device, dtype=dtype),
])
self.scratch = _make_scratch(out_channels, features,
device=device, dtype=dtype, operations=operations)
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False,
device=device, dtype=dtype, operations=operations)
head_features_1 = features
head_features_2 = 32
self.scratch.output_conv1 = operations.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
device=device, dtype=dtype,
)
self.scratch.output_conv2 = nn.Sequential(
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
device=device, dtype=dtype),
nn.ReLU(inplace=False),
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0,
device=device, dtype=dtype),
)
if self.use_sky_head:
self.scratch.sky_output_conv2 = nn.Sequential(
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
device=device, dtype=dtype),
nn.ReLU(inplace=False),
operations.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0,
device=device, dtype=dtype),
)
def forward(self, feats: List[torch.Tensor], H: int, W: int,
patch_start_idx: int = 0, **_kwargs) -> dict:
# feats[i][0] is the patch-token tensor with shape (B, S, N_patch, C)
B, S, N, C = feats[0][0].shape
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
ph, pw = H // self.patch_size, W // self.patch_size
resized = []
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
x = feats_flat[take_idx][:, patch_start_idx:]
x = self.norm(x)
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
x = self.projects[stage_idx](x)
if self.pos_embed:
x = _add_pos_embed(x, W, H)
x = self.resize_layers[stage_idx](x)
resized.append(x)
l1_rn = self.scratch.layer1_rn(resized[0])
l2_rn = self.scratch.layer2_rn(resized[1])
l3_rn = self.scratch.layer3_rn(resized[2])
l4_rn = self.scratch.layer4_rn(resized[3])
out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:])
out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:])
out = self.scratch.refinenet1(out, l1_rn)
h_out = int(ph * self.patch_size / self.down_ratio)
w_out = int(pw * self.patch_size / self.down_ratio)
fused = self.scratch.output_conv1(out)
fused = _custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True)
if self.pos_embed:
fused = _add_pos_embed(fused, W, H)
feat = fused
main_logits = self.scratch.output_conv2(feat)
outs = {}
if self.has_conf:
fmap = main_logits.permute(0, 2, 3, 1)
pred = _apply_activation(fmap[..., :-1], self.activation)
conf = _apply_activation(fmap[..., -1], self.conf_activation)
outs[self.head_main] = pred.squeeze(-1).view(B, S, *pred.shape[1:-1])
outs[f"{self.head_main}_conf"] = conf.view(B, S, *conf.shape[1:])
else:
pred = _apply_activation(main_logits, self.activation)
outs[self.head_main] = pred.squeeze(1).view(B, S, *pred.shape[2:])
if self.use_sky_head:
sky_logits = self.scratch.sky_output_conv2(feat)
if self.sky_activation.lower() == "sigmoid":
sky = torch.sigmoid(sky_logits)
elif self.sky_activation.lower() == "relu":
sky = F.relu(sky_logits)
else:
sky = sky_logits
outs[self.sky_name] = sky.squeeze(1).view(B, S, *sky.shape[2:])
return outs
# -----------------------------------------------------------------------------
# DualDPT (depth + auxiliary "ray" head) -- used by DA3-Small / DA3-Base
# -----------------------------------------------------------------------------
class DualDPT(nn.Module):
"""Two-head DPT used by DA3-Small / DA3-Base.
The auxiliary "ray" head is constructed so that HF state-dict keys load
cleanly. It is only executed when :attr:`enable_aux` is set on the
instance (typically by ``DepthAnything3Net`` when running multi-view
with ``use_ray_pose=True``); otherwise the monocular path skips it for
speed and the auxiliary submodules sit idle.
"""
def __init__(
self,
dim_in: int,
patch_size: int = 14,
output_dim: int = 2,
activation: str = "exp",
conf_activation: str = "expp1",
features: int = 256,
out_channels: Sequence[int] = (256, 512, 1024, 1024),
pos_embed: bool = True,
down_ratio: int = 1,
aux_pyramid_levels: int = 4,
aux_out1_conv_num: int = 5,
head_names: Tuple[str, str] = ("depth", "ray"),
device=None, dtype=None, operations=None,
):
super().__init__()
self.patch_size = patch_size
self.activation = activation
self.conf_activation = conf_activation
self.pos_embed = pos_embed
self.down_ratio = down_ratio
self.aux_levels = aux_pyramid_levels
self.aux_out1_conv_num = aux_out1_conv_num
self.head_main, self.head_aux = head_names
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
# Toggle the auxiliary ray branch at runtime. Default off (mono path).
# ``DepthAnything3Net`` flips this on when running multi-view + ray-pose.
self.enable_aux: bool = False
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
out_channels = list(out_channels)
self.projects = nn.ModuleList([
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0,
device=device, dtype=dtype)
for oc in out_channels
])
self.resize_layers = nn.ModuleList([
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0,
device=device, dtype=dtype),
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0,
device=device, dtype=dtype),
nn.Identity(),
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1,
device=device, dtype=dtype),
])
self.scratch = _make_scratch(out_channels, features,
device=device, dtype=dtype, operations=operations)
# Main fusion chain
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False,
device=device, dtype=dtype, operations=operations)
# Auxiliary fusion chain (separate copies)
self.scratch.refinenet1_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet2_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet3_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False,
device=device, dtype=dtype, operations=operations)
head_features_1 = features
head_features_2 = 32
# Main head neck + final projection
self.scratch.output_conv1 = operations.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
device=device, dtype=dtype,
)
self.scratch.output_conv2 = nn.Sequential(
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
device=device, dtype=dtype),
nn.ReLU(inplace=False),
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0,
device=device, dtype=dtype),
)
# Aux pre-head per level (multi-level pyramid)
self.scratch.output_conv1_aux = nn.ModuleList([
self._make_aux_out1_block(head_features_1, device=device, dtype=dtype, operations=operations)
for _ in range(self.aux_levels)
])
# Aux final projection per level (includes LayerNorm permute path).
ln_seq = [Permute((0, 2, 3, 1)),
operations.LayerNorm(head_features_2, device=device, dtype=dtype),
Permute((0, 3, 1, 2))]
self.scratch.output_conv2_aux = nn.ModuleList([
nn.Sequential(
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
device=device, dtype=dtype),
*ln_seq,
nn.ReLU(inplace=False),
operations.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0,
device=device, dtype=dtype),
)
for _ in range(self.aux_levels)
])
@staticmethod
def _make_aux_out1_block(in_ch: int, *, device=None, dtype=None, operations=None) -> nn.Sequential:
# aux_out1_conv_num=5 in all Apache-2.0 variants.
return nn.Sequential(
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
)
def forward(self, feats: List[torch.Tensor], H: int, W: int,
patch_start_idx: int = 0, **_kwargs) -> dict:
B, S, N, C = feats[0][0].shape
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
ph, pw = H // self.patch_size, W // self.patch_size
resized = []
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
x = feats_flat[take_idx][:, patch_start_idx:]
x = self.norm(x)
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
x = self.projects[stage_idx](x)
if self.pos_embed:
x = _add_pos_embed(x, W, H)
x = self.resize_layers[stage_idx](x)
resized.append(x)
l1_rn = self.scratch.layer1_rn(resized[0])
l2_rn = self.scratch.layer2_rn(resized[1])
l3_rn = self.scratch.layer3_rn(resized[2])
l4_rn = self.scratch.layer4_rn(resized[3])
# Main pyramid (output_conv1 is applied inside the upstream `_fuse`,
# before interpolation -- replicate that order here).
m = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
if self.enable_aux:
a4 = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:])
aux_pyr = [a4]
m = self.scratch.refinenet3(m, l3_rn, size=l2_rn.shape[2:])
if self.enable_aux:
aux_pyr.append(self.scratch.refinenet3_aux(aux_pyr[-1], l3_rn, size=l2_rn.shape[2:]))
m = self.scratch.refinenet2(m, l2_rn, size=l1_rn.shape[2:])
if self.enable_aux:
aux_pyr.append(self.scratch.refinenet2_aux(aux_pyr[-1], l2_rn, size=l1_rn.shape[2:]))
m = self.scratch.refinenet1(m, l1_rn)
if self.enable_aux:
aux_pyr.append(self.scratch.refinenet1_aux(aux_pyr[-1], l1_rn))
m = self.scratch.output_conv1(m)
h_out = int(ph * self.patch_size / self.down_ratio)
w_out = int(pw * self.patch_size / self.down_ratio)
m = _custom_interpolate(m, (h_out, w_out), mode="bilinear", align_corners=True)
if self.pos_embed:
m = _add_pos_embed(m, W, H)
main_logits = self.scratch.output_conv2(m)
fmap = main_logits.permute(0, 2, 3, 1)
depth_pred = _apply_activation(fmap[..., :-1], self.activation)
depth_conf = _apply_activation(fmap[..., -1], self.conf_activation)
outs = {
self.head_main: depth_pred.squeeze(-1).view(B, S, *depth_pred.shape[1:-1]),
f"{self.head_main}_conf": depth_conf.view(B, S, *depth_conf.shape[1:]),
}
if self.enable_aux:
# Auxiliary "ray" head (multi-level inside) -- only the last level
# is returned. Mirrors upstream ``DualDPT._fuse`` + ``_forward_impl``:
# each aux pyramid level goes through ``output_conv1_aux[i]``
# (5-layer conv stack that ends at ``features // 2`` channels),
# then the last level optionally gets a pos-embed and finally
# ``output_conv2_aux[-1]``.
aux_processed = [
self.scratch.output_conv1_aux[i](a) for i, a in enumerate(aux_pyr)
]
last_aux = aux_processed[-1]
if self.pos_embed:
last_aux = _add_pos_embed(last_aux, W, H)
last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux)
fmap_last = last_aux_logits.permute(0, 2, 3, 1)
# Channels: [ray(6), ray_conf(1)]; ray uses 'linear' activation.
aux_pred = fmap_last[..., :-1]
aux_conf = _apply_activation(fmap_last[..., -1], self.conf_activation)
outs[self.head_aux] = aux_pred.view(B, S, *aux_pred.shape[1:])
outs[f"{self.head_aux}_conf"] = aux_conf.view(B, S, *aux_conf.shape[1:])
return outs

View File

@ -1,300 +0,0 @@
# DepthAnything3Net: top-level wrapper that combines backbone + head.
#
# Supports both the monocular and the multi-view + camera path:
#
# * Monocular: ``S = 1``, no camera encoder/decoder. Mirrors the original
# port that only handled ``DA3-MONO/METRIC-LARGE`` and the auxiliary-disabled
# ``DA3-SMALL/BASE`` configs.
# * Multi-view + camera: ``S > 1``. ``cam_enc`` (optional) maps user-supplied
# extrinsics + intrinsics into a per-view camera token; ``cam_dec`` decodes
# the final layer's camera token into a 9-D pose encoding. When the
# auxiliary "ray" head of ``DualDPT`` is enabled the predicted ray map can
# alternatively be used to estimate pose via RANSAC (``use_ray_pose=True``).
# The 3D-Gaussian head and the nested-architecture wrapper are intentionally
# left out of scope here; their state-dict keys are filtered in
# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``.
#
# The backbone is shared with the CLIP-vision DINOv2 path
# (``comfy.image_encoders.dino2.Dinov2Model``); the DA3-specific extensions
# (RoPE, QK-norm, alternating local/global attention, camera token, multi-
# layer feature extraction, reference-view reordering) are opt-in via the
# config dict and are all disabled for the Mono/Metric variants.
from __future__ import annotations
from typing import Dict, Optional, Sequence
import torch
import torch.nn as nn
from comfy.image_encoders.dino2 import Dinov2Model
from .camera import CameraDec, CameraEnc
from .dpt import DPT, DualDPT
from .ray_pose import get_extrinsic_from_camray
from .transform import affine_inverse, pose_encoding_to_extri_intri
_HEAD_REGISTRY = {
"dpt": DPT,
"dualdpt": DualDPT,
}
# Backbone presets (mirror the upstream DINOv2 ViT variants).
_BACKBONE_PRESETS = {
"vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False),
"vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False),
"vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False),
"vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True),
}
def _build_backbone_config(
backbone_name: str,
*,
alt_start: int,
qknorm_start: int,
rope_start: int,
cat_token: bool,
) -> dict:
if backbone_name not in _BACKBONE_PRESETS:
raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}")
cfg = dict(_BACKBONE_PRESETS[backbone_name])
cfg.update(dict(
layer_norm_eps=1e-6,
patch_size=14,
image_size=518,
# No mask_token in DA3 weights; omit param to avoid load warnings.
use_mask_token=False,
alt_start=alt_start,
qknorm_start=qknorm_start,
rope_start=rope_start,
cat_token=cat_token,
rope_freq=100.0,
))
return cfg
class DepthAnything3Net(nn.Module):
"""ComfyUI-side DepthAnything3 network.
Parameters mirror the variant YAML configs from the upstream repo and
are auto-detected from the state dict by ``comfy/model_detection.py``.
The kwargs ``device``, ``dtype`` and ``operations`` are injected by
``BaseModel``.
"""
PATCH_SIZE = 14
def __init__(
self,
# --- Backbone ---
backbone_name: str = "vitl",
out_layers: Sequence[int] = (4, 11, 17, 23),
alt_start: int = -1,
qknorm_start: int = -1,
rope_start: int = -1,
cat_token: bool = False,
# --- Head ---
head_type: str = "dpt", # "dpt" or "dualdpt"
head_dim_in: int = 1024,
head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf
head_features: int = 256,
head_out_channels: Sequence[int] = (256, 512, 1024, 1024),
head_use_sky_head: bool = True, # ignored by DualDPT
head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT
# --- Camera (multi-view) ---
has_cam_enc: bool = False,
has_cam_dec: bool = False,
cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim)
cam_dec_dim_in: Optional[int] = None, # CameraDec dim_in (defaults to 2*embed_dim with cat_token)
# ComfyUI plumbing
device=None, dtype=None, operations=None,
**_ignored,
):
super().__init__()
head_cls = _HEAD_REGISTRY[head_type.lower()]
self.head_type = head_type.lower()
self.has_sky = (self.head_type == "dpt") and head_use_sky_head
self.has_conf = head_output_dim > 1
self.out_layers = list(out_layers)
backbone_cfg = _build_backbone_config(
backbone_name,
alt_start=alt_start,
qknorm_start=qknorm_start,
rope_start=rope_start,
cat_token=cat_token,
)
self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations)
head_kwargs = dict(
dim_in=head_dim_in,
patch_size=self.PATCH_SIZE,
output_dim=head_output_dim,
features=head_features,
out_channels=tuple(head_out_channels),
device=device, dtype=dtype, operations=operations,
)
if self.head_type == "dpt":
head_kwargs.update(
use_sky_head=head_use_sky_head,
pos_embed=(False if head_pos_embed is None else head_pos_embed),
)
else: # dualdpt
head_kwargs.update(
pos_embed=(True if head_pos_embed is None else head_pos_embed),
)
self.head = head_cls(**head_kwargs)
# Built only if checkpoint has weights; cam_enc output dim == embed_dim.
embed_dim = backbone_cfg["hidden_size"]
if has_cam_enc:
self.cam_enc = CameraEnc(
dim_out=cam_dim_out if cam_dim_out is not None else embed_dim,
num_heads=max(1, embed_dim // 64),
device=device, dtype=dtype, operations=operations,
)
else:
self.cam_enc = None
if has_cam_dec:
default_dim = embed_dim * (2 if cat_token else 1)
self.cam_dec = CameraDec(
dim_in=cam_dec_dim_in if cam_dec_dim_in is not None else default_dim,
device=device, dtype=dtype, operations=operations,
)
else:
self.cam_dec = None
self.dtype = dtype
def forward(
self,
image: torch.Tensor,
extrinsics: Optional[torch.Tensor] = None,
intrinsics: Optional[torch.Tensor] = None,
*,
use_ray_pose: bool = False,
ref_view_strategy: str = "saddle_balanced",
export_feat_layers: Optional[Sequence[int]] = None,
**_unused,
) -> Dict[str, torch.Tensor]:
"""Run depth (and optionally pose) prediction.
Args:
image: ``(B, 3, H, W)`` ImageNet-normalised image tensor, or
``(B, S, 3, H, W)`` for multi-view inputs. ``H`` and ``W``
must be multiples of 14.
extrinsics: optional ``(B, S, 4, 4)`` world-to-camera extrinsics.
When provided together with ``intrinsics``, ``CameraEnc``
converts them into per-view camera tokens that the backbone
injects at block ``alt_start``.
intrinsics: optional ``(B, S, 3, 3)`` pixel-space intrinsics.
use_ray_pose: if True, predict pose from the auxiliary "ray" head
(RANSAC over per-pixel rays). Only available on DualDPT
variants. If False (default) and ``cam_dec`` is present,
the final-layer cam token is decoded into pose instead.
ref_view_strategy: reference-view selection strategy used when
``S >= 3`` and no extrinsics are supplied. See
:mod:`comfy.ldm.depth_anything_3.reference_view_selector`.
export_feat_layers: optional list of backbone layer indices whose
local features to also return as auxiliary outputs (used by
downstream nested-architecture wrappers; empty by default).
Returns:
Dict with a subset of:
- ``depth`` ``(B*S, H, W)`` raw depth values.
- ``depth_conf`` ``(B*S, H, W)`` confidence (DualDPT only).
- ``sky`` ``(B*S, H, W)`` sky probability (DPT + sky head).
- ``ray`` ``(B, S, h, w, 6)`` per-pixel cam ray (DualDPT,
multi-view, ``use_ray_pose=True`` only).
- ``ray_conf`` ``(B, S, h, w)`` ray confidence.
- ``extrinsics`` ``(B, S, 4, 4)`` world-to-cam, when pose
prediction is active.
- ``intrinsics`` ``(B, S, 3, 3)`` pixel-space intrinsics.
- ``aux_features`` list of ``(B, S, h_p, w_p, C)`` features
when ``export_feat_layers`` is non-empty.
"""
if image.ndim == 4:
image = image.unsqueeze(1) # (B, 1, 3, H, W)
assert image.ndim == 5 and image.shape[2] == 3, \
f"image must be (B,3,H,W) or (B,S,3,H,W); got {tuple(image.shape)}"
B, S, _, H, W = image.shape
assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \
f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}"
# Camera-token preparation (multi-view path).
cam_token = None
if extrinsics is not None and intrinsics is not None and self.cam_enc is not None:
cam_token = self.cam_enc(extrinsics, intrinsics, (H, W))
# Toggle aux ray output on/off depending on what the caller asked for.
if isinstance(self.head, DualDPT):
self.head.enable_aux = bool(use_ray_pose)
feats, aux_feats = self.backbone.get_intermediate_layers_da3(
image, self.out_layers, cam_token=cam_token,
ref_view_strategy=ref_view_strategy,
export_feat_layers=export_feat_layers,
)
head_out = self.head(feats, H=H, W=W, patch_start_idx=0)
# Pose prediction.
out: Dict[str, torch.Tensor] = {}
if use_ray_pose and "ray" in head_out and "ray_conf" in head_out:
ray = head_out["ray"]
ray_conf = head_out["ray_conf"]
extr_c2w, focal, pp = get_extrinsic_from_camray(
ray, ray_conf, ray.shape[-3], ray.shape[-2],
)
# Match the upstream output: w2c, drop the homogeneous row.
extr_w2c = affine_inverse(extr_c2w)[:, :, :3, :]
# Build pixel-space intrinsics from the normalised focal/pp output.
intr = torch.eye(3, device=ray.device, dtype=ray.dtype)
intr = intr[None, None].expand(extr_c2w.shape[0], extr_c2w.shape[1], 3, 3).clone()
intr[:, :, 0, 0] = focal[:, :, 0] / 2 * W
intr[:, :, 1, 1] = focal[:, :, 1] / 2 * H
intr[:, :, 0, 2] = pp[:, :, 0] * W * 0.5
intr[:, :, 1, 2] = pp[:, :, 1] * H * 0.5
out["extrinsics"] = extr_w2c
out["intrinsics"] = intr
elif self.cam_dec is not None and S > 1:
# Decode the cam-token of the final out_layer into a pose encoding.
cam_feat = feats[-1][1] # (B, S, dim_in_to_cam_dec)
pose_enc = self.cam_dec(cam_feat)
c2w_3x4, intr = pose_encoding_to_extri_intri(pose_enc, (H, W))
# Match the upstream output convention: w2c (world->camera), 3x4.
c2w_4x4 = torch.cat([
c2w_3x4,
torch.tensor([0, 0, 0, 1], device=c2w_3x4.device, dtype=c2w_3x4.dtype)
.view(1, 1, 1, 4).expand(B, S, 1, 4),
], dim=-2)
out["extrinsics"] = affine_inverse(c2w_4x4)[:, :, :3, :]
out["intrinsics"] = intr
# Flatten the views axis for per-pixel outputs (depth/conf/sky) so the
# per-image consumer keeps its (B*S, H, W) interface.
for k, v in head_out.items():
if k in ("ray", "ray_conf"):
# Keep multi-view shape for downstream pose work.
out[k] = v
elif v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S:
out[k] = v.reshape(B * S, *v.shape[2:])
else:
out[k] = v
if export_feat_layers:
out["aux_features"] = self._reshape_aux_features(aux_feats, H, W)
return out
def _reshape_aux_features(self, aux_feats, H: int, W: int):
"""Reshape ``(B, S, N, C)`` aux features into ``(B, S, h_p, w_p, C)``."""
ph, pw = H // self.PATCH_SIZE, W // self.PATCH_SIZE
out = []
for f in aux_feats:
B, S, N, C = f.shape
assert N == ph * pw, f"aux feature seq mismatch: {N} != {ph}*{pw}"
out.append(f.reshape(B, S, ph, pw, C))
return out

View File

@ -1,184 +0,0 @@
# Input/output preprocessing helpers for Depth Anything 3.
#
# Ported from:
# src/depth_anything_3/utils/io/input_processor.py (image normalisation)
# src/depth_anything_3/utils/alignment.py (sky-aware depth clip)
# src/depth_anything_3/model/da3.py::_process_mono_sky_estimation
#
# Resize: ``comfy.utils.common_upscale`` with ``upscale_method="lanczos"``.
# Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale); a sweep
# across {bilinear, bicubic, area, lanczos, bislerp} on a 768->504 test image
# showed lanczos has the lowest max-abs-diff vs the upstream cv2 output
# (~0.13 vs 0.21-0.71 for the others), so we use it in both directions for
# simplicity. This keeps the path stateless, on-device, and free of any
# OpenCV dependency.
from __future__ import annotations
from typing import Tuple
import torch
import comfy.utils
PATCH_SIZE = 14
# ImageNet normalization constants used during DA3 training.
_IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406])
_IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225])
def _round_to_patch(x: int, patch: int = PATCH_SIZE) -> int:
down = (x // patch) * patch
up = down + patch
return up if abs(up - x) <= abs(x - down) else down
def compute_target_size(orig_h: int, orig_w: int, process_res: int,
method: str = "upper_bound_resize") -> Tuple[int, int]:
"""Compute (target_h, target_w) for a single image.
Methods:
- "upper_bound_resize": scale longest side to ``process_res``, then
round each dim to nearest multiple of 14 (default upstream method).
- "lower_bound_resize": scale shortest side to ``process_res``, then
round.
"""
if method == "upper_bound_resize":
longest = max(orig_h, orig_w)
scale = process_res / float(longest)
elif method == "lower_bound_resize":
shortest = min(orig_h, orig_w)
scale = process_res / float(shortest)
else:
raise ValueError(f"Unsupported process_res_method: {method}")
new_w = max(1, _round_to_patch(int(round(orig_w * scale))))
new_h = max(1, _round_to_patch(int(round(orig_h * scale))))
return new_h, new_w
def preprocess_image(
image: torch.Tensor,
process_res: int = 504,
method: str = "upper_bound_resize",
) -> torch.Tensor:
"""Preprocess a ComfyUI ``IMAGE`` batch for DA3.
Args:
image: ``(B, H, W, 3)`` float in [0, 1] (ComfyUI ``IMAGE`` convention).
process_res: target resolution (longest or shortest side, depending
on ``method``).
method: resize strategy.
Returns:
``(B, 3, H', W')`` tensor with H' and W' multiples of 14, normalised
with ImageNet statistics. The tensor lives on the same device as
``image``.
"""
assert image.ndim == 4 and image.shape[-1] == 3, \
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
B, H, W, _ = image.shape
target_h, target_w = compute_target_size(H, W, process_res, method)
# (B, H, W, 3) -> (B, 3, H, W)
x = image.movedim(-1, 1).contiguous()
if (target_h, target_w) != (H, W):
# Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale).
# Lanczos in ``common_upscale`` is anti-aliased and produces the
# closest pixel-wise match in a sweep across {bilinear, bicubic,
# area, lanczos, bislerp}. Used in both directions for simplicity.
x = comfy.utils.common_upscale(
x.float(), target_w, target_h, "lanczos", "disabled",
)
x = x.clamp(0.0, 1.0)
mean = _IMAGENET_MEAN.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
std = _IMAGENET_STD.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
x = (x - mean) / std
return x
# -----------------------------------------------------------------------------
# Output post-processing (sky-aware clipping for Mono/Metric variants)
# -----------------------------------------------------------------------------
def compute_non_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor:
"""Boolean mask: True for non-sky pixels (sky probability < threshold)."""
return sky_prediction < threshold
def apply_sky_aware_clip(
depth: torch.Tensor,
sky: torch.Tensor,
threshold: float = 0.3,
quantile: float = 0.99,
) -> torch.Tensor:
"""Replicates ``_process_mono_sky_estimation`` from upstream.
Clips sky regions to the 99th percentile of non-sky depth. Returns a new
depth tensor; ``depth`` is not modified in place.
"""
non_sky = compute_non_sky_mask(sky, threshold=threshold)
if non_sky.sum() <= 10 or (~non_sky).sum() <= 10:
return depth.clone()
non_sky_depth = depth[non_sky]
if non_sky_depth.numel() > 100_000:
idx = torch.randint(0, non_sky_depth.numel(), (100_000,), device=non_sky_depth.device)
sampled = non_sky_depth[idx]
else:
sampled = non_sky_depth
max_depth = torch.quantile(sampled, quantile)
out = depth.clone()
out[~non_sky] = max_depth
return out
def normalize_depth_v2_style(
depth: torch.Tensor,
sky: torch.Tensor | None = None,
low_quantile: float = 0.01,
high_quantile: float = 0.99,
) -> torch.Tensor:
"""V2-style normalization for ControlNet workflows.
Computes percentile bounds over non-sky pixels (when available),
then maps depth into [0, 1] with near = white (1.0).
"""
if sky is not None:
mask = compute_non_sky_mask(sky)
if mask.any():
valid = depth[mask]
else:
valid = depth.flatten()
else:
valid = depth.flatten()
if valid.numel() > 100_000:
idx = torch.randint(0, valid.numel(), (100_000,), device=valid.device)
sample = valid[idx]
else:
sample = valid
lo = torch.quantile(sample, low_quantile)
hi = torch.quantile(sample, high_quantile)
rng = (hi - lo).clamp(min=1e-6)
norm = ((depth - lo) / rng).clamp(0.0, 1.0)
# ControlNet convention: nearer pixels are brighter (1.0).
norm = 1.0 - norm
if sky is not None:
# Sky pixels become black (far / unknown).
sky_mask = ~compute_non_sky_mask(sky)
norm = torch.where(sky_mask, torch.zeros_like(norm), norm)
return norm
def normalize_depth_min_max(depth: torch.Tensor) -> torch.Tensor:
"""Simple per-frame min/max normalization with near=1.0 convention."""
lo = depth.amin(dim=(-2, -1), keepdim=True)
hi = depth.amax(dim=(-2, -1), keepdim=True)
rng = (hi - lo).clamp(min=1e-6)
return 1.0 - ((depth - lo) / rng).clamp(0.0, 1.0)

View File

@ -1,318 +0,0 @@
"""Ray-to-pose conversion for the multi-view path of Depth Anything 3.
Converts the auxiliary "ray" output of :class:`DualDPT` (per-pixel camera
ray vectors, predicted on the per-view local feature map) into per-view
extrinsics + intrinsics. Implementation is a 1:1 port of
``depth_anything_3.utils.ray_utils`` upstream, using a weighted-RANSAC
homography fit followed by a QL decomposition.
No learned parameters; pure tensor math. Output:
* ``R`` -- ``(B, S, 3, 3)`` rotation matrix
* ``T`` -- ``(B, S, 3)`` camera-space translation
* ``focal_lengths`` -- ``(B, S, 2)`` in normalised image space (image=2x2)
* ``principal_points`` -- ``(B, S, 2)`` ditto
:func:`get_extrinsic_from_camray` wraps these into a 4x4 extrinsic matrix
that the public node converts back into pixel-space intrinsics.
"""
from __future__ import annotations
from typing import Optional, Tuple
import torch
# qr/svd use fp32: CUDA often has no fp16/bf16 kernels for these ops.
def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Decompose ``A = Q @ L`` with ``Q`` orthogonal and ``L`` lower-triangular.
Implemented in terms of QR by reversing the columns/rows; the standard
trick from the upstream reference. Inputs ``A`` are ``(3, 3)``.
"""
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]],
device=A.device, dtype=A.dtype)
A_tilde = A @ P
# CUDA QR is not implemented for fp16/bf16; upcast just for this call.
Q_tilde, R_tilde = torch.linalg.qr(A_tilde.float())
Q_tilde = Q_tilde.to(A.dtype)
R_tilde = R_tilde.to(A.dtype)
Q = Q_tilde @ P
L = P @ R_tilde @ P
d = torch.diag(L)
sign = torch.sign(d)
Q = Q * sign[None, :] # scale columns of Q
L = L * sign[:, None] # scale rows of L
return Q, L
def _homogenize_points(points: torch.Tensor) -> torch.Tensor:
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
# -----------------------------------------------------------------------------
# Weighted-LSQ + RANSAC homography (batched)
# -----------------------------------------------------------------------------
def _find_homography_weighted_lsq(
src_pts: torch.Tensor,
dst_pts: torch.Tensor,
confident_weight: torch.Tensor,
) -> torch.Tensor:
"""Solve a single ``H`` with weighted least-squares (DLT)."""
N = src_pts.shape[0]
if N < 4:
raise ValueError("At least 4 points are required to compute a homography.")
w = confident_weight.sqrt().unsqueeze(1) # (N, 1)
x = src_pts[:, 0:1]
y = src_pts[:, 1:2]
u = dst_pts[:, 0:1]
v = dst_pts[:, 1:2]
zeros = torch.zeros_like(x)
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=1)
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=1)
A = torch.cat([A1, A2], dim=0) # (2N, 9)
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
_, _, Vh = torch.linalg.svd(A.float())
Vh = Vh.to(A.dtype)
H = Vh[-1].reshape(3, 3)
return H / H[-1, -1]
def _find_homography_weighted_lsq_batched(
src_pts_batch: torch.Tensor,
dst_pts_batch: torch.Tensor,
confident_weight_batch: torch.Tensor,
) -> torch.Tensor:
"""Batched DLT solver. Inputs ``(B, K, 2)`` / ``(B, K)``; output ``(B, 3, 3)``."""
B, K, _ = src_pts_batch.shape
w = confident_weight_batch.sqrt().unsqueeze(2)
x = src_pts_batch[:, :, 0:1]
y = src_pts_batch[:, :, 1:2]
u = dst_pts_batch[:, :, 0:1]
v = dst_pts_batch[:, :, 1:2]
zeros = torch.zeros_like(x)
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=2)
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=2)
A = torch.cat([A1, A2], dim=1) # (B, 2K, 9)
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
_, _, Vh = torch.linalg.svd(A.float())
Vh = Vh.to(A.dtype)
H = Vh[:, -1].reshape(B, 3, 3)
return H / H[:, 2:3, 2:3]
def _ransac_find_homography_weighted_batched(
src_pts: torch.Tensor, # (B, N, 2)
dst_pts: torch.Tensor, # (B, N, 2)
confident_weight: torch.Tensor, # (B, N)
n_sample: int,
n_iter: int = 100,
reproj_threshold: float = 3.0,
num_sample_for_ransac: int = 8,
max_inlier_num: int = 10000,
rand_sample_iters_idx: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Batched weighted-RANSAC homography estimator.
Returns ``(B, 3, 3)`` homography matrices.
"""
B, N, _ = src_pts.shape
assert N >= 4
device = src_pts.device
sorted_idx = torch.argsort(confident_weight, descending=True, dim=1)
candidate_idx = sorted_idx[:, :n_sample] # (B, n_sample)
if rand_sample_iters_idx is None:
rand_sample_iters_idx = torch.stack(
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac]
for _ in range(n_iter)],
dim=0,
)
rand_idx = candidate_idx[:, rand_sample_iters_idx] # (B, n_iter, k)
b_idx = (
torch.arange(B, device=device)
.view(B, 1, 1)
.expand(B, n_iter, num_sample_for_ransac)
)
src_b = src_pts[b_idx, rand_idx]
dst_b = dst_pts[b_idx, rand_idx]
w_b = confident_weight[b_idx, rand_idx]
cB, cN = src_b.shape[:2]
H_batch = _find_homography_weighted_lsq_batched(
src_b.flatten(0, 1), dst_b.flatten(0, 1), w_b.flatten(0, 1),
).unflatten(0, (cB, cN)) # (B, n_iter, 3, 3)
src_homo = torch.cat([src_pts, torch.ones(B, N, 1, device=device, dtype=src_pts.dtype)], dim=2)
proj = torch.bmm(
src_homo.unsqueeze(1).expand(B, n_iter, N, 3).reshape(-1, N, 3),
H_batch.reshape(-1, 3, 3).transpose(1, 2),
) # (B*n_iter, N, 3)
proj_xy = (proj[:, :, :2] / proj[:, :, 2:3]).reshape(B, n_iter, N, 2)
err = ((proj_xy - dst_pts.unsqueeze(1)) ** 2).sum(-1).sqrt() # (B, n_iter, N)
inlier_mask = err < reproj_threshold
score = (inlier_mask * confident_weight.unsqueeze(1)).sum(dim=2)
best_idx = torch.argmax(score, dim=1)
best_inlier_mask = inlier_mask[torch.arange(B, device=device), best_idx]
# Refit with the inlier set (per-batch, since the inlier counts vary).
H_inlier_list = []
for b in range(B):
mask = best_inlier_mask[b]
in_src = src_pts[b][mask]
in_dst = dst_pts[b][mask]
in_w = confident_weight[b][mask]
if in_src.shape[0] < 4:
# Fall back to identity when RANSAC fails to find enough inliers.
H_inlier_list.append(torch.eye(3, device=device, dtype=src_pts.dtype))
continue
sorted_w = torch.argsort(in_w, descending=True)
if len(sorted_w) > max_inlier_num:
keep = max(int(len(sorted_w) * 0.95), max_inlier_num)
sorted_w = sorted_w[:keep][torch.randperm(keep, device=device)[:max_inlier_num]]
H_inlier_list.append(
_find_homography_weighted_lsq(in_src[sorted_w], in_dst[sorted_w], in_w[sorted_w])
)
return torch.stack(H_inlier_list, dim=0)
# -----------------------------------------------------------------------------
# Camera-ray utilities
# -----------------------------------------------------------------------------
def _unproject_identity(num_y: int, num_x: int, B: int, S: int,
device, dtype) -> torch.Tensor:
"""Camera-space unit rays for an identity intrinsic on a 2x2 image plane.
Replicates ``unproject_depth(..., ixt_normalized=True)`` upstream: pixel
coords ``(x, y)`` in ``[dx, 2-dx] x [dy, 2-dy]`` get mapped to
camera-space rays ``(x-1, y-1, 1)`` via the identity intrinsic
``[[1,0,1],[0,1,1],[0,0,1]]``. Returns ``(B, S, num_y, num_x, 3)``.
"""
dx = 1.0 / num_x
dy = 1.0 / num_y
# Centered camera-space coords directly (skip the K^-1 step since it's
# just a translation by -1 on x and y when K is identity-with-center=1).
y = torch.linspace(-(1 - dy), (1 - dy), num_y, device=device, dtype=dtype)
x = torch.linspace(-(1 - dx), (1 - dx), num_x, device=device, dtype=dtype)
yy, xx = torch.meshgrid(y, x, indexing="ij")
grid = torch.stack((xx, yy), dim=-1) # (h, w, 2)
grid = grid.unsqueeze(0).unsqueeze(0).expand(B, S, num_y, num_x, 2)
return torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1)
def _camray_to_caminfo(
camray: torch.Tensor, # (B, S, h, w, 6)
confidence: Optional[torch.Tensor] = None, # (B, S, h, w)
reproj_threshold: float = 0.2,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert per-pixel camera rays to per-view (R, T, focal, principal)."""
if confidence is None:
confidence = torch.ones_like(camray[..., 0])
B, S, h, w, _ = camray.shape
device = camray.device
dtype = camray.dtype
rays_target = camray[..., :3] # (B, S, h, w, 3)
rays_origin = _unproject_identity(h, w, B, S, device, dtype)
# Flatten (B*S, h*w, *) for the RANSAC routine.
rays_target = rays_target.flatten(0, 1).flatten(1, 2)
rays_origin = rays_origin.flatten(0, 1).flatten(1, 2)
weights = confidence.flatten(0, 1).flatten(1, 2).clone()
# Project to 2D in homogeneous form (the upstream calls this "perspective division").
z_thresh = 1e-4
mask = (rays_target[:, :, 2].abs() > z_thresh) & (rays_origin[:, :, 2].abs() > z_thresh)
weights = torch.where(mask, weights, torch.zeros_like(weights))
src = rays_origin.clone()
dst = rays_target.clone()
src[..., 0] = torch.where(mask, src[..., 0] / src[..., 2], src[..., 0])
src[..., 1] = torch.where(mask, src[..., 1] / src[..., 2], src[..., 1])
dst[..., 0] = torch.where(mask, dst[..., 0] / dst[..., 2], dst[..., 0])
dst[..., 1] = torch.where(mask, dst[..., 1] / dst[..., 2], dst[..., 1])
src = src[..., :2]
dst = dst[..., :2]
N = src.shape[1]
n_iter = 100
sample_ratio = 0.3
num_sample_for_ransac = 8
n_sample = max(num_sample_for_ransac, int(N * sample_ratio))
rand_idx = torch.stack(
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)],
dim=0,
)
# Chunk along the view axis to keep peak memory predictable.
chunk = 2
A_list = []
for i in range(0, src.shape[0], chunk):
A = _ransac_find_homography_weighted_batched(
src[i:i + chunk], dst[i:i + chunk], weights[i:i + chunk],
n_sample=n_sample, n_iter=n_iter,
num_sample_for_ransac=num_sample_for_ransac,
reproj_threshold=reproj_threshold,
rand_sample_iters_idx=rand_idx,
max_inlier_num=8000,
)
# Flip sign on dets that come out < 0 (so that the QL produces a
# right-handed rotation). ``det`` lacks fp16/bf16 CUDA kernels, so
# do the comparison in fp32.
flip = torch.linalg.det(A.float()) < 0
A = torch.where(flip[:, None, None], -A, A)
A_list.append(A)
A = torch.cat(A_list, dim=0) # (B*S, 3, 3)
R_list, f_list, pp_list = [], [], []
for i in range(A.shape[0]):
R, L = _ql_decomposition(A[i])
L = L / L[2][2]
f_list.append(torch.stack((L[0][0], L[1][1])))
pp_list.append(torch.stack((L[2][0], L[2][1])))
R_list.append(R)
R = torch.stack(R_list).reshape(B, S, 3, 3)
focal = torch.stack(f_list).reshape(B, S, 2)
pp = torch.stack(pp_list).reshape(B, S, 2)
# Translation: confidence-weighted average of camray direction(s).
cf = confidence.flatten(0, 1).flatten(1, 2)
T = (camray.flatten(0, 1).flatten(1, 2)[..., 3:] * cf.unsqueeze(-1)).sum(dim=1)
T = T / cf.sum(dim=-1, keepdim=True)
T = T.reshape(B, S, 3)
# Match upstream output convention: focal -> 1/focal, pp + 1.
return R, T, 1.0 / focal, pp + 1.0
def get_extrinsic_from_camray(
camray: torch.Tensor, # (B, S, h, w, 6)
conf: torch.Tensor, # (B, S, h, w, 1) or (B, S, h, w)
patch_size_y: int,
patch_size_x: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Wrap a 4x4 extrinsic + per-view focal + principal-point output.
Returns:
* extrinsic ``(B, S, 4, 4)`` camera-to-world (the inverse is
what gets stored in ``output.extrinsics``
by the caller).
* focals ``(B, S, 2)`` in normalised image space.
* pp ``(B, S, 2)`` in normalised image space.
"""
if conf.ndim == 5 and conf.shape[-1] == 1:
conf = conf.squeeze(-1)
R, T, focal, pp = _camray_to_caminfo(camray, confidence=conf)
extr = torch.cat([R, T.unsqueeze(-1)], dim=-1) # (B, S, 3, 4)
homo_row = torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device)
homo_row = homo_row.view(1, 1, 1, 4).expand(R.shape[0], R.shape[1], 1, 4)
extr = torch.cat([extr, homo_row], dim=-2) # (B, S, 4, 4)
return extr, focal, pp

View File

@ -1,116 +0,0 @@
"""Reference-view selection for the multi-view path of Depth Anything 3.
Pure tensor math, no learned parameters. Exposed as three free functions:
* :func:`select_reference_view` -- pick a reference view per batch.
* :func:`reorder_by_reference` -- move the reference view to position 0.
* :func:`restore_original_order` -- inverse of :func:`reorder_by_reference`.
Mirrors ``depth_anything_3.model.reference_view_selector`` upstream.
The default strategy (``"saddle_balanced"``) selects the view whose CLS
token features are closest to the median across multiple metrics.
"""
from __future__ import annotations
from typing import Literal
import torch
RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"]
# Per the upstream constants module: ``THRESH_FOR_REF_SELECTION = 3``.
# Reference selection only runs when there are at least this many views.
THRESH_FOR_REF_SELECTION: int = 3
def select_reference_view(
x: torch.Tensor,
strategy: RefViewStrategy = "saddle_balanced",
) -> torch.Tensor:
"""Pick a reference view index per batch element.
Args:
x: ``(B, S, N, C)`` token tensor. Index 0 along ``N`` is the
cls/cam token used by the feature-based strategies.
strategy: One of ``"first" | "middle" | "saddle_balanced" |
"saddle_sim_range"``.
Returns:
``(B,)`` long tensor with the chosen reference view index for
each batch element.
"""
B, S, _, _ = x.shape
if S <= 1:
return torch.zeros(B, dtype=torch.long, device=x.device)
if strategy == "first":
return torch.zeros(B, dtype=torch.long, device=x.device)
if strategy == "middle":
return torch.full((B,), S // 2, dtype=torch.long, device=x.device)
# Feature-based strategies: normalised cls/cam token per view.
img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True) # (B,S,C)
if strategy == "saddle_balanced":
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # (B,S,S)
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
sim_score = sim_no_diag.sum(dim=-1) / (S - 1) # (B,S)
feat_norm = x[:, :, 0].norm(dim=-1) # (B,S)
feat_var = img_class_feat.var(dim=-1) # (B,S)
def _normalize(metric):
mn = metric.min(dim=1, keepdim=True).values
mx = metric.max(dim=1, keepdim=True).values
return (metric - mn) / (mx - mn + 1e-8)
sim_n, norm_n, var_n = _normalize(sim_score), _normalize(feat_norm), _normalize(feat_var)
balance = (sim_n - 0.5).abs() + (norm_n - 0.5).abs() + (var_n - 0.5).abs()
return balance.argmin(dim=1)
if strategy == "saddle_sim_range":
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2))
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
sim_max = sim_no_diag.max(dim=-1).values
sim_min = sim_no_diag.min(dim=-1).values
return (sim_max - sim_min).argmax(dim=1)
raise ValueError(
f"Unknown reference view selection strategy: {strategy!r}. "
f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'"
)
def reorder_by_reference(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
"""Reorder ``x`` so the reference view is at position 0 in axis ``S``."""
B, S = x.shape[0], x.shape[1]
if S <= 1:
return x
positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
b_idx_exp = b_idx.unsqueeze(1)
reorder = torch.where(
(positions > 0) & (positions <= b_idx_exp),
positions - 1,
positions,
)
reorder[:, 0] = b_idx
batch = torch.arange(B, device=x.device).unsqueeze(1)
return x[batch, reorder]
def restore_original_order(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
"""Inverse of :func:`reorder_by_reference`."""
B, S = x.shape[0], x.shape[1]
if S <= 1:
return x
target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
b_idx_exp = b_idx.unsqueeze(1)
restore = torch.where(target_positions < b_idx_exp,
target_positions + 1,
target_positions)
restore = torch.scatter(
restore, dim=1, index=b_idx_exp, src=torch.zeros_like(b_idx_exp),
)
batch = torch.arange(B, device=x.device).unsqueeze(1)
return x[batch, restore]

View File

@ -1,180 +0,0 @@
"""Geometry / camera transform helpers for Depth Anything 3.
Pure tensor math, no learned parameters. Mirrors the upstream upstream
``depth_anything_3.model.utils.transform`` and the parts of
``depth_anything_3.utils.geometry`` used at inference time on the
multi-view + camera path. Kept self-contained so the DA3 module is fully
ported and does not depend on the upstream repo at runtime.
"""
from __future__ import annotations
from typing import Tuple
import torch
import torch.nn.functional as F
# -----------------------------------------------------------------------------
# Affine 4x4 helpers
# -----------------------------------------------------------------------------
def as_homogeneous(ext: torch.Tensor) -> torch.Tensor:
"""Promote ``(...,3,4)`` extrinsics to ``(...,4,4)`` homogeneous form.
A no-op when the input is already ``(...,4,4)``.
"""
if ext.shape[-2:] == (4, 4):
return ext
if ext.shape[-2:] == (3, 4):
ones = torch.zeros_like(ext[..., :1, :4])
ones[..., 0, 3] = 1.0
return torch.cat([ext, ones], dim=-2)
raise ValueError(f"Invalid affine shape: {ext.shape}")
def affine_inverse(A: torch.Tensor) -> torch.Tensor:
"""Inverse of an affine matrix ``[R|T; 0 0 0 1]``."""
R = A[..., :3, :3]
T = A[..., :3, 3:]
P = A[..., 3:, :]
return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2)
# -----------------------------------------------------------------------------
# Quaternion <-> rotation matrix (xyzw / scalar-last)
# -----------------------------------------------------------------------------
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""``sqrt(max(0, x))`` with a zero subgradient where ``x == 0``."""
ret = torch.zeros_like(x)
positive_mask = x > 0
if torch.is_grad_enabled():
ret[positive_mask] = torch.sqrt(x[positive_mask])
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""Force the real part of a unit quaternion (xyzw) to be non-negative."""
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
"""Convert quaternions (xyzw) to ``(...,3,3)`` rotation matrices."""
i, j, k, r = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
"""Convert ``(...,3,3)`` rotation matrices to quaternions (xyzw)."""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
quat_by_rijk = torch.stack(
[
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
batch_dim + (4,)
)
# Reorder rijk -> xyzw (i.e. ijkr).
out = out[..., [1, 2, 3, 0]]
return standardize_quaternion(out)
# -----------------------------------------------------------------------------
# Pose-encoding <-> extrinsics + intrinsics
# -----------------------------------------------------------------------------
def extri_intri_to_pose_encoding(
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
image_size_hw: Tuple[int, int],
) -> torch.Tensor:
"""Pack ``(extr, intr, image_size)`` into the 9-D pose-encoding vector.
``extrinsics`` are camera-to-world (c2w) ``(B,S,4,4)`` matrices,
``intrinsics`` are pixel-space ``(B,S,3,3)`` matrices, ``image_size_hw``
is a ``(H, W)`` pair. The encoding is ``[T(3), quat_xyzw(4), fov_h, fov_w]``.
"""
R = extrinsics[..., :3, :3]
T = extrinsics[..., :3, 3]
quat = mat_to_quat(R)
H, W = image_size_hw
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
return torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
def pose_encoding_to_extri_intri(
pose_encoding: torch.Tensor,
image_size_hw: Tuple[int, int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Inverse of :func:`extri_intri_to_pose_encoding`.
Returns a ``(B,S,3,4)`` c2w extrinsic matrix and a ``(B,S,3,3)``
pixel-space intrinsic matrix.
"""
T = pose_encoding[..., :3]
quat = pose_encoding[..., 3:7]
fov_h = pose_encoding[..., 7]
fov_w = pose_encoding[..., 8]
R = quat_to_mat(quat)
extrinsics = torch.cat([R, T[..., None]], dim=-1)
H, W = image_size_hw
fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6)
fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6)
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3),
device=pose_encoding.device, dtype=pose_encoding.dtype)
intrinsics[..., 0, 0] = fx
intrinsics[..., 1, 1] = fy
intrinsics[..., 0, 2] = W / 2
intrinsics[..., 1, 2] = H / 2
intrinsics[..., 2, 2] = 1.0
return extrinsics, intrinsics

View File

@ -60,7 +60,6 @@ import comfy.ldm.ernie.model
import comfy.ldm.sam3.detector
import comfy.ldm.hidream_o1.model
from comfy.ldm.hidream_o1.conditioning import build_extra_conds
import comfy.ldm.depth_anything_3.model
import comfy.model_management
import comfy.patcher_extension
@ -2122,12 +2121,6 @@ class RT_DETR_v4(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
class DepthAnything3(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device,
unet_model=comfy.ldm.depth_anything_3.model.DepthAnything3Net)
class ErnieImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel)

View File

@ -805,108 +805,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
return dit_config
# Depth Anything 3
if '{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix) in state_dict_keys:
dit_config = {}
dit_config["image_model"] = "DepthAnything3"
patch_w = state_dict['{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix)]
embed_dim = patch_w.shape[0]
depth = count_blocks(state_dict_keys, '{}backbone.pretrained.blocks.'.format(key_prefix) + '{}.')
# Backbone preset is determined by embed_dim (matches vits/vitb/vitl/vitg).
backbone_name = {384: "vits", 768: "vitb", 1024: "vitl", 1536: "vitg"}.get(embed_dim)
if backbone_name is None:
return None
dit_config["backbone_name"] = backbone_name
# Detect DA3 extensions on top of vanilla DINOv2.
has_camera_token = '{}backbone.pretrained.camera_token'.format(key_prefix) in state_dict_keys
# qk-norm shows up as `attn.q_norm.weight` on enabled blocks.
qknorm_indices = [
i for i in range(depth)
if '{}backbone.pretrained.blocks.{}.attn.q_norm.weight'.format(key_prefix, i) in state_dict_keys
]
qknorm_start = qknorm_indices[0] if qknorm_indices else -1
# The DA3 main-series configs always set alt_start == qknorm_start == rope_start.
# cat_token=True is implied by the presence of camera_token.
if has_camera_token:
dit_config["alt_start"] = qknorm_start
dit_config["rope_start"] = qknorm_start
dit_config["qknorm_start"] = qknorm_start
dit_config["cat_token"] = True
else:
dit_config["alt_start"] = -1
dit_config["rope_start"] = -1
dit_config["qknorm_start"] = -1
dit_config["cat_token"] = False
# Detect head type and config.
has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys
if has_aux:
dit_config["head_type"] = "dualdpt"
# DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width).
head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
out_channels = [
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
for i in range(4)
]
features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
dit_config["head_dim_in"] = head_dim_in
dit_config["head_output_dim"] = 2
dit_config["head_features"] = features
dit_config["head_out_channels"] = out_channels
dit_config["head_use_sky_head"] = False
else:
dit_config["head_type"] = "dpt"
head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
out_channels = [
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
for i in range(4)
]
features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
output_dim = state_dict[
'{}head.scratch.output_conv2.2.weight'.format(key_prefix)
].shape[0]
dit_config["head_dim_in"] = head_dim_in
dit_config["head_output_dim"] = output_dim
dit_config["head_features"] = features
dit_config["head_out_channels"] = out_channels
dit_config["head_use_sky_head"] = (
'{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys
)
# out_layers: hard-coded per upstream YAML config (depth-aware default).
if depth >= 24:
# vitl: depths used vary between DA3-Large (DualDPT) and Mono/Metric (DPT).
if has_aux:
dit_config["out_layers"] = [11, 15, 19, 23]
else:
dit_config["out_layers"] = [4, 11, 17, 23]
else:
# vits/vitb: 12 blocks
dit_config["out_layers"] = [5, 7, 9, 11]
# Camera encoder/decoder presence (multi-view + pose path).
has_cam_enc = '{}cam_enc.token_norm.weight'.format(key_prefix) in state_dict_keys
has_cam_dec = '{}cam_dec.fc_t.weight'.format(key_prefix) in state_dict_keys
dit_config["has_cam_enc"] = has_cam_enc
dit_config["has_cam_dec"] = has_cam_dec
if has_cam_enc:
cam_enc_w = state_dict.get(
'{}cam_enc.pose_branch.fc2.weight'.format(key_prefix)
)
if cam_enc_w is not None:
dit_config["cam_dim_out"] = cam_enc_w.shape[0]
if has_cam_dec:
cam_dec_w = state_dict.get(
'{}cam_dec.fc_t.weight'.format(key_prefix)
)
if cam_dec_w is not None:
dit_config["cam_dec_dim_in"] = cam_dec_w.shape[1]
return dit_config
if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image
dit_config = {}
dit_config["image_model"] = "ernie"

View File

@ -1871,101 +1871,6 @@ class RT_DETR_v4(supported_models_base.BASE):
return None
class DepthAnything3(supported_models_base.BASE):
unet_config = {
"image_model": "DepthAnything3",
}
# Mono path: no num_heads / num_head_channels needed.
unet_extra_config = {}
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
return model_base.DepthAnything3(self, device=device)
def clip_target(self, state_dict={}):
return None
def process_unet_state_dict(self, state_dict):
# Drop Gaussian-head weights; remap fused backbone QKV to Dinov2Model layout.
drop_prefixes = ("gs_head.", "gs_adapter.")
for k in list(state_dict.keys()):
if k.startswith(drop_prefixes):
state_dict.pop(k)
return _da3_remap_backbone_keys(state_dict, prefix="backbone.")
def _da3_remap_backbone_keys(state_dict, prefix="backbone."):
"""Map ``backbone.pretrained.*`` (upstream DA3) keys to ``Dinov2Model`` under ``prefix``."""
pre = prefix + "pretrained."
src_keys = [k for k in state_dict.keys() if k.startswith(pre)]
if not src_keys:
return state_dict
static_renames = {
pre + "patch_embed.proj.weight": prefix + "embeddings.patch_embeddings.projection.weight",
pre + "patch_embed.proj.bias": prefix + "embeddings.patch_embeddings.projection.bias",
pre + "pos_embed": prefix + "embeddings.position_embeddings",
pre + "cls_token": prefix + "embeddings.cls_token",
pre + "camera_token": prefix + "embeddings.camera_token",
pre + "norm.weight": prefix + "layernorm.weight",
pre + "norm.bias": prefix + "layernorm.bias",
}
for src, dst in static_renames.items():
if src in state_dict:
state_dict[dst] = state_dict.pop(src)
block_pre = pre + "blocks."
block_keys = [k for k in state_dict.keys() if k.startswith(block_pre)]
for k in block_keys:
rest = k[len(block_pre):] # e.g. "5.attn.qkv.weight"
idx_str, _, sub = rest.partition(".")
target_block = "{}encoder.layer.{}.".format(prefix, idx_str)
# Fused QKV -> split query/key/value linears.
if sub == "attn.qkv.weight":
qkv = state_dict.pop(k)
c = qkv.shape[0] // 3
state_dict[target_block + "attention.attention.query.weight"] = qkv[:c].clone()
state_dict[target_block + "attention.attention.key.weight"] = qkv[c:2 * c].clone()
state_dict[target_block + "attention.attention.value.weight"] = qkv[2 * c:].clone()
continue
if sub == "attn.qkv.bias":
qkv = state_dict.pop(k)
c = qkv.shape[0] // 3
state_dict[target_block + "attention.attention.query.bias"] = qkv[:c].clone()
state_dict[target_block + "attention.attention.key.bias"] = qkv[c:2 * c].clone()
state_dict[target_block + "attention.attention.value.bias"] = qkv[2 * c:].clone()
continue
# Sub-key remap (suffix preserved).
if sub.startswith("attn.proj."):
tail = sub[len("attn.proj."):]
new = "attention.output.dense." + tail
elif sub.startswith("attn.q_norm."):
new = "attention.q_norm." + sub[len("attn.q_norm."):]
elif sub.startswith("attn.k_norm."):
new = "attention.k_norm." + sub[len("attn.k_norm."):]
elif sub == "ls1.gamma":
new = "layer_scale1.lambda1"
elif sub == "ls2.gamma":
new = "layer_scale2.lambda1"
elif sub.startswith("mlp.w12."):
new = "mlp.weights_in." + sub[len("mlp.w12."):]
elif sub.startswith("mlp.w3."):
new = "mlp.weights_out." + sub[len("mlp.w3."):]
elif sub.startswith(("norm1.", "norm2.", "mlp.fc1.", "mlp.fc2.")):
new = sub
else:
# Unrecognised key -- leave as-is so load_state_dict can complain.
continue
state_dict[target_block + new] = state_dict.pop(k)
return state_dict
class ErnieImage(supported_models_base.BASE):
unet_config = {
"image_model": "ernie",
@ -2202,5 +2107,4 @@ models = [
CogVideoX_I2V,
CogVideoX_T2V,
SVD_img2vid,
DepthAnything3,
]

View File

@ -1,7 +1,5 @@
from __future__ import annotations
from enum import Enum
from typing import Optional, List
from pydantic import BaseModel, Field
@ -11,44 +9,76 @@ class Rodin3DGenerateRequest(BaseModel):
material: str = Field(..., description="The material type.")
quality_override: int = Field(..., description="The poly count of the mesh.")
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
TAPose: Optional[bool] = Field(None, description="")
TAPose: bool | None = Field(None, description="")
class Rodin3DGen25Request(BaseModel):
tier: str = Field(..., description="Gen-2.5 tier (e.g. Gen-2.5-High).")
prompt: str | None = Field(None, description="Required for Text-to-3D; ignored otherwise.")
seed: int | None = Field(None, description="0-65535.")
material: str | None = Field(None, description="PBR | Shaded | All | None.")
geometry_file_format: str | None = Field(None, description="glb | usdz | fbx | obj | stl.")
texture_mode: str | None = Field(None, description="legacy | extreme-low | low | medium | high.")
mesh_mode: str | None = Field(None, description="Raw (triangular) | Quad.")
quality_override: int | None = Field(None, description="Mesh face count override.")
geometry_instruct_mode: str | None = Field(None, description="faithful | creative.")
bbox_condition: list[int] | None = Field(None, description="Bounding box [Width(Y), Height(Z), Length(X)] in cm.")
height: int | None = Field(None, description="Approximate model height in cm.")
TAPose: bool | None = Field(None, description="T/A pose for human-like models.")
hd_texture: bool | None = Field(None, description="Enhanced texture quality.")
texture_delight: bool | None = Field(None, description="Remove baked lighting from textures.")
is_micro: bool | None = Field(None, description="Micro detail (Extreme-High only).")
use_original_alpha: bool | None = Field(None, description="Preserve image transparency.")
preview_render: bool | None = Field(None, description="Generate high-quality preview render.")
addons: list[str] | None = Field(None, description='Optional addons, e.g. ["HighPack"].')
class GenerateJobsData(BaseModel):
uuids: List[str] = Field(..., description="str LIST")
uuids: list[str] = Field(..., description="str LIST")
subscription_key: str = Field(..., description="subscription key")
class Rodin3DGenerateResponse(BaseModel):
message: Optional[str] = Field(None, description="Return message.")
prompt: Optional[str] = Field(None, description="Generated Prompt from image.")
submit_time: Optional[str] = Field(None, description="Submit Time")
uuid: Optional[str] = Field(None, description="Task str")
jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs")
message: str | None = Field(None, description="Return message.")
prompt: str | None = Field(None, description="Generated Prompt from image.")
submit_time: str | None = Field(None, description="Submit Time")
uuid: str | None = Field(None, description="Task str")
jobs: GenerateJobsData | None = Field(None, description="Details of jobs")
class JobStatus(str, Enum):
"""
Status for jobs
"""
Done = "Done"
Failed = "Failed"
Generating = "Generating"
Waiting = "Waiting"
class Rodin3DCheckStatusRequest(BaseModel):
subscription_key: str = Field(..., description="subscription from generate endpoint")
class JobItem(BaseModel):
uuid: str = Field(..., description="uuid")
status: JobStatus = Field(...,description="Status Currently")
status: JobStatus = Field(..., description="Status Currently")
class Rodin3DCheckStatusResponse(BaseModel):
jobs: List[JobItem] = Field(..., description="Job status List")
jobs: list[JobItem] = Field(..., description="Job status List")
class Rodin3DDownloadRequest(BaseModel):
task_uuid: str = Field(..., description="Task str")
class RodinResourceItem(BaseModel):
url: str = Field(..., description="Download Url")
name: str = Field(..., description="File name with ext")
class Rodin3DDownloadResponse(BaseModel):
list: List[RodinResourceItem] = Field(..., description="Source List")
items: list[RodinResourceItem] = Field(..., alias="list", description="Source List")

View File

@ -5,32 +5,37 @@ Rodin API docs: https://developer.hyper3d.ai/
"""
from inspect import cleandoc
import folder_paths as comfy_paths
import os
import logging
import math
import os
from inspect import cleandoc
from io import BytesIO
from typing_extensions import override
from typing import Any
import aiohttp
from PIL import Image
from typing_extensions import override
import folder_paths as comfy_paths
from comfy_api.latest import IO, ComfyExtension, Types
from comfy_api_nodes.apis.rodin import (
Rodin3DGenerateRequest,
Rodin3DGenerateResponse,
JobStatus,
Rodin3DCheckStatusRequest,
Rodin3DCheckStatusResponse,
Rodin3DDownloadRequest,
Rodin3DDownloadResponse,
JobStatus,
Rodin3DGen25Request,
Rodin3DGenerateRequest,
Rodin3DGenerateResponse,
)
from comfy_api_nodes.util import (
sync_op,
poll_op,
ApiEndpoint,
download_url_to_bytesio,
download_url_to_file_3d,
poll_op,
sync_op,
validate_string,
)
from comfy_api.latest import ComfyExtension, IO, Types
COMMON_PARAMETERS = [
IO.Int.Input(
@ -51,40 +56,30 @@ COMMON_PARAMETERS = [
]
def get_quality_mode(poly_count):
polycount = poly_count.split("-")
poly = polycount[1]
count = polycount[0]
if poly == "Triangle":
mesh_mode = "Raw"
elif poly == "Quad":
mesh_mode = "Quad"
else:
mesh_mode = "Quad"
if count == "4K":
quality_override = 4000
elif count == "8K":
quality_override = 8000
elif count == "18K":
quality_override = 18000
elif count == "50K":
quality_override = 50000
elif count == "2K":
quality_override = 2000
elif count == "20K":
quality_override = 20000
elif count == "150K":
quality_override = 150000
elif count == "500K":
quality_override = 500000
else:
quality_override = 18000
return mesh_mode, quality_override
_QUALITY_MESH_OPTIONS: dict[str, tuple[str, int]] = {
"4K-Quad": ("Quad", 4000),
"8K-Quad": ("Quad", 8000),
"18K-Quad": ("Quad", 18000),
"50K-Quad": ("Quad", 50000),
"200K-Quad": ("Quad", 200000),
"2K-Triangle": ("Raw", 2000),
"20K-Triangle": ("Raw", 20000),
"150K-Triangle": ("Raw", 150000),
"200K-Triangle": ("Raw", 200000),
"500K-Triangle": ("Raw", 500000),
"1M-Triangle": ("Raw", 1000000),
}
def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
def get_quality_mode(poly_count: str) -> tuple[str, int]:
"""Map a polygon-count preset like '18K-Quad' to (mesh_mode, quality_override).
Falls back to ('Quad', 18000) for unknown labels; legacy parity.
"""
return _QUALITY_MESH_OPTIONS.get(poly_count, ("Quad", 18000))
def tensor_to_filelike(tensor, max_pixels: int = 2048 * 2048):
"""
Converts a PyTorch tensor to a file-like object.
@ -96,8 +91,8 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
- io.BytesIO: A file-like object containing the image data.
"""
array = tensor.cpu().numpy()
array = (array * 255).astype('uint8')
image = Image.fromarray(array, 'RGB')
array = (array * 255).astype("uint8")
image = Image.fromarray(array, "RGB")
original_width, original_height = image.size
original_pixels = original_width * original_height
@ -112,7 +107,7 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
image.save(img_byte_arr, format="PNG") # PNG is used for lossless compression
img_byte_arr.seek(0)
return img_byte_arr
@ -145,11 +140,9 @@ async def create_generate_task(
TAPose=ta_pose,
),
files=[
(
"images",
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image)
)
for image in images if image is not None
("images", open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image))
for image in images
if image is not None
],
content_type="multipart/form-data",
)
@ -177,6 +170,7 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
return "DONE"
return "Generating"
def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None:
if not response.jobs:
return None
@ -214,7 +208,7 @@ async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.Fi
model_file_path = None
file_3d = None
for i in url_list.list:
for i in url_list.items:
file_path = os.path.join(save_path, i.name)
if i.name.lower().endswith(".glb"):
model_file_path = os.path.join(result_folder_name, i.name)
@ -489,7 +483,16 @@ class Rodin3D_Gen2(IO.ComfyNode):
IO.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
IO.Combo.Input(
"Polygon_count",
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
options=[
"4K-Quad",
"8K-Quad",
"18K-Quad",
"50K-Quad",
"2K-Triangle",
"20K-Triangle",
"150K-Triangle",
"500K-Triangle",
],
default="500K-Triangle",
optional=True,
),
@ -542,6 +545,566 @@ class Rodin3D_Gen2(IO.ComfyNode):
return IO.NodeOutput(model_path, file_3d)
def _rodin_multipart_parser(data: dict[str, Any]) -> aiohttp.FormData:
"""Convert a Rodin request dict to an aiohttp form, fixing bool/list serialization.
Booleans --> "true"/"false". Lists --> one field per element.
"""
form = aiohttp.FormData(default_to_multipart=True)
for key, value in data.items():
if value is None:
continue
if isinstance(value, bool):
form.add_field(key, "true" if value else "false")
elif isinstance(value, list):
for item in value:
form.add_field(key, str(item))
elif isinstance(value, (bytes, bytearray)):
form.add_field(key, value)
else:
form.add_field(key, str(value))
return form
async def _create_gen25_task(
cls: type[IO.ComfyNode],
request: Rodin3DGen25Request,
images: list | None,
) -> tuple[str, str]:
"""Submit a Gen-2.5 generate job; returns (task_uuid, subscription_key)."""
if images is not None and len(images) > 5:
raise ValueError("Rodin Gen-2.5 supports at most 5 input images.")
files = None
if images:
files = [
(
"images",
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image),
)
for image in images
if image is not None
]
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"),
response_model=Rodin3DGenerateResponse,
data=request,
files=files,
content_type="multipart/form-data",
multipart_parser=_rodin_multipart_parser,
)
if not response.uuid or not response.jobs or not response.jobs.subscription_key:
raise RuntimeError(f"Rodin Gen-2.5 submit failed: message={response.message!r}")
return response.uuid, response.jobs.subscription_key
_PREVIEWABLE_3D_EXTS = {".glb", ".obj", ".fbx", ".stl", ".gltf"}
async def _download_gen25_files(
download_list: Rodin3DDownloadResponse,
task_uuid: str,
geometry_file_format: str,
) -> Types.File3D | None:
"""Download every file in the list; return the File3D matching the chosen format."""
folder_name = f"Rodin3D_Gen25_{task_uuid}"
save_dir = os.path.join(comfy_paths.get_output_directory(), folder_name)
os.makedirs(save_dir, exist_ok=True)
target_ext = f".{geometry_file_format.lower().lstrip('.')}"
file_3d: Types.File3D | None = None
for item in download_list.items:
file_path = os.path.join(save_dir, item.name)
ext = os.path.splitext(item.name.lower())[1]
# Prefer the file matching the user's chosen format; fall back below.
if file_3d is None and ext == target_ext and ext in _PREVIEWABLE_3D_EXTS:
file_3d = await download_url_to_file_3d(item.url, target_ext.lstrip("."))
with open(file_path, "wb") as f:
f.write(file_3d.get_bytes())
continue
await download_url_to_bytesio(item.url, file_path)
# If the chosen format wasn't found, surface any model file we did get.
if file_3d is None:
for item in download_list.items:
ext = os.path.splitext(item.name.lower())[1]
if ext in _PREVIEWABLE_3D_EXTS:
file_3d = await download_url_to_file_3d(item.url, ext.lstrip("."))
break
return file_3d
_MODE_REGULAR = "Regular"
_MODE_FAST = "Fast"
_MODE_EXTREME_HIGH = "Extreme-High"
_REGULAR_POLY_OPTIONS = [
"Default",
"4K-Quad",
"8K-Quad",
"18K-Quad",
"50K-Quad",
"2K-Triangle",
"20K-Triangle",
"150K-Triangle",
"500K-Triangle",
"1M-Triangle",
]
_TEXTURE_MODE_OPTIONS = ["Default", "legacy", "extreme-low", "low", "medium", "high"]
_GEOMETRY_FORMAT_OPTIONS = ["glb", "fbx", "obj", "stl"]
_MATERIAL_OPTIONS = ["PBR", "Shaded", "All", "None"]
def _build_mode_input(name: str = "mode") -> IO.DynamicCombo.Input:
return IO.DynamicCombo.Input(
name,
options=[
IO.DynamicCombo.Option(
_MODE_REGULAR,
[
IO.Combo.Input(
"tier",
options=["Gen-2.5-Low", "Gen-2.5-Medium", "Gen-2.5-High"],
default="Gen-2.5-High",
tooltip="Quality tier. Higher tiers produce higher-fidelity geometry.",
),
IO.Combo.Input(
"polygon_count",
options=_REGULAR_POLY_OPTIONS,
default="Default",
tooltip="Preset face count. 'Default' uses the server's default for the selected tier.",
),
IO.Boolean.Input(
"creative",
default=False,
tooltip="Creative mode (Medium/High only). Enhances generative robustness.",
),
],
),
IO.DynamicCombo.Option(
_MODE_FAST,
[
IO.Combo.Input(
"tier",
options=[
"Gen-2.5-Extreme-Low",
"Gen-2.5-Low",
"Gen-2.5-Medium",
"Gen-2.5-High",
],
default="Gen-2.5-Low",
),
IO.Int.Input(
"mesh_faces",
default=20000,
min=1000,
max=20000,
display_mode=IO.NumberDisplay.number,
tooltip="Mesh face count (1K-20K in Fast mode).",
),
],
),
IO.DynamicCombo.Option(
_MODE_EXTREME_HIGH,
[
IO.Combo.Input("mesh_mode", options=["Raw", "Quad"], default="Raw"),
IO.Int.Input(
"mesh_faces",
default=1000000,
min=20000,
max=2000000,
display_mode=IO.NumberDisplay.number,
tooltip=(
"Mesh face count. Raw mode: 20K-2M. "
"Quad mode: keep under 200K (upstream may reject higher values)."
),
),
IO.Boolean.Input(
"is_micro",
default=False,
tooltip="Enable micro detail (Extreme-High only).",
),
IO.Boolean.Input(
"creative",
default=False,
tooltip="Creative mode. Enhances generative robustness.",
),
],
),
],
tooltip=(
"Generation mode. Regular = balanced. Fast = 1K-20K faces for rapid prototyping. "
"Extreme-High = 20K-2M faces with optional micro details."
),
)
def _build_common_inputs(*, include_image_only: bool) -> list:
inputs: list = [
IO.Combo.Input("material", options=_MATERIAL_OPTIONS, default="Shaded"),
IO.Combo.Input("geometry_file_format", options=_GEOMETRY_FORMAT_OPTIONS, default="glb"),
IO.Combo.Input(
"texture_mode",
options=_TEXTURE_MODE_OPTIONS,
default="Default",
optional=True,
tooltip="Texture quality preset. 'Default' uses the server's default for the selected tier.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=65535,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
optional=True,
),
IO.Boolean.Input(
"TAPose", default=False, optional=True, advanced=True, tooltip="T/A pose for human-like models."
),
IO.Boolean.Input(
"hd_texture", default=False, optional=True, advanced=True, tooltip="High-quality texture enhancement."
),
IO.Boolean.Input(
"texture_delight",
default=False,
optional=True,
advanced=True,
tooltip="Remove baked lighting from textures.",
),
]
if include_image_only:
inputs.append(
IO.Boolean.Input(
"use_original_alpha",
default=False,
optional=True,
advanced=True,
tooltip="Preserve image transparency.",
)
)
inputs.extend(
[
IO.Boolean.Input(
"addon_highpack",
default=False,
optional=True,
advanced=True,
tooltip="HighPack addon: 4K textures and ~16x faces in Quad mode.",
),
IO.Int.Input(
"bbox_width",
default=0,
min=0,
max=300,
display_mode=IO.NumberDisplay.number,
optional=True,
advanced=True,
tooltip="Bounding-box width (Y axis). Set to 0 with the others to skip bbox.",
),
IO.Int.Input(
"bbox_height",
default=0,
min=0,
max=300,
display_mode=IO.NumberDisplay.number,
optional=True,
advanced=True,
tooltip="Bounding-box height (Z axis).",
),
IO.Int.Input(
"bbox_length",
default=0,
min=0,
max=300,
display_mode=IO.NumberDisplay.number,
optional=True,
advanced=True,
tooltip="Bounding-box length (X axis).",
),
IO.Int.Input(
"height_cm",
default=0,
min=0,
max=10000,
display_mode=IO.NumberDisplay.number,
optional=True,
advanced=True,
tooltip="Approximate model height in centimeters (0 to skip).",
),
]
)
return inputs
_PRICE_EXPR = """
(
$baseCredits := widgets.mode = "extreme-high" ? 1.0 : 0.5;
$addonCredits := widgets.addon_highpack ? 1.0 : 0.0;
$total := ($baseCredits * 1.5) + ($addonCredits * 0.8);
{"type":"usd","usd": $total}
)
"""
def _resolve_mode_params(mode_input: dict) -> dict:
"""Translate the DynamicCombo `mode` payload into Gen-2.5 request fields.
Returns a dict with: tier, quality_override, mesh_mode, geometry_instruct_mode, is_micro.
Missing keys mean "do not send" (so we don't override server defaults).
"""
selected = mode_input["mode"]
out: dict = {}
if selected == _MODE_REGULAR:
out["tier"] = mode_input["tier"]
polygon = mode_input.get("polygon_count", "Default")
if polygon != "Default":
mesh_mode, faces = get_quality_mode(polygon)
out["mesh_mode"] = mesh_mode
out["quality_override"] = faces
if mode_input.get("creative"):
out["geometry_instruct_mode"] = "creative"
elif selected == _MODE_FAST:
out["tier"] = mode_input["tier"]
out["mesh_mode"] = "Raw"
out["quality_override"] = int(mode_input["mesh_faces"])
elif selected == _MODE_EXTREME_HIGH:
out["tier"] = "Gen-2.5-Extreme-High"
out["mesh_mode"] = mode_input["mesh_mode"]
out["quality_override"] = int(mode_input["mesh_faces"])
if mode_input.get("is_micro"):
out["is_micro"] = True
if mode_input.get("creative"):
out["geometry_instruct_mode"] = "creative"
return out
def _build_request(
*,
mode_input: dict,
material: str,
geometry_file_format: str,
texture_mode: str,
seed: int,
TAPose: bool,
hd_texture: bool,
texture_delight: bool,
addon_highpack: bool,
bbox_width: int,
bbox_height: int,
bbox_length: int,
height_cm: int,
prompt: str | None = None,
use_original_alpha: bool = False,
) -> Rodin3DGen25Request:
mode_params = _resolve_mode_params(mode_input)
bbox = None
if bbox_width and bbox_height and bbox_length:
bbox = [bbox_width, bbox_height, bbox_length]
return Rodin3DGen25Request(
tier=mode_params["tier"],
prompt=prompt or None,
seed=seed,
material=material,
geometry_file_format=geometry_file_format,
texture_mode=None if texture_mode == "Default" else texture_mode,
mesh_mode=mode_params.get("mesh_mode"),
quality_override=mode_params.get("quality_override"),
geometry_instruct_mode=mode_params.get("geometry_instruct_mode"),
bbox_condition=bbox,
height=height_cm or None,
TAPose=TAPose or None,
hd_texture=hd_texture or None,
texture_delight=texture_delight or None,
is_micro=mode_params.get("is_micro"),
use_original_alpha=use_original_alpha or None,
addons=["HighPack"] if addon_highpack else None,
)
class Rodin3D_Gen25_Image(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Rodin3D_Gen25_Image",
display_name="Rodin 3D Gen-2.5 - Image to 3D",
category="api node/3d/Rodin",
description=(
"Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. "
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
),
inputs=[
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplatePrefix(IO.Image.Input("image"), prefix="image", min=1, max=5),
tooltip="1-5 images. The first image is used for materials when multi-view.",
),
_build_mode_input(),
*_build_common_inputs(include_image_only=True),
],
outputs=[IO.File3DAny.Output(display_name="model_file")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode", "addon_highpack"]),
expr=_PRICE_EXPR,
),
)
@classmethod
async def execute(
cls,
images: IO.Autogrow.Type,
mode: dict,
material: str,
geometry_file_format: str,
texture_mode: str,
seed: int,
TAPose: bool,
hd_texture: bool,
texture_delight: bool,
use_original_alpha: bool,
addon_highpack: bool,
bbox_width: int,
bbox_height: int,
bbox_length: int,
height_cm: int,
) -> IO.NodeOutput:
image_tensors = [img for img in images.values() if img is not None]
if not image_tensors:
raise ValueError("Rodin Gen-2.5 Image-to-3D requires at least one image.")
# Flatten multi-image tensors into individual frames; the API accepts each as a separate part.
flat_images: list = []
for tensor in image_tensors:
if hasattr(tensor, "shape") and len(tensor.shape) == 4:
for i in range(tensor.shape[0]):
flat_images.append(tensor[i])
else:
flat_images.append(tensor)
if len(flat_images) > 5:
raise ValueError(f"Rodin Gen-2.5 accepts at most 5 images; received {len(flat_images)}.")
request = _build_request(
mode_input=mode,
material=material,
geometry_file_format=geometry_file_format,
texture_mode=texture_mode,
seed=seed,
TAPose=TAPose,
hd_texture=hd_texture,
texture_delight=texture_delight,
addon_highpack=addon_highpack,
bbox_width=bbox_width,
bbox_height=bbox_height,
bbox_length=bbox_length,
height_cm=height_cm,
prompt=None,
use_original_alpha=use_original_alpha,
)
task_uuid, subscription_key = await _create_gen25_task(cls, request, flat_images)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
file_3d = await _download_gen25_files(download_list, task_uuid, geometry_file_format)
return IO.NodeOutput(file_3d)
class Rodin3D_Gen25_Text(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Rodin3D_Gen25_Text",
display_name="Rodin 3D Gen-2.5 - Text to 3D",
category="api node/3d/Rodin",
description=(
"Generate a 3D model from a text prompt via Rodin Gen-2.5. "
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the 3D model.",
),
_build_mode_input(),
*_build_common_inputs(include_image_only=False),
],
outputs=[IO.File3DAny.Output(display_name="model_file")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode", "addon_highpack"]),
expr=_PRICE_EXPR,
),
)
@classmethod
async def execute(
cls,
prompt: str,
mode: dict,
material: str,
geometry_file_format: str,
texture_mode: str,
seed: int,
TAPose: bool,
hd_texture: bool,
texture_delight: bool,
addon_highpack: bool,
bbox_width: int,
bbox_height: int,
bbox_length: int,
height_cm: int,
) -> IO.NodeOutput:
validate_string(prompt, field_name="prompt", min_length=1, max_length=2500)
request = _build_request(
mode_input=mode,
material=material,
geometry_file_format=geometry_file_format,
texture_mode=texture_mode,
seed=seed,
TAPose=TAPose,
hd_texture=hd_texture,
texture_delight=texture_delight,
addon_highpack=addon_highpack,
bbox_width=bbox_width,
bbox_height=bbox_height,
bbox_length=bbox_length,
height_cm=height_cm,
prompt=prompt,
)
task_uuid, subscription_key = await _create_gen25_task(cls, request, images=None)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
file_3d = await _download_gen25_files(download_list, task_uuid, geometry_file_format)
return IO.NodeOutput(file_3d)
class Rodin3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -551,6 +1114,8 @@ class Rodin3DExtension(ComfyExtension):
Rodin3D_Smooth,
Rodin3D_Sketch,
Rodin3D_Gen2,
Rodin3D_Gen25_Image,
Rodin3D_Gen25_Text,
]

View File

@ -1,446 +0,0 @@
"""ComfyUI nodes for Depth Anything 3.
Adds these nodes:
* ``LoadDepthAnything3`` -- load a DA3 ``.safetensors`` file from the
``models/geometry_estimation/`` folder.
* ``DepthAnything3`` -- unified depth estimation node supporting both mono and
multi-view modes via a DynamicCombo selector. Returns a DA3_GEOMETRY dict of
raw tensors (depth, sky, confidence, camera). Feed into ``DepthAnything3Render``
to produce display images, or directly into ``MoGeRender`` for depth / mask views.
* ``DepthAnything3Render`` -- post-processes a DA3_GEOMETRY dict: applies optional
sky clipping, normalises depth and confidence, and returns display images.
Model capability matrix
-----------------------
Variant head_type has_sky has_conf cam_dec
DA3-Small dualdpt False True yes
DA3-Base dualdpt False True yes
DA3-Mono-Large dpt True False no
DA3-Metric-Large dpt True False no (raw output is metres)
The node raises a ``ValueError`` at execution time when the selected
parameters conflict with the loaded model's capabilities (e.g.
``apply_sky_clip=True`` on a model with no sky head).
"""
from __future__ import annotations
from typing_extensions import override
import torch
import comfy.model_management as mm
import comfy.sd
import folder_paths
from comfy.ldm.colormap import turbo as _turbo
from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess
from comfy_api.latest import ComfyExtension, io
DA3ModelType = io.Custom("DA3_MODEL")
DA3Geometry = io.Custom("DA3_GEOMETRY")
# DA3_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them):
#
# Per-frame tensors — B = batch size in mono mode; B = S (number of views) in multi-view mode.
# "depth": torch.Tensor (B, H, W) -- raw model depth (always present; matches MoGe convention)
# "image": torch.Tensor (B, H, W, 3) -- source image in [0, 1], CPU (always present)
# "mode": str -- "mono" or "multiview" (always present)
# "sky": torch.Tensor (B, H, W) -- sky probability in [0, 1] (Mono/Metric variants only)
# "confidence": torch.Tensor (B, H, W) -- raw model confidence output (Small/Base variants only)
#
# Multi-view only — S = number of views; the leading 1 is the scene dimension from the model.
# "extrinsics": torch.Tensor (1, S, 4, 4) -- world-to-camera matrices
# "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics
class LoadDepthAnything3Model(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadDepthAnything3Model",
display_name="Load Depth Anything 3",
category="loaders",
inputs=[
io.Combo.Input(
"model_name",
options=folder_paths.get_filename_list("geometry_estimation"),
),
io.Combo.Input(
"weight_dtype",
options=["default", "fp16", "bf16", "fp32"],
default="default",
),
],
outputs=[DA3ModelType.Output()],
)
@classmethod
def execute(cls, model_name, weight_dtype) -> io.NodeOutput:
model_options = {}
if weight_dtype == "fp16":
model_options["dtype"] = torch.float16
elif weight_dtype == "bf16":
model_options["dtype"] = torch.bfloat16
elif weight_dtype == "fp32":
model_options["dtype"] = torch.float32
path = folder_paths.get_full_path_or_raise("geometry_estimation", model_name)
model = comfy.sd.load_diffusion_model(path, model_options=model_options)
return io.NodeOutput(model)
def _run_da3(model_patcher, image: torch.Tensor, process_res: int,
method: str = "upper_bound_resize"):
"""Run DA3 on ``(B,H,W,3)`` IMAGE; returns depth/conf/sky at original resolution (or None)."""
assert image.ndim == 4 and image.shape[-1] == 3, \
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
B, H, W, _ = image.shape
mm.load_model_gpu(model_patcher)
diffusion = model_patcher.model.diffusion_model
device = mm.get_torch_device()
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
depths, confs, skies = [], [], []
for i in range(B):
single = image[i:i + 1].to(device)
x = da3_preprocess.preprocess_image(single, process_res=process_res, method=method)
x = x.to(dtype=dtype)
with torch.no_grad():
out = diffusion(x)
depth_lr = out["depth"]
depth_full = torch.nn.functional.interpolate(
depth_lr.unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
depths.append(depth_full)
if "depth_conf" in out:
conf_full = torch.nn.functional.interpolate(
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
confs.append(conf_full)
if "sky" in out:
sky_full = torch.nn.functional.interpolate(
out["sky"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
skies.append(sky_full)
depth = torch.cat(depths, dim=0)
confidence = torch.cat(confs, dim=0) if confs else None
sky = torch.cat(skies, dim=0) if skies else None
return depth, confidence, sky
class DepthAnything3Inference(io.ComfyNode):
"""Raw Depth Anything 3 inference node.
Outputs a DA3_GEOMETRY dict of raw tensors. All display normalization
(sky clipping, depth scaling, confidence normalisation) is handled by
the companion ``DepthAnything3Render`` node.
Mono mode: each batch element is processed independently.
Multi-view mode: all frames share a single forward pass with cross-view
attention; adds ``extrinsics`` and ``intrinsics`` to the geometry dict.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DepthAnything3Inference",
search_aliases=["depth", "geometry", "da3", "depth anything", "monocular", "pointmap", "sky", "3d", "metric depth", "disparity"],
display_name="Run Depth Anything 3",
category="image/geometry_estimation",
description="Run Depth Anything 3 on an image or image batch. In multi-view mode each frame is treated as a separate view of the same scene.",
inputs=[
DA3ModelType.Input("da3_model"),
io.Image.Input("image",
tooltip="Single image or image batch. "
"In multi-view mode each frame is treated as "
"a separate view of the same scene."),
io.Int.Input("process_res", default=504, min=140, max=2520, step=14,
tooltip="Resolution the model runs at (longest side, multiple of 14). "
"Lower = faster / less VRAM; higher = more detail. "
"Output is upsampled back to the original size."),
io.Combo.Input("resize_method",
options=["upper_bound_resize", "lower_bound_resize"],
default="upper_bound_resize",
tooltip="- upper_bound_resize: scale so the longest side = process_res (caps memory, default).\n"
"- lower_bound_resize: scale so the shortest side = process_res (preserves more detail on tall/wide images, uses more memory)."),
io.DynamicCombo.Input("mode",
tooltip="- mono: single image or independent batch — use with any model.\n"
"- multiview: all frames processed together with cross-view attention for geometric consistency; also outputs camera pose — requires DA3-Small or DA3-Base.",
options=[
io.DynamicCombo.Option("mono", []),
io.DynamicCombo.Option("multiview", [
io.Combo.Input("ref_view_strategy",
options=["saddle_balanced", "saddle_sim_range",
"first", "middle"],
default="saddle_balanced",
tooltip="Which view to use as the geometric anchor (only applied when S >= 3 and no extrinsics are provided).\n"
"- saddle_balanced: picks the view whose CLS-token features are closest to the median across similarity, norm and variance — best general choice.\n"
"- saddle_sim_range: picks the view with the widest similarity spread to other views — favours the most distinct viewpoint.\n"
"- first / middle: deterministic positional fallbacks."),
io.Combo.Input("pose_method",
options=["cam_dec", "ray_pose"],
default="cam_dec",
tooltip="- cam_dec: small MLP on the final camera token (DA3-Small/Base).\n"
"- ray_pose: RANSAC over the DualDPT ray output (DA3-Small/Base only)."),
]),
]),
],
outputs=[
DA3Geometry.Output("geometry",
tooltip="DA3_GEOMETRY dict of raw tensors.\n"
"- Always: 'depth' (B,H,W), 'image', 'mode'.\n"
"- Optional: 'sky' + 'mask' (Mono/Metric), 'confidence' raw (Small/Base), 'extrinsics' + 'intrinsics' (multi-view)."),
],
)
@classmethod
def execute(cls, da3_model, image, process_res, resize_method, mode) -> io.NodeOutput:
mode_val = mode["mode"] # "mono" or "multiview"
if mode_val == "mono":
return cls._execute_mono(da3_model, image, process_res, resize_method)
# Capability checks for multi-view pose.
diffusion = da3_model.model.diffusion_model
pose_method = mode["pose_method"]
ref_view_strategy = mode["ref_view_strategy"]
if pose_method == "cam_dec" and diffusion.cam_dec is None:
raise ValueError(
"pose_method='cam_dec' requires a camera decoder, but the loaded "
"model does not have one. Load a model with a camera decoder "
"(e.g. DA3-Small or DA3-Base), or set pose_method='ray_pose'."
)
if pose_method == "ray_pose" and diffusion.head_type != "dualdpt":
raise ValueError(
"pose_method='ray_pose' requires a DualDPT head, but the loaded "
"model has a DPT head. Load a model with a DualDPT head "
"(e.g. DA3-Small or DA3-Base), or set pose_method='cam_dec'."
)
return cls._execute_multiview(
da3_model, image, process_res, resize_method,
ref_view_strategy, pose_method,
)
@classmethod
def _execute_mono(cls, model, image, process_res, resize_method) -> io.NodeOutput:
depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method)
geometry: dict = {
"depth": depth.contiguous(),
"image": image[..., :3].cpu(),
"mode": "mono",
}
if sky is not None:
geometry["sky"] = sky.contiguous()
if confidence is not None:
geometry["confidence"] = confidence.contiguous()
return io.NodeOutput(geometry)
@classmethod
def _execute_multiview(cls, model, image, process_res, resize_method,
ref_view_strategy, pose_method) -> io.NodeOutput:
assert image.ndim == 4 and image.shape[-1] == 3, \
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
S, H, W, _ = image.shape
mm.load_model_gpu(model)
diffusion = model.model.diffusion_model
device = mm.get_torch_device()
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
# All views in a single forward pass: (1, S, 3, H', W').
x = image.to(device)
x = da3_preprocess.preprocess_image(x, process_res=process_res, method=resize_method)
x = x.to(dtype=dtype).unsqueeze(0)
use_ray_pose = (pose_method == "ray_pose")
with torch.no_grad():
out = diffusion(x, use_ray_pose=use_ray_pose,
ref_view_strategy=ref_view_strategy)
depth = torch.nn.functional.interpolate(
out["depth"].float().unsqueeze(1), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
sky = None
if "sky" in out:
sky = torch.nn.functional.interpolate(
out["sky"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
if "extrinsics" in out and "intrinsics" in out:
extrinsics = out["extrinsics"].float().cpu()
intrinsics = out["intrinsics"].float().cpu()
else:
extrinsics = torch.eye(4)[None, None].expand(1, S, 4, 4).clone()
intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone()
geometry: dict = {
"depth": depth.contiguous(),
"image": image[..., :3].cpu(),
"mode": "multiview",
"extrinsics": extrinsics.contiguous(),
"intrinsics": intrinsics.contiguous(),
}
if sky is not None:
geometry["sky"] = sky.contiguous()
if "depth_conf" in out:
conf = torch.nn.functional.interpolate(
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
geometry["confidence"] = conf.contiguous()
return io.NodeOutput(geometry)
class DepthAnything3Render(io.ComfyNode):
"""Visualise a DA3_GEOMETRY packet as a single image.
Mirrors the MoGeRender interface: one ``output`` selector, one IMAGE out.
Use multiple nodes in parallel to get depth + sky + confidence simultaneously.
"""
_DEPTH_RENDER_INPUTS = [
io.Combo.Input("normalization",
options=["v2_style", "min_max", "raw"],
default="v2_style",
tooltip="- v2_style: mean/std normalisation for perceptually balanced results (default).\n"
"- min_max: stretches the full depth range to [0, 1] for maximum contrast.\n"
"- raw: no scaling — preserves metric units for DA3-Metric-Large."),
io.Boolean.Input("apply_sky_clip", default=False,
tooltip="Clip sky-region depth to the 99th percentile of foreground depth before "
"normalisation. Requires a 'sky' tensor in the geometry "
"(DA3-Mono-Large or DA3-Metric-Large); raises an error otherwise."),
]
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DepthAnything3Render",
display_name="Depth Anything 3 Render",
category="image/geometry_estimation",
description="Visualise a DA3_GEOMETRY packet. Drop multiple nodes to get different views simultaneously.",
inputs=[
DA3Geometry.Input("geometry"),
io.DynamicCombo.Input("output",
tooltip="- depth: normalised greyscale depth image.\n"
"- depth_colored: depth mapped through the Turbo colormap.\n"
"- sky_mask: sky probability in [0, 1] (Mono/Metric variants only).\n"
"- confidence: normalised depth confidence (Small/Base variants only).",
options=[
io.DynamicCombo.Option("depth", cls._DEPTH_RENDER_INPUTS),
io.DynamicCombo.Option("depth_colored", cls._DEPTH_RENDER_INPUTS),
io.DynamicCombo.Option("sky_mask", []),
io.DynamicCombo.Option("confidence", []),
]),
],
outputs=[io.Image.Output()],
)
@classmethod
def execute(cls, geometry, output) -> io.NodeOutput:
output_val = output["output"]
if output_val in ("depth", "depth_colored"):
normalization = output["normalization"]
apply_sky_clip = output["apply_sky_clip"]
if apply_sky_clip and "sky" not in geometry:
raise ValueError(
"apply_sky_clip=True requires a sky tensor in the geometry, but none is present. "
"Run with DA3-Mono-Large or DA3-Metric-Large, or set apply_sky_clip=False."
)
depth = geometry["depth"]
sky = geometry.get("sky")
if apply_sky_clip and sky is not None:
depth = torch.stack([
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
for i in range(depth.shape[0])
], dim=0)
grey = cls._depth_to_image(depth, sky, normalization) # (B,H,W,3) greyscale
result = _turbo(grey[..., 0]) if output_val == "depth_colored" else grey
elif output_val == "sky_mask":
if "sky" not in geometry:
raise ValueError("geometry has no sky output; run with DA3-Mono-Large or DA3-Metric-Large.")
sky = geometry["sky"]
result = sky.unsqueeze(-1).expand(*sky.shape, 3).contiguous()
elif output_val == "confidence":
if "confidence" not in geometry:
raise ValueError("geometry has no confidence output; run with DA3-Small or DA3-Base.")
result = cls._normalize_confidence(geometry["confidence"])
result = result.unsqueeze(-1).expand(*result.shape, 3).contiguous()
else:
raise ValueError(f"Unknown output mode: {output_val}")
return io.NodeOutput(result.float())
@staticmethod
def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None,
normalization: str) -> torch.Tensor:
"""Normalise depth and pack as an (B,H,W,3) image tensor."""
N = depth.shape[0]
if normalization == "v2_style":
norm = torch.stack([
da3_preprocess.normalize_depth_v2_style(
depth[i], sky_for_norm[i] if sky_for_norm is not None else None)
for i in range(N)
], dim=0)
elif normalization == "min_max":
norm = da3_preprocess.normalize_depth_min_max(depth)
else:
norm = depth
out = norm.unsqueeze(-1).repeat(1, 1, 1, 3)
if normalization != "raw":
out = out.clamp(0.0, 1.0)
return out.contiguous()
@staticmethod
def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor:
"""Map raw confidence (expp1 activaton, range [1, ∞)) to [0, 1] per image.
The model uses ``exp(x) + 1`` so every pixel is guaranteed to be ≥ 1.
Min-max normalization per image preserves the spatial pattern (high
confidence = brighter) while producing a valid mask in [0, 1].
"""
B = conf.shape[0]
out = []
for i in range(B):
c = conf[i]
c_min = c.min()
c_max = c.max()
if c_max > c_min:
out.append((c - c_min) / (c_max - c_min))
else:
out.append(torch.ones_like(c))
return torch.stack(out, dim=0)
class DepthAnything3Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LoadDepthAnything3Model,
DepthAnything3Inference,
DepthAnything3Render,
]
async def comfy_entrypoint() -> DepthAnything3Extension:
return DepthAnything3Extension()

View File

@ -8,6 +8,82 @@ from comfy_api.latest import _io
MISSING = object()
class NotNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ComfyNotNode",
display_name="Not",
category="utils/logic",
description="Logical NOT operation. Returns true if the value is falsy. Uses Python's rules for truthiness.",
search_aliases=["invert", "toggle", "negate", "flip boolean"],
inputs=[
io.AnyType.Input("value"),
],
outputs=[
io.Boolean.Output(),
],
)
@classmethod
def execute(cls, value) -> io.NodeOutput:
return io.NodeOutput(not value)
class AndNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.Autogrow.TemplatePrefix(
input=io.AnyType.Input("value"),
prefix="value",
min=1,
)
return io.Schema(
node_id="ComfyAndNode",
display_name="And",
category="utils/logic",
description="Logical AND operation. Returns true if all of the values are truthy. Uses Python's rules for truthiness.",
search_aliases=["all", "every"],
inputs=[
io.Autogrow.Input("values", template=template),
],
outputs=[
io.Boolean.Output(),
],
)
@classmethod
def execute(cls, values: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(all(values.values()))
class OrNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.Autogrow.TemplatePrefix(
input=io.AnyType.Input("value"),
prefix="value",
min=1,
)
return io.Schema(
node_id="ComfyOrNode",
display_name="Or",
category="utils/logic",
description="Logical OR operation. Returns true if any of the values are truthy. Uses Python's rules for truthiness.",
search_aliases=["any", "some"],
inputs=[
io.Autogrow.Input("values", template=template),
],
outputs=[
io.Boolean.Output(),
],
)
@classmethod
def execute(cls, values: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(any(values.values()))
class SwitchNode(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -15,7 +91,7 @@ class SwitchNode(io.ComfyNode):
return io.Schema(
node_id="ComfySwitchNode",
display_name="Switch",
category="logic",
category="utils/logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
@ -46,7 +122,7 @@ class SoftSwitchNode(io.ComfyNode):
return io.Schema(
node_id="ComfySoftSwitchNode",
display_name="Soft Switch",
category="logic",
category="utils/logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
@ -136,7 +212,7 @@ class DCTestNode(io.ComfyNode):
return io.Schema(
node_id="DCTestNode",
display_name="DCTest",
category="logic",
category="utils/logic",
is_output_node=True,
inputs=[io.DynamicCombo.Input("combo", options=[
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
@ -174,7 +250,7 @@ class AutogrowNamesTestNode(io.ComfyNode):
return io.Schema(
node_id="AutogrowNamesTestNode",
display_name="AutogrowNamesTest",
category="logic",
category="utils/logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
@ -194,7 +270,7 @@ class AutogrowPrefixTestNode(io.ComfyNode):
return io.Schema(
node_id="AutogrowPrefixTestNode",
display_name="AutogrowPrefixTest",
category="logic",
category="utils/logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
@ -213,7 +289,7 @@ class ComboOutputTestNode(io.ComfyNode):
return io.Schema(
node_id="ComboOptionTestNode",
display_name="ComboOptionTest",
category="logic",
category="utils/logic",
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
outputs=[io.Combo.Output(), io.Combo.Output()],
@ -230,7 +306,7 @@ class ConvertStringToComboNode(io.ComfyNode):
node_id="ConvertStringToComboNode",
search_aliases=["string to dropdown", "text to combo"],
display_name="Convert String to Combo",
category="logic",
category="utils/logic",
inputs=[io.String.Input("string")],
outputs=[io.Combo.Output()],
)
@ -246,7 +322,7 @@ class InvertBooleanNode(io.ComfyNode):
node_id="InvertBooleanNode",
search_aliases=["not", "toggle", "negate", "flip boolean"],
display_name="Invert Boolean",
category="logic",
category="utils/logic",
inputs=[io.Boolean.Input("boolean")],
outputs=[io.Boolean.Output()],
)
@ -261,6 +337,9 @@ class LogicExtension(ComfyExtension):
return [
SwitchNode,
CustomComboNode,
NotNode,
AndNode,
OrNode,
# SoftSwitchNode,
# ConvertStringToComboNode,
# DCTestNode,

View File

@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode):
return io.Schema(
node_id="ComfyMathExpression",
display_name="Math Expression",
category="logic",
category="utils",
search_aliases=[
"expression", "formula", "calculate", "calculator",
"eval", "math",

View File

@ -9,7 +9,6 @@ import folder_paths
from comfy_api.latest import ComfyExtension, Types, io
from typing_extensions import override
from comfy.ldm.colormap import turbo as _turbo
from comfy.ldm.moge.model import MoGeModel
from comfy.ldm.moge.geometry import triangulate_grid_mesh
from comfy.ldm.moge.panorama import get_panorama_cameras, split_panorama_image, merge_panorama_depth, spherical_uv_to_directions, _uv_grid
@ -29,6 +28,19 @@ MoGeGeometry = io.Custom("MOGE_GEOMETRY")
# "image": torch.Tensor (B, H, W, 3) in [0, 1], CPU (always present)
def _turbo(x: torch.Tensor) -> torch.Tensor:
"""Anton Mikhailov polynomial approximation of the turbo colormap."""
x = x.clamp(0.0, 1.0)
x2 = x * x
x3 = x2 * x
x4 = x2 * x2
x5 = x4 * x
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)
def _normals_from_points(points: torch.Tensor) -> torch.Tensor:
"""Camera-space surface normals from a (B, H, W, 3) point map (v1 fallback)."""
finite = torch.isfinite(points).all(dim=-1)

View File

@ -14,7 +14,7 @@ class CreateList(io.ComfyNode):
return io.Schema(
node_id="CreateList",
display_name="Create List",
category="logic",
category="utils",
is_input_list=True,
search_aliases=["Image Iterator", "Text Iterator", "Iterator"],
inputs=[io.Autogrow.Input("inputs", template=template_autogrow)],

View File

@ -2445,7 +2445,6 @@ async def init_builtin_extra_nodes():
"nodes_save_3d.py",
"nodes_moge.py",
"nodes_mediapipe.py",
"nodes_depth_anything_3.py",
]
import_failed = []

File diff suppressed because it is too large Load Diff