mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-22 01:00:07 +08:00
Compare commits
12 Commits
trellis2
...
matt/asset
| Author | SHA1 | Date | |
|---|---|---|---|
| dc6190e8ba | |||
| 2d21956ac7 | |||
| 396bfe4056 | |||
| 00940fb24e | |||
| 7ff001d7c8 | |||
| 19ba85bb2e | |||
| 3ffc49aa0e | |||
| 36f9a6fdef | |||
| a0d1238829 | |||
| 1688a5e262 | |||
| 7ab346fc7b | |||
| 5b7288d700 |
@ -401,12 +401,16 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
)
|
||||
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
# tag[1] may be the standalone category ("checkpoints") or the
|
||||
# slash-joined shape ("checkpoints/flux/...") that
|
||||
# `get_name_and_tags_from_asset_path` and cloud both emit. Match
|
||||
# `resolve_destination_from_tags` by extracting the first segment.
|
||||
category = spec.tags[1].split("/", 1)[0] if len(spec.tags) >= 2 else ""
|
||||
if (
|
||||
len(spec.tags) < 2
|
||||
or spec.tags[1] not in folder_paths.folder_names_and_paths
|
||||
or category not in folder_paths.folder_names_and_paths
|
||||
):
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
category = spec.tags[1] if len(spec.tags) >= 2 else ""
|
||||
return _build_error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{category}'"
|
||||
)
|
||||
|
||||
@ -327,7 +327,12 @@ def list_references_page(
|
||||
select(AssetReferenceTag.asset_reference_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
|
||||
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
|
||||
.order_by(AssetReferenceTag.tag_name.asc())
|
||||
# Preserve insertion order so the structural first tag (the root
|
||||
# category like "models") stays in position 0 and the path-derived
|
||||
# sub-path tag stays in position 1, matching cloud's behavior.
|
||||
# tag_name is a deterministic tiebreaker when multiple tags share
|
||||
# an added_at (same-batch insert via set_reference_tags).
|
||||
.order_by(AssetReferenceTag.added_at.asc(), AssetReferenceTag.tag_name.asc())
|
||||
)
|
||||
for ref_id, tag_name in rows.all():
|
||||
tag_map[ref_id].append(tag_name)
|
||||
@ -355,7 +360,8 @@ def fetch_reference_asset_and_tags(
|
||||
build_visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetReference.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
# See list_references_page for the rationale behind ordering by added_at.
|
||||
.order_by(AssetReferenceTag.added_at.asc(), Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -20,7 +21,12 @@ from app.assets.database.queries.common import (
|
||||
build_visible_owner_clause,
|
||||
iter_row_chunks,
|
||||
)
|
||||
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
||||
from app.assets.helpers import (
|
||||
escape_sql_like_string,
|
||||
expand_bucket_prefixes,
|
||||
get_utc_now,
|
||||
normalize_tags,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -44,6 +50,26 @@ class SetTagsResult:
|
||||
total: list[str]
|
||||
|
||||
|
||||
def _next_added_at_base(session: Session, reference_id: str) -> datetime:
|
||||
"""Return a timestamp strictly greater than any existing
|
||||
`added_at` for this reference. On platforms where the wall clock
|
||||
has insufficient resolution between back-to-back commits (notably
|
||||
Windows), two write batches on the same reference can otherwise
|
||||
share a microsecond — the `ORDER BY added_at, tag_name` retrieval
|
||||
then falls back to the alphabetic tiebreaker and user-tier tags
|
||||
sort ahead of path-tier tags they were meant to follow.
|
||||
"""
|
||||
existing_max = session.execute(
|
||||
sa.select(sa.func.max(AssetReferenceTag.added_at)).where(
|
||||
AssetReferenceTag.asset_reference_id == reference_id
|
||||
)
|
||||
).scalar()
|
||||
now = get_utc_now()
|
||||
if existing_max is None:
|
||||
return now
|
||||
return max(existing_max + timedelta(microseconds=1), now)
|
||||
|
||||
|
||||
def validate_tags_exist(session: Session, tags: list[str]) -> None:
|
||||
"""Raise ValueError if any of the given tag names do not exist."""
|
||||
existing_tag_names = set(
|
||||
@ -77,7 +103,13 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
|
||||
session.execute(
|
||||
select(AssetReferenceTag.tag_name)
|
||||
.where(AssetReferenceTag.asset_reference_id == reference_id)
|
||||
.order_by(AssetReferenceTag.tag_name.asc())
|
||||
# Match the response-path ordering used by
|
||||
# list_references_page / fetch_reference_asset_and_tags so
|
||||
# upload responses and subsequent GETs agree on tag order.
|
||||
.order_by(
|
||||
AssetReferenceTag.added_at.asc(),
|
||||
AssetReferenceTag.tag_name.asc(),
|
||||
)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
@ -89,7 +121,7 @@ def set_reference_tags(
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> SetTagsResult:
|
||||
desired = normalize_tags(tags)
|
||||
desired = expand_bucket_prefixes(normalize_tags(tags))
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
@ -98,15 +130,22 @@ def set_reference_tags(
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
# Stagger added_at by microsecond per tag so the retrieval ORDER BY
|
||||
# added_at preserves input order. Per-tag get_utc_now() calls can
|
||||
# collide at microsecond resolution on fast machines, dropping the
|
||||
# query to the tag_name alphabetical tiebreaker — same fix as in
|
||||
# batch_insert_seed_assets. Read max(existing) so this batch sorts
|
||||
# strictly after any prior batch on the same reference.
|
||||
base_ts = _next_added_at_base(session, reference_id)
|
||||
session.add_all(
|
||||
[
|
||||
AssetReferenceTag(
|
||||
asset_reference_id=reference_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=get_utc_now(),
|
||||
added_at=base_ts + timedelta(microseconds=i),
|
||||
)
|
||||
for t in to_add
|
||||
for i, t in enumerate(to_add)
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
@ -136,7 +175,7 @@ def add_tags_to_reference(
|
||||
if not ref:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
norm = expand_bucket_prefixes(normalize_tags(tags))
|
||||
if not norm:
|
||||
total = get_reference_tags(session, reference_id=reference_id)
|
||||
return AddTagsResult(added=[], already_present=[], total_tags=total)
|
||||
@ -146,10 +185,17 @@ def add_tags_to_reference(
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
# Preserve the caller's insertion order rather than alphabetizing —
|
||||
# the retrieval ORDER BY added_at + microsecond stagger only meaningfully
|
||||
# preserves insertion order if "the order we insert in" actually matches
|
||||
# the caller's intent.
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
to_add = [t for t in norm if t not in current]
|
||||
|
||||
if to_add:
|
||||
# See set_reference_tags for the rationale behind the per-tag stagger
|
||||
# and the max(existing) seed.
|
||||
base_ts = _next_added_at_base(session, reference_id)
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
@ -158,9 +204,9 @@ def add_tags_to_reference(
|
||||
asset_reference_id=reference_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=get_utc_now(),
|
||||
added_at=base_ts + timedelta(microseconds=i),
|
||||
)
|
||||
for t in to_add
|
||||
for i, t in enumerate(to_add)
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
@ -47,6 +47,50 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
|
||||
|
||||
|
||||
def _known_bucket_prefixes() -> set[str]:
|
||||
"""Lowercased model-category names eligible for standalone-prefix
|
||||
expansion. Tags whose first slash segment matches one of these get
|
||||
the bucket inserted as a separate token, so FE filters like
|
||||
``include_tags=models,checkpoints`` keep matching even when the
|
||||
asset lives in a nested subfolder (`models/checkpoints/flux/foo`).
|
||||
|
||||
Bare user labels with slashes whose first segment is not a registered
|
||||
bucket (e.g. ``my-org/team-a``) pass through unchanged.
|
||||
"""
|
||||
try:
|
||||
import folder_paths
|
||||
|
||||
return {
|
||||
name.lower()
|
||||
for name in folder_paths.folder_names_and_paths.keys()
|
||||
if name != "custom_nodes"
|
||||
}
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
|
||||
def expand_bucket_prefixes(tags: list[str]) -> list[str]:
|
||||
"""Insert standalone bucket tokens after any slash-joined tag whose
|
||||
first segment is a registered model category. Preserves caller order
|
||||
and is idempotent (existing bucket tokens are not duplicated).
|
||||
"""
|
||||
if not tags:
|
||||
return list(tags)
|
||||
buckets = _known_bucket_prefixes()
|
||||
if not buckets:
|
||||
return list(tags)
|
||||
seen = set(tags)
|
||||
result: list[str] = []
|
||||
for t in tags:
|
||||
result.append(t)
|
||||
if "/" in t:
|
||||
prefix = t.split("/", 1)[0]
|
||||
if prefix.lower() in buckets and prefix not in seen:
|
||||
result.append(prefix)
|
||||
seen.add(prefix)
|
||||
return result
|
||||
|
||||
|
||||
def validate_blake3_hash(s: str) -> str:
|
||||
"""Validate and normalize a blake3 hash string.
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@ -13,13 +13,14 @@ from app.assets.database.queries import (
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
bulk_insert_tags_and_meta,
|
||||
delete_assets_by_ids,
|
||||
ensure_tags_exist,
|
||||
get_existing_asset_ids,
|
||||
get_reference_ids_by_ids,
|
||||
get_references_by_paths_and_asset_ids,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
restore_references_by_paths,
|
||||
)
|
||||
from app.assets.helpers import get_utc_now
|
||||
from app.assets.helpers import expand_bucket_prefixes, get_utc_now
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.assets.services.metadata_extract import ExtractedMetadata
|
||||
@ -233,13 +234,20 @@ def batch_insert_seed_assets(
|
||||
if ref_id not in inserted_ref_ids:
|
||||
continue
|
||||
|
||||
for tag in ref_data["tags"]:
|
||||
# Stagger added_at by microsecond per tag within a reference so
|
||||
# the retrieval ORDER BY added_at preserves the input list order
|
||||
# (the path-derived root category stays at position 0). Without
|
||||
# this, every tag in a bulk-insert batch shares current_time and
|
||||
# the tag_name tiebreaker sorts them alphabetically — putting the
|
||||
# subpath tag ahead of "models" since "c"/"d"/"l" < "m".
|
||||
ref_tags = expand_bucket_prefixes(ref_data["tags"])
|
||||
for tag_idx, tag in enumerate(ref_tags):
|
||||
tag_rows.append(
|
||||
{
|
||||
"asset_reference_id": ref_id,
|
||||
"tag_name": tag,
|
||||
"origin": "automatic",
|
||||
"added_at": current_time,
|
||||
"added_at": current_time + timedelta(microseconds=tag_idx),
|
||||
}
|
||||
)
|
||||
|
||||
@ -261,6 +269,16 @@ def batch_insert_seed_assets(
|
||||
}
|
||||
)
|
||||
|
||||
if tag_rows:
|
||||
# Bucket-prefix expansion may have introduced tags the caller did
|
||||
# not register via the upstream tag_pool (e.g. `checkpoints` for a
|
||||
# nested `checkpoints/flux/foo` path). Pre-register the full set so
|
||||
# the AssetReferenceTag.tag_name FK is satisfied; the underlying
|
||||
# insert is ON CONFLICT DO NOTHING so re-registration is idempotent.
|
||||
ensure_tags_exist(
|
||||
session, {row["tag_name"] for row in tag_rows}, tag_type="user"
|
||||
)
|
||||
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows)
|
||||
|
||||
return BulkInsertResult(
|
||||
|
||||
@ -3,7 +3,6 @@ from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import folder_paths
|
||||
from app.assets.helpers import normalize_tags
|
||||
|
||||
|
||||
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
|
||||
@ -27,27 +26,51 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs).
|
||||
|
||||
Accepts both the legacy one-tag-per-directory shape
|
||||
(``["models", "diffusers", "Kolors", "text_encoder"]``) and the
|
||||
slash-joined shape emitted by :func:`get_name_and_tags_from_asset_path`
|
||||
(``["models", "diffusers/Kolors/text_encoder"]``). Hybrid shapes that
|
||||
mix the two within a single call (e.g.
|
||||
``["models", "diffusers", "Kolors/text_encoder"]``) are also
|
||||
accepted: each entry after ``tags[0]`` is split on ``/`` and
|
||||
concatenated, so the two shapes — and any mix of them — resolve to
|
||||
the same destination. The same safety checks are applied to each
|
||||
component after expansion.
|
||||
"""
|
||||
if not tags:
|
||||
raise ValueError("tags must not be empty")
|
||||
root = tags[0].lower()
|
||||
|
||||
# Expand any slash-joined entries into individual path components so
|
||||
# the rest of the function can treat both tag shapes uniformly. Each
|
||||
# component is also stripped, so " a / b " behaves like ["a", "b"].
|
||||
expanded: list[str] = []
|
||||
for t in tags[1:]:
|
||||
for part in str(t).split("/"):
|
||||
part = part.strip()
|
||||
if part:
|
||||
expanded.append(part)
|
||||
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
if not expanded:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
category = expanded[0]
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
bases = folder_paths.folder_names_and_paths[category][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
raise ValueError(f"unknown model category '{category}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
raise ValueError(f"no base path configured for category '{category}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
raw_subdirs = expanded[1:]
|
||||
elif root == "input":
|
||||
base_dir = os.path.abspath(folder_paths.get_input_directory())
|
||||
raw_subdirs = tags[1:]
|
||||
raw_subdirs = expanded
|
||||
elif root == "output":
|
||||
base_dir = os.path.abspath(folder_paths.get_output_directory())
|
||||
raw_subdirs = tags[1:]
|
||||
raw_subdirs = expanded
|
||||
else:
|
||||
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
|
||||
_sep_chars = frozenset(("/", "\\", os.sep))
|
||||
@ -160,7 +183,21 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return (name, tags) derived from a filesystem path.
|
||||
|
||||
- name: base filename with extension
|
||||
- tags: [root_category] + parent folder names in order
|
||||
- tags: [root_category] for paths with no parent subdirectories,
|
||||
[root_category, slash_joined_subpath] otherwise. The parent subpath
|
||||
(everything between the root category and the filename) is collapsed
|
||||
into a single tag rather than emitted as one tag per directory, so
|
||||
consumers can use ``tags[1]`` as a stable category identifier that
|
||||
survives nested directory layouts (e.g. diffusers components).
|
||||
|
||||
The subpath is lowercased to match the canonicalization applied by
|
||||
:func:`ensure_tags_exist`; without that, the
|
||||
``asset_reference_tags.tag_name`` FK to the lowercased ``tags.name``
|
||||
would fail for any path containing uppercase letters. The root
|
||||
category is lowercase by construction in
|
||||
:func:`get_asset_category_and_relative_path`, so no separate cast
|
||||
is applied here. Consumers that need to look up providers keyed on
|
||||
original-case paths should normalize their lookup key to lowercase.
|
||||
|
||||
Raises:
|
||||
ValueError: path does not belong to any known root.
|
||||
@ -170,4 +207,7 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
parent_parts = [
|
||||
part for part in p.parent.parts if part not in (".", "..", p.anchor)
|
||||
]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
tags = [root_category]
|
||||
if parent_parts:
|
||||
tags.append("/".join(parent_parts).lower())
|
||||
return p.name, list(dict.fromkeys(t.strip() for t in tags if t.strip()))
|
||||
|
||||
@ -9,7 +9,6 @@ import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.clip_model
|
||||
import comfy.image_encoders.dino2
|
||||
import comfy.image_encoders.dino3
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
@ -24,7 +23,6 @@ IMAGE_ENCODERS = {
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
"dinov3": comfy.image_encoders.dino3.DINOv3ViTModel
|
||||
}
|
||||
|
||||
class ClipVisionModel():
|
||||
@ -136,8 +134,6 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
||||
elif 'layer.9.attention.o_proj.bias' in sd: # dinov3
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino3_large.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@ -1,285 +0,0 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
|
||||
|
||||
class DINOv3ViTMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.act_fn = torch.nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.up_proj(x)))
|
||||
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, **kwargs):
|
||||
num_tokens = q.shape[-2]
|
||||
num_patches = sin.shape[-2]
|
||||
num_prefix_tokens = num_tokens - num_patches
|
||||
|
||||
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
|
||||
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
|
||||
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
|
||||
|
||||
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
|
||||
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
|
||||
|
||||
return q, k
|
||||
|
||||
class DINOv3ViTAttention(nn.Module):
|
||||
def __init__(self, hidden_size, num_attention_heads, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.embed_dim = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
|
||||
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
|
||||
batch_size, patches, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
attn = optimized_attention_for_device(query_states.device, mask=False)
|
||||
|
||||
attn_output = attn(
|
||||
query_states, key_states, value_states, self.num_heads, attention_mask, skip_reshape=True, skip_output_reshape=True
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
class DINOv3ViTGatedMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.act_fn = torch.nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
def get_patches_center_coordinates(
|
||||
num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
|
||||
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
|
||||
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
|
||||
coords_h = coords_h / num_patches_h
|
||||
coords_w = coords_w / num_patches_w
|
||||
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
|
||||
coords = coords.flatten(0, 1)
|
||||
coords = 2.0 * coords - 1.0
|
||||
return coords
|
||||
|
||||
class DINOv3ViTRopePositionEmbedding(nn.Module):
|
||||
inv_freq: torch.Tensor
|
||||
|
||||
def __init__(self, rope_theta, hidden_size, num_attention_heads, image_size, patch_size, device, dtype):
|
||||
super().__init__()
|
||||
self.base = rope_theta
|
||||
self.head_dim = hidden_size // num_attention_heads
|
||||
self.num_patches_h = image_size // patch_size
|
||||
self.num_patches_w = image_size // patch_size
|
||||
self.patch_size = patch_size
|
||||
|
||||
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
_, _, height, width = pixel_values.shape
|
||||
num_patches_h = height // self.patch_size
|
||||
num_patches_w = width // self.patch_size
|
||||
|
||||
device = pixel_values.device
|
||||
device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
|
||||
with torch.amp.autocast(device_type = device_type, enabled=False):
|
||||
patch_coords = get_patches_center_coordinates(
|
||||
num_patches_h, num_patches_w, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
self.inv_freq = self.inv_freq.to(device)
|
||||
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
|
||||
angles = angles.flatten(1, 2)
|
||||
angles = angles.tile(2)
|
||||
|
||||
cos = torch.cos(angles)
|
||||
sin = torch.sin(angles)
|
||||
|
||||
dtype = pixel_values.dtype
|
||||
return cos.to(dtype=dtype), sin.to(dtype=dtype)
|
||||
|
||||
|
||||
class DINOv3ViTEmbeddings(nn.Module):
|
||||
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_size, device=device, dtype=dtype))
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype))
|
||||
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
|
||||
self.patch_embeddings = operations.Conv2d(
|
||||
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None):
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embeddings.weight.dtype
|
||||
|
||||
patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
|
||||
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
|
||||
|
||||
if bool_masked_pos is not None:
|
||||
mask_token = self.mask_token.to(patch_embeddings.dtype)
|
||||
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
|
||||
|
||||
cls_token = self.cls_token.expand(batch_size, -1, -1)
|
||||
register_tokens = self.register_tokens.expand(batch_size, -1, -1)
|
||||
device = patch_embeddings.device
|
||||
cls_token = cls_token.to(device)
|
||||
register_tokens = register_tokens.to(device)
|
||||
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
|
||||
|
||||
return embeddings
|
||||
|
||||
class DINOv3ViTLayer(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size, num_attention_heads,
|
||||
device, dtype, operations):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
|
||||
|
||||
self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
if use_gated_mlp:
|
||||
self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations)
|
||||
else:
|
||||
self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations)
|
||||
self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.attention(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = self.layer_scale1(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.layer_scale2(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DINOv3ViTModel(nn.Module):
|
||||
def __init__(self, config, dtype, device, operations):
|
||||
super().__init__()
|
||||
use_bf16 = comfy.model_management.should_use_bf16(device, prioritize_performance=True)
|
||||
if dtype == torch.float16 and use_bf16:
|
||||
dtype = torch.bfloat16
|
||||
elif dtype == torch.float16 and not use_bf16:
|
||||
dtype = torch.float32
|
||||
num_hidden_layers = config["num_hidden_layers"]
|
||||
hidden_size = config["hidden_size"]
|
||||
num_attention_heads = config["num_attention_heads"]
|
||||
num_register_tokens = config["num_register_tokens"]
|
||||
intermediate_size = config["intermediate_size"]
|
||||
layer_norm_eps = config["layer_norm_eps"]
|
||||
num_channels = config["num_channels"]
|
||||
patch_size = config["patch_size"]
|
||||
rope_theta = config["rope_theta"]
|
||||
|
||||
self.embeddings = DINOv3ViTEmbeddings(
|
||||
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
|
||||
rope_theta, hidden_size, num_attention_heads, image_size=512, patch_size=patch_size, dtype=dtype, device=device
|
||||
)
|
||||
self.layer = nn.ModuleList(
|
||||
[DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, mlp_bias=True,
|
||||
intermediate_size=intermediate_size,num_attention_heads = num_attention_heads,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_hidden_layers)])
|
||||
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
bool_masked_pos: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
if kwargs.get("skip_norm_elementwise", False):
|
||||
sequence_output= F.layer_norm(hidden_states, hidden_states.shape[-1:])
|
||||
else:
|
||||
norm = self.norm.to(hidden_states.device)
|
||||
sequence_output = norm(hidden_states)
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
|
||||
return sequence_output, None, pooled_output, None
|
||||
@ -1,23 +0,0 @@
|
||||
{
|
||||
"model_type": "dinov3",
|
||||
"hidden_size": 1024,
|
||||
"image_size": 224,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"key_bias": false,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"mlp_bias": true,
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 24,
|
||||
"num_register_tokens": 4,
|
||||
"patch_size": 16,
|
||||
"pos_embed_rescale": 2.0,
|
||||
"proj_bias": true,
|
||||
"query_bias": true,
|
||||
"rope_theta": 100.0,
|
||||
"use_gated_mlp": false,
|
||||
"value_bias": true,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225]
|
||||
}
|
||||
@ -760,8 +760,6 @@ class Hunyuan3Dv2_1(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class Trellis2(LatentFormat): # TODO
|
||||
latent_channels = 32
|
||||
class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
@ -1,282 +0,0 @@
|
||||
import torch
|
||||
import math
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from typing import Tuple, Union, List
|
||||
from comfy.ldm.trellis2.vae import VarLenTensor
|
||||
import comfy.ops
|
||||
|
||||
|
||||
# replica of the seedvr2 code
|
||||
def var_attn_arg(kwargs):
|
||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_kv", cu_seqlens_q)
|
||||
max_seqlen_k = kwargs.get("max_kv_seqlen", max_seqlen_q)
|
||||
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = True
|
||||
if var_length:
|
||||
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
|
||||
if not skip_reshape:
|
||||
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||
total_tokens, embed_dim = q.shape
|
||||
head_dim = embed_dim // heads
|
||||
q = q.view(total_tokens, heads, head_dim)
|
||||
k = k.view(k.shape[0], heads, head_dim)
|
||||
v = v.view(v.shape[0], heads, head_dim)
|
||||
|
||||
b = q.size(0)
|
||||
dim_head = q.shape[-1]
|
||||
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||
|
||||
mask = None
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if var_length:
|
||||
return out.transpose(1, 2).values()
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
def scaled_dot_product_attention(*args, **kwargs):
|
||||
num_all_args = len(args) + len(kwargs)
|
||||
|
||||
q = None
|
||||
if num_all_args == 1:
|
||||
qkv = args[0] if len(args) > 0 else kwargs.get('qkv')
|
||||
elif num_all_args == 2:
|
||||
q = args[0] if len(args) > 0 else kwargs.get('q')
|
||||
kv = args[1] if len(args) > 1 else kwargs.get('kv')
|
||||
elif num_all_args == 3:
|
||||
q = args[0] if len(args) > 0 else kwargs.get('q')
|
||||
k = args[1] if len(args) > 1 else kwargs.get('k')
|
||||
v = args[2] if len(args) > 2 else kwargs.get('v')
|
||||
|
||||
if q is not None:
|
||||
heads = q.shape[2]
|
||||
else:
|
||||
heads = qkv.shape[3]
|
||||
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=2)
|
||||
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs)
|
||||
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
return out
|
||||
|
||||
def sparse_windowed_scaled_dot_product_self_attention(
|
||||
qkv,
|
||||
window_size: int,
|
||||
shift_window: Tuple[int, int, int] = (0, 0, 0)
|
||||
):
|
||||
|
||||
serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}'
|
||||
serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
|
||||
if serialization_spatial_cache is None:
|
||||
fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window)
|
||||
qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args))
|
||||
else:
|
||||
fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache
|
||||
|
||||
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
|
||||
heads = qkv_feats.shape[2]
|
||||
|
||||
if optimized_attention.__name__ == 'attention_xformers':
|
||||
q, k, v = qkv_feats.unbind(dim=1)
|
||||
q = q.unsqueeze(0) # [1, M, H, C]
|
||||
k = k.unsqueeze(0) # [1, M, H, C]
|
||||
v = v.unsqueeze(0) # [1, M, H, C]
|
||||
#out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
elif optimized_attention.__name__ == 'attention_flash':
|
||||
if 'flash_attn' not in globals():
|
||||
import flash_attn
|
||||
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C]
|
||||
else:
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
|
||||
out = out[bwd_indices] # [T, H, C]
|
||||
|
||||
return qkv.replace(out)
|
||||
|
||||
def calc_window_partition(
|
||||
tensor,
|
||||
window_size: Union[int, Tuple[int, ...]],
|
||||
shift_window: Union[int, Tuple[int, ...]] = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
|
||||
|
||||
DIM = tensor.coords.shape[1] - 1
|
||||
shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
|
||||
window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
|
||||
shifted_coords = tensor.coords.clone().detach()
|
||||
shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
|
||||
|
||||
MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)]
|
||||
NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
|
||||
OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
|
||||
|
||||
shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
|
||||
shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
|
||||
fwd_indices = torch.argsort(shifted_indices)
|
||||
bwd_indices = torch.empty_like(fwd_indices)
|
||||
bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
|
||||
seq_lens = torch.bincount(shifted_indices)
|
||||
mask = seq_lens != 0
|
||||
seq_lens = seq_lens[mask]
|
||||
|
||||
if optimized_attention.__name__ == 'attention_xformers':
|
||||
if 'xops' not in globals():
|
||||
import xformers.ops as xops
|
||||
attn_func_args = {
|
||||
'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
|
||||
}
|
||||
elif optimized_attention.__name__ == 'attention_flash':
|
||||
attn_func_args = {
|
||||
'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(),
|
||||
'max_seqlen': torch.max(seq_lens)
|
||||
}
|
||||
|
||||
return fwd_indices, bwd_indices, seq_lens, attn_func_args
|
||||
|
||||
|
||||
def sparse_scaled_dot_product_attention(*args, **kwargs):
|
||||
q=None
|
||||
arg_names_dict = {
|
||||
1: ['qkv'],
|
||||
2: ['q', 'kv'],
|
||||
3: ['q', 'k', 'v']
|
||||
}
|
||||
num_all_args = len(args) + len(kwargs)
|
||||
for key in arg_names_dict[num_all_args][len(args):]:
|
||||
assert key in kwargs, f"Missing argument {key}"
|
||||
|
||||
if num_all_args == 1:
|
||||
qkv = args[0] if len(args) > 0 else kwargs['qkv']
|
||||
device = qkv.device
|
||||
|
||||
s = qkv
|
||||
q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
|
||||
kv_seqlen = q_seqlen
|
||||
qkv = qkv.feats # [T, 3, H, C]
|
||||
|
||||
elif num_all_args == 2:
|
||||
q = args[0] if len(args) > 0 else kwargs['q']
|
||||
kv = args[1] if len(args) > 1 else kwargs['kv']
|
||||
device = q.device
|
||||
|
||||
if isinstance(q, VarLenTensor):
|
||||
s = q
|
||||
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
|
||||
q = q.feats # [T_Q, H, C]
|
||||
else:
|
||||
s = None
|
||||
N, L, H, C = q.shape
|
||||
q_seqlen = [L] * N
|
||||
q = q.reshape(N * L, H, C) # [T_Q, H, C]
|
||||
|
||||
if isinstance(kv, VarLenTensor):
|
||||
kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
|
||||
kv = kv.feats # [T_KV, 2, H, C]
|
||||
else:
|
||||
N, L, _, H, C = kv.shape
|
||||
kv_seqlen = [L] * N
|
||||
kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
|
||||
|
||||
elif num_all_args == 3:
|
||||
q = args[0] if len(args) > 0 else kwargs['q']
|
||||
k = args[1] if len(args) > 1 else kwargs['k']
|
||||
v = args[2] if len(args) > 2 else kwargs['v']
|
||||
device = q.device
|
||||
|
||||
if isinstance(q, VarLenTensor):
|
||||
s = q
|
||||
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
|
||||
q = q.feats # [T_Q, H, Ci]
|
||||
else:
|
||||
s = None
|
||||
N, L, H, CI = q.shape
|
||||
q_seqlen = [L] * N
|
||||
q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
|
||||
|
||||
if isinstance(k, VarLenTensor):
|
||||
kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
|
||||
k = k.feats # [T_KV, H, Ci]
|
||||
v = v.feats # [T_KV, H, Co]
|
||||
else:
|
||||
N, L, H, CI, CO = *k.shape, v.shape[-1]
|
||||
kv_seqlen = [L] * N
|
||||
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
|
||||
v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
|
||||
|
||||
# TODO: change
|
||||
if q is not None:
|
||||
heads = q
|
||||
else:
|
||||
heads = qkv
|
||||
heads = heads.shape[2]
|
||||
if optimized_attention.__name__ == 'attention_xformers':
|
||||
if 'xops' not in globals():
|
||||
import xformers.ops as xops
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=1)
|
||||
q = q.unsqueeze(0)
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
|
||||
out = xops.memory_efficient_attention(q, k, v, mask)[0]
|
||||
elif optimized_attention.__name__ == 'attention_flash':
|
||||
if 'flash_attn' not in globals():
|
||||
import flash_attn
|
||||
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
||||
if num_all_args in [2, 3]:
|
||||
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
||||
if num_all_args == 1:
|
||||
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
|
||||
elif num_all_args == 2:
|
||||
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||
elif num_all_args == 3:
|
||||
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||
|
||||
elif optimized_attention.__name__ == "attention_pytorch":
|
||||
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
||||
if num_all_args in [2, 3]:
|
||||
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
||||
else:
|
||||
cu_seqlens_kv = cu_seqlens_q
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=1)
|
||||
out = attention_pytorch(q, k, v, heads=heads,cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max(q_seqlen), max_kv_seqlen=max(kv_seqlen),
|
||||
skip_reshape=True, skip_output_reshape=True)
|
||||
|
||||
if s is not None:
|
||||
return s.replace(out)
|
||||
else:
|
||||
return out.reshape(N, L, H, -1)
|
||||
@ -1,298 +0,0 @@
|
||||
# will contain every cuda -> pytorch operation
|
||||
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
|
||||
UINT32_SENTINEL = 0xFFFFFFFF
|
||||
|
||||
|
||||
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
|
||||
"""Kernel spatial offsets in the same order as the CUDA/Triton kernels."""
|
||||
offsets = []
|
||||
for vx in range(Kw):
|
||||
for vy in range(Kh):
|
||||
for vz in range(Kd):
|
||||
offsets.append((vx * Dw, vy * Dh, vz * Dd))
|
||||
return torch.tensor(offsets, device=device, dtype=torch.int32)
|
||||
|
||||
|
||||
class TorchHashMap:
|
||||
"""Sorted-array hashmap backed by torch.searchsorted."""
|
||||
|
||||
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
|
||||
device = keys.device
|
||||
self.sorted_keys, order = torch.sort(keys.to(torch.long))
|
||||
self.sorted_vals = values.to(torch.long)[order]
|
||||
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
|
||||
self._n = self.sorted_keys.numel()
|
||||
|
||||
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
|
||||
flat = flat_keys.to(torch.long)
|
||||
if self._n == 0:
|
||||
return torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32)
|
||||
idx = torch.searchsorted(self.sorted_keys, flat)
|
||||
idx_safe = torch.clamp(idx, max=self._n - 1)
|
||||
found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat)
|
||||
out = torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32)
|
||||
if found.any():
|
||||
out[found] = self.sorted_vals[idx_safe[found]].to(torch.int32)
|
||||
return out
|
||||
|
||||
|
||||
def build_submanifold_neighbor_map(
|
||||
hashmap,
|
||||
coords: torch.Tensor,
|
||||
W, H, D,
|
||||
Kw, Kh, Kd,
|
||||
Dw, Dh, Dd,
|
||||
):
|
||||
device = coords.device
|
||||
M = coords.shape[0]
|
||||
V = Kw * Kh * Kd
|
||||
half_V = V // 2 + 1
|
||||
INVALID = -1
|
||||
|
||||
# int32 neighbour map: 4 bytes/elem vs 8 bytes for int64
|
||||
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.int32)
|
||||
|
||||
b = coords[:, 0].long()
|
||||
x = coords[:, 1].long()
|
||||
y = coords[:, 2].long()
|
||||
z = coords[:, 3].long()
|
||||
|
||||
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device)
|
||||
|
||||
ox = x - (Kw // 2) * Dw
|
||||
oy = y - (Kh // 2) * Dh
|
||||
oz = z - (Kd // 2) * Dd
|
||||
|
||||
for v in range(half_V):
|
||||
if v == half_V - 1:
|
||||
# Center voxel always maps to itself
|
||||
neighbor[:, v] = torch.arange(M, device=device, dtype=torch.int32)
|
||||
continue
|
||||
|
||||
dx, dy, dz = offsets[v]
|
||||
|
||||
kx = ox + dx
|
||||
ky = oy + dy
|
||||
kz = oz + dz
|
||||
|
||||
valid = (
|
||||
(kx >= 0) & (kx < W) &
|
||||
(ky >= 0) & (ky < H) &
|
||||
(kz >= 0) & (kz < D)
|
||||
)
|
||||
|
||||
flat = (
|
||||
b[valid] * (W * H * D) +
|
||||
kx[valid] * (H * D) +
|
||||
ky[valid] * D +
|
||||
kz[valid]
|
||||
)
|
||||
|
||||
if flat.numel() > 0:
|
||||
found = hashmap.lookup_flat(flat)
|
||||
idx_in_M = torch.where(valid)[0]
|
||||
neighbor[idx_in_M, v] = found.to(torch.int32)
|
||||
|
||||
# BUG FIX: old code used found != hashmap.default_value which
|
||||
# compared int32 -1 against int64 4294967295 → always True.
|
||||
# We now explicitly check for valid indices.
|
||||
valid_found_mask = found >= 0
|
||||
if valid_found_mask.any():
|
||||
src_points = idx_in_M[valid_found_mask]
|
||||
dst_points = found[valid_found_mask].long()
|
||||
neighbor[dst_points, V - 1 - v] = src_points.to(torch.int32)
|
||||
|
||||
return neighbor
|
||||
|
||||
def get_recommended_chunk_mem(
|
||||
device=None,
|
||||
safety_fraction: float = 0.4,
|
||||
min_gb: float = 0.25,
|
||||
max_gb: float = 8.0,
|
||||
):
|
||||
|
||||
if device is None:
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
device = torch.device(device)
|
||||
|
||||
if device.type == 'cuda':
|
||||
try:
|
||||
idx = device.index if device.index is not None else 0
|
||||
free_bytes, total_bytes = torch.cuda.mem_get_info(idx)
|
||||
free_gb = free_bytes / (1024 ** 3)
|
||||
total_gb = total_bytes / (1024 ** 3)
|
||||
|
||||
recommended = free_gb * safety_fraction
|
||||
result = max(min_gb, min(recommended, max_gb))
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
try:
|
||||
idx = device.index if device.index is not None else 0
|
||||
total_gb = torch.cuda.get_device_properties(idx).total_memory / (1024 ** 3)
|
||||
except Exception:
|
||||
total_gb = 16.0
|
||||
|
||||
if total_gb < 12:
|
||||
result = 0.5
|
||||
elif total_gb < 16:
|
||||
result = 0.75
|
||||
elif total_gb < 24:
|
||||
result = 1.0
|
||||
elif total_gb < 32:
|
||||
result = 2.0
|
||||
elif total_gb < 48:
|
||||
result = 4.0
|
||||
else:
|
||||
result = 6.0
|
||||
return result
|
||||
|
||||
else:
|
||||
try:
|
||||
import psutil
|
||||
avail_gb = psutil.virtual_memory().available / (1024 ** 3)
|
||||
recommended = avail_gb * safety_fraction
|
||||
result = max(min_gb, min(recommended, max_gb))
|
||||
return result
|
||||
except ImportError:
|
||||
return min_gb
|
||||
|
||||
def sparse_submanifold_conv3d(
|
||||
feats: torch.Tensor,
|
||||
coords: torch.Tensor,
|
||||
shape: tuple,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
neighbor_cache: Optional[torch.Tensor],
|
||||
dilation: tuple,
|
||||
max_chunk_mem_gb: float = 6.0,
|
||||
accumulate_f32: bool = True,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
if feats.shape[0] == 0:
|
||||
Co = weight.shape[0]
|
||||
return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None
|
||||
|
||||
if len(shape) == 5:
|
||||
_, _, W, H, D = shape
|
||||
else:
|
||||
W, H, D = shape
|
||||
|
||||
Co, Kw, Kh, Kd, Ci = weight.shape
|
||||
V = Kw * Kh * Kd
|
||||
device = feats.device
|
||||
sentinel = -1
|
||||
max_chunk_mem_gb = get_recommended_chunk_mem(device)
|
||||
|
||||
if neighbor_cache is None:
|
||||
b_stride = W * H * D
|
||||
x_stride = H * D
|
||||
y_stride = D
|
||||
z_stride = 1
|
||||
|
||||
flat_keys = (coords[:, 0].long() * b_stride +
|
||||
coords[:, 1].long() * x_stride +
|
||||
coords[:, 2].long() * y_stride +
|
||||
coords[:, 3].long() * z_stride)
|
||||
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=device)
|
||||
hashmap = TorchHashMap(flat_keys, vals, UINT32_SENTINEL)
|
||||
|
||||
neighbor = build_submanifold_neighbor_map(
|
||||
hashmap, coords, W, H, D, Kw, Kh, Kd,
|
||||
dilation[0], dilation[1], dilation[2]
|
||||
)
|
||||
else:
|
||||
neighbor = neighbor_cache
|
||||
|
||||
N_pts = feats.shape[0]
|
||||
|
||||
if accumulate_f32:
|
||||
weight_T = weight.view(Co, V * Ci).to(torch.float32).T.contiguous()
|
||||
output = torch.zeros(N_pts, Co, device=device, dtype=torch.float32)
|
||||
else:
|
||||
weight_T = weight.view(Co, V * Ci).to(feats.dtype).T.contiguous()
|
||||
output = torch.zeros(N_pts, Co, device=device, dtype=feats.dtype)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chunk size from memory budget
|
||||
# ------------------------------------------------------------------
|
||||
bytes_per_elem = 4 if accumulate_f32 else feats.element_size()
|
||||
mem_per_row = V * Ci * bytes_per_elem
|
||||
max_chunk_mem = max_chunk_mem_gb * (1024 ** 3)
|
||||
chunk_size = max(1, int(max_chunk_mem / mem_per_row))
|
||||
chunk_size = min(chunk_size, N_pts)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chunked forward pass
|
||||
# Each iteration:
|
||||
# 1. gather (chunk, V, Ci) – memory bound
|
||||
# 2. mask zero invalids – in-place, no extra alloc
|
||||
# 3. reshape (chunk, V*Ci)
|
||||
# 4. GEMM (chunk, V*Ci) @ (V*Ci, Co) → (chunk, Co) – cuBLAS
|
||||
# written directly into output slice via out= argument
|
||||
# ------------------------------------------------------------------
|
||||
for start in range(0, N_pts, chunk_size):
|
||||
end = min(start + chunk_size, N_pts)
|
||||
actual_chunk = end - start
|
||||
|
||||
# (chunk, V) int32
|
||||
chunk_neighbor = neighbor[start:end]
|
||||
chunk_valid = chunk_neighbor != sentinel
|
||||
|
||||
# Clamp sentinel -1 → 0 for safe indexing. No clone of the full map.
|
||||
chunk_idx = chunk_neighbor.clamp(min=0).long()
|
||||
|
||||
# Gather: (chunk, V, Ci). Memory-bound, single index_select.
|
||||
gathered = feats[chunk_idx]
|
||||
|
||||
# Zero invalid neighbours in-place. gathered is a fresh tensor from
|
||||
# advanced indexing, so in-place mutation is safe.
|
||||
gathered.mul_(chunk_valid.unsqueeze(-1))
|
||||
|
||||
# Reshape to (chunk, V*Ci)
|
||||
gathered_flat = gathered.view(actual_chunk, V * Ci)
|
||||
if accumulate_f32:
|
||||
gathered_flat = gathered_flat.to(torch.float32)
|
||||
|
||||
# Single GEMM call per chunk, written directly into output.
|
||||
# This avoids allocating a temporary (chunk, Co) tensor.
|
||||
torch.matmul(gathered_flat, weight_T, out=output[start:end])
|
||||
|
||||
if accumulate_f32:
|
||||
output = output.to(feats.dtype)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias.unsqueeze(0).to(output.dtype)
|
||||
|
||||
return output, neighbor
|
||||
|
||||
class Mesh:
|
||||
def __init__(self,
|
||||
vertices,
|
||||
faces,
|
||||
vertex_attrs=None
|
||||
):
|
||||
self.vertices = vertices.float()
|
||||
self.faces = faces.int()
|
||||
self.vertex_attrs = vertex_attrs
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.vertices.device
|
||||
|
||||
def to(self, device, non_blocking=False):
|
||||
return Mesh(
|
||||
self.vertices.to(device, non_blocking=non_blocking),
|
||||
self.faces.to(device, non_blocking=non_blocking),
|
||||
self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None,
|
||||
)
|
||||
|
||||
def cuda(self, non_blocking=False):
|
||||
return self.to('cuda', non_blocking=non_blocking)
|
||||
|
||||
def cpu(self):
|
||||
return self.to('cpu')
|
||||
@ -1,935 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor
|
||||
from typing import Optional, Tuple, Literal, Union, List
|
||||
from comfy.ldm.trellis2.attention import (
|
||||
sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention
|
||||
)
|
||||
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
||||
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
||||
|
||||
class SparseGELU(nn.GELU):
|
||||
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
||||
return input.replace(super().forward(input.feats))
|
||||
|
||||
class SparseFeedForwardNet(nn.Module):
|
||||
def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
SparseLinear(channels, int(channels * mlp_ratio), device=device, dtype=dtype, operations=operations),
|
||||
SparseGELU(approximate="tanh"),
|
||||
SparseLinear(int(channels * mlp_ratio), channels, device=device, dtype=dtype, operations=operations),
|
||||
)
|
||||
|
||||
def forward(self, x: VarLenTensor) -> VarLenTensor:
|
||||
return self.mlp(x)
|
||||
|
||||
def manual_cast(obj, dtype):
|
||||
return obj.to(dtype=dtype)
|
||||
|
||||
class LayerNorm32(nn.LayerNorm):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_dtype = x.dtype
|
||||
x = manual_cast(x, torch.float32)
|
||||
o = super().forward(x)
|
||||
return manual_cast(o, x_dtype)
|
||||
|
||||
|
||||
class SparseMultiHeadRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, heads: int, device, dtype):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
||||
x_type = x.dtype
|
||||
x = x.float()
|
||||
if isinstance(x, VarLenTensor):
|
||||
x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale)
|
||||
else:
|
||||
x = F.normalize(x, dim=-1) * self.gamma * self.scale
|
||||
return x.to(x_type)
|
||||
|
||||
class SparseRotaryPositionEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
dim: int = 3,
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||
device=None
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.dim = dim
|
||||
self.rope_freq = rope_freq
|
||||
self.freq_dim = head_dim // 2 // dim
|
||||
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32, device=device) / self.freq_dim
|
||||
self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
|
||||
|
||||
def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
phases_list = []
|
||||
for i in range(self.dim):
|
||||
phases_list.append(torch.outer(coords[..., i], self.freqs.to(coords.device)))
|
||||
|
||||
phases = torch.cat(phases_list, dim=-1)
|
||||
|
||||
if phases.shape[-1] < self.head_dim // 2:
|
||||
padn = self.head_dim // 2 - phases.shape[-1]
|
||||
phases = torch.cat([phases, torch.zeros(*phases.shape[:-1], padn, device=phases.device)], dim=-1)
|
||||
|
||||
cos = torch.cos(phases)
|
||||
sin = torch.sin(phases)
|
||||
|
||||
f_cis_0 = torch.stack([cos, sin], dim=-1)
|
||||
f_cis_1 = torch.stack([-sin, cos], dim=-1)
|
||||
freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1)
|
||||
|
||||
return freqs_cis
|
||||
|
||||
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
|
||||
self.freqs = self.freqs.to(indices.device)
|
||||
phases = torch.outer(indices, self.freqs)
|
||||
phases = torch.polar(torch.ones_like(phases), phases)
|
||||
return phases
|
||||
|
||||
def forward(self, q, k=None):
|
||||
cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}'
|
||||
freqs_cis = q.get_spatial_cache(cache_name)
|
||||
|
||||
if freqs_cis is None:
|
||||
coords = q.coords[..., 1:].to(torch.float32)
|
||||
freqs_cis = self._get_freqs_cis(coords)
|
||||
q.register_spatial_cache(cache_name, freqs_cis)
|
||||
|
||||
if q.feats.ndim == 3:
|
||||
f_cis = freqs_cis.unsqueeze(1)
|
||||
else:
|
||||
f_cis = freqs_cis
|
||||
|
||||
if k is None:
|
||||
return q.replace(apply_rope1(q.feats, f_cis))
|
||||
|
||||
q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis)
|
||||
return q.replace(q_feats), k.replace(k_feats)
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
|
||||
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
x_rotated = x_complex * phases.unsqueeze(-2)
|
||||
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
|
||||
return x_embed
|
||||
|
||||
class RotaryPositionEmbedder(SparseRotaryPositionEmbedder):
|
||||
def forward(self, indices: torch.Tensor) -> torch.Tensor:
|
||||
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
|
||||
if torch.is_complex(phases):
|
||||
phases = phases.to(torch.complex64)
|
||||
else:
|
||||
phases = phases.to(torch.float32)
|
||||
if phases.shape[-1] < self.head_dim // 2:
|
||||
padn = self.head_dim // 2 - phases.shape[-1]
|
||||
phases = torch.cat([phases, torch.polar(
|
||||
torch.ones(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32),
|
||||
torch.zeros(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32)
|
||||
)], dim=-1)
|
||||
return phases
|
||||
|
||||
class SparseMultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_heads: int,
|
||||
ctx_channels: Optional[int] = None,
|
||||
type: Literal["self", "cross"] = "self",
|
||||
attn_mode: Literal["full", "windowed", "double_windowed"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
qkv_bias: bool = True,
|
||||
use_rope: bool = False,
|
||||
rope_freq: Tuple[int, int] = (1.0, 10000.0),
|
||||
qk_rms_norm: bool = False,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.head_dim = channels // num_heads
|
||||
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
||||
self.num_heads = num_heads
|
||||
self._type = type
|
||||
self.attn_mode = attn_mode
|
||||
self.window_size = window_size
|
||||
self.shift_window = shift_window
|
||||
self.use_rope = use_rope
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
|
||||
if self._type == "self":
|
||||
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, device=device, dtype=dtype)
|
||||
else:
|
||||
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
if self.qk_rms_norm:
|
||||
self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
|
||||
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
|
||||
|
||||
if use_rope:
|
||||
self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq, device=device)
|
||||
|
||||
@staticmethod
|
||||
def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
||||
if isinstance(x, VarLenTensor):
|
||||
return x.replace(module(x.feats))
|
||||
else:
|
||||
return module(x)
|
||||
|
||||
@staticmethod
|
||||
def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]:
|
||||
if isinstance(x, VarLenTensor):
|
||||
return x.reshape(*shape)
|
||||
else:
|
||||
return x.reshape(*x.shape[:2], *shape)
|
||||
|
||||
def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]:
|
||||
if isinstance(x, VarLenTensor):
|
||||
x_feats = x.feats.unsqueeze(0)
|
||||
else:
|
||||
x_feats = x
|
||||
x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
|
||||
return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats
|
||||
|
||||
def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor:
|
||||
if self._type == "self":
|
||||
dtype = next(self.to_qkv.parameters()).dtype
|
||||
x = x.to(dtype)
|
||||
qkv = self._linear(self.to_qkv, x)
|
||||
qkv = self._fused_pre(qkv, num_fused=3)
|
||||
if self.qk_rms_norm or self.use_rope:
|
||||
q, k, v = qkv.unbind(dim=-3)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k = self.k_rms_norm(k)
|
||||
if self.use_rope:
|
||||
q, k = self.rope(q, k)
|
||||
qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
|
||||
if self.attn_mode == "full":
|
||||
h = sparse_scaled_dot_product_attention(qkv)
|
||||
elif self.attn_mode == "windowed":
|
||||
h = sparse_windowed_scaled_dot_product_self_attention(
|
||||
qkv, self.window_size, shift_window=self.shift_window
|
||||
)
|
||||
elif self.attn_mode == "double_windowed":
|
||||
qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:])
|
||||
qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2])
|
||||
h0 = sparse_windowed_scaled_dot_product_self_attention(
|
||||
qkv0, self.window_size, shift_window=(0, 0, 0)
|
||||
)
|
||||
h1 = sparse_windowed_scaled_dot_product_self_attention(
|
||||
qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3)
|
||||
)
|
||||
h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1))
|
||||
else:
|
||||
q = self._linear(self.to_q, x)
|
||||
q = self._reshape_chs(q, (self.num_heads, -1))
|
||||
dtype = next(self.to_kv.parameters()).dtype
|
||||
context = context.to(dtype)
|
||||
kv = self._linear(self.to_kv, context)
|
||||
kv = self._fused_pre(kv, num_fused=2)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k, v = kv.unbind(dim=-3)
|
||||
k = self.k_rms_norm(k)
|
||||
h = sparse_scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
h = sparse_scaled_dot_product_attention(q, kv)
|
||||
h = self._reshape_chs(h, (-1,))
|
||||
h = self._linear(self.to_out, h)
|
||||
return h
|
||||
|
||||
class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
"""
|
||||
Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
ctx_channels: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "swin"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
use_checkpoint: bool = False,
|
||||
use_rope: bool = False,
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
share_mod: bool = False,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
|
||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
self.self_attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
type="self",
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_window=shift_window,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rope=use_rope,
|
||||
rope_freq=rope_freq,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
self.cross_attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
ctx_channels=ctx_channels,
|
||||
num_heads=num_heads,
|
||||
type="cross",
|
||||
attn_mode="full",
|
||||
qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm_cross,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
self.mlp = SparseFeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(channels, 6 * channels, bias=True, device=device, dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
||||
|
||||
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
h = h * (1 + scale_msa) + shift_msa
|
||||
h = self.self_attn(h)
|
||||
h = h * gate_msa
|
||||
x = x + h
|
||||
h = x.replace(self.norm2(x.feats))
|
||||
h = self.cross_attn(h, context)
|
||||
x = x + h
|
||||
h = x.replace(self.norm3(x.feats))
|
||||
h = h * (1 + scale_mlp) + shift_mlp
|
||||
h = self.mlp(h)
|
||||
h = h * gate_mlp
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
|
||||
return self._forward(x, mod, context)
|
||||
|
||||
|
||||
class SLatFlowModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
cond_channels: int,
|
||||
out_channels: int,
|
||||
num_blocks: int,
|
||||
num_heads: Optional[int] = None,
|
||||
num_head_channels: Optional[int] = 64,
|
||||
mlp_ratio: float = 4,
|
||||
pe_mode: Literal["ape", "rope"] = "rope",
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||
use_checkpoint: bool = False,
|
||||
share_mod: bool = False,
|
||||
initialization: str = 'vanilla',
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
dtype = None,
|
||||
device = None,
|
||||
operations = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.pe_mode = pe_mode
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.initialization = initialization
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||
self.dtype = dtype
|
||||
|
||||
self.t_embedder = TimestepEmbedder(model_channels, device=device, dtype=dtype, operations=operations)
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
self.input_layer = SparseLinear(in_channels, model_channels, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
ModulatedSparseTransformerCrossBlock(
|
||||
model_channels,
|
||||
cond_channels,
|
||||
num_heads=self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
attn_mode='full',
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
use_rope=(pe_mode == "rope"),
|
||||
rope_freq=rope_freq,
|
||||
share_mod=self.share_mod,
|
||||
qk_rms_norm=self.qk_rms_norm,
|
||||
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
|
||||
self.out_layer = SparseLinear(model_channels, out_channels, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: SparseTensor,
|
||||
t: torch.Tensor,
|
||||
cond: Union[torch.Tensor, List[torch.Tensor]],
|
||||
concat_cond: Optional[SparseTensor] = None,
|
||||
**kwargs
|
||||
) -> SparseTensor:
|
||||
if concat_cond is not None:
|
||||
x = sparse_cat([x, concat_cond], dim=-1)
|
||||
if isinstance(cond, list):
|
||||
cond = VarLenTensor.from_tensor_list(cond)
|
||||
|
||||
dtype = next(self.input_layer.parameters()).dtype
|
||||
x = x.to(dtype)
|
||||
h = self.input_layer(x)
|
||||
h = manual_cast(h, self.dtype)
|
||||
t = t.to(dtype)
|
||||
t_embedder = self.t_embedder.to(dtype)
|
||||
t_emb = t_embedder(t, out_dtype = t.dtype)
|
||||
if self.share_mod:
|
||||
t_emb = self.adaLN_modulation(t_emb)
|
||||
t_emb = manual_cast(t_emb, self.dtype)
|
||||
cond = manual_cast(cond, self.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
h = block(h, t_emb, cond)
|
||||
|
||||
h = manual_cast(h, x.dtype)
|
||||
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
||||
h = self.out_layer(h)
|
||||
return h
|
||||
|
||||
class FeedForwardNet(nn.Module):
|
||||
def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(channels, int(channels * mlp_ratio), device=device, dtype=dtype),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(int(channels * mlp_ratio), channels, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(x)
|
||||
|
||||
class MultiHeadRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, heads: int, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_heads: int,
|
||||
ctx_channels: Optional[int]=None,
|
||||
type: Literal["self", "cross"] = "self",
|
||||
attn_mode: Literal["full", "windowed"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
qkv_bias: bool = True,
|
||||
use_rope: bool = False,
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||
qk_rms_norm: bool = False,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.head_dim = channels // num_heads
|
||||
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
||||
self.num_heads = num_heads
|
||||
self._type = type
|
||||
self.attn_mode = attn_mode
|
||||
self.window_size = window_size
|
||||
self.shift_window = shift_window
|
||||
self.use_rope = use_rope
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
|
||||
if self._type == "self":
|
||||
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
if self.qk_rms_norm:
|
||||
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
|
||||
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
B, L, C = x.shape
|
||||
if self._type == "self":
|
||||
x = x.to(next(self.to_qkv.parameters()).dtype)
|
||||
qkv = self.to_qkv(x)
|
||||
qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
|
||||
|
||||
if self.attn_mode == "full":
|
||||
if self.qk_rms_norm or self.use_rope:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k = self.k_rms_norm(k)
|
||||
if self.use_rope:
|
||||
assert phases is not None, "Phases must be provided for RoPE"
|
||||
q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases)
|
||||
k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases)
|
||||
h = scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
h = scaled_dot_product_attention(qkv)
|
||||
else:
|
||||
Lkv = context.shape[1]
|
||||
q = self.to_q(x)
|
||||
context = context.to(next(self.to_kv.parameters()).dtype)
|
||||
kv = self.to_kv(context)
|
||||
q = q.reshape(B, L, self.num_heads, -1)
|
||||
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k, v = kv.unbind(dim=2)
|
||||
k = self.k_rms_norm(k)
|
||||
h = scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
h = scaled_dot_product_attention(q, kv)
|
||||
h = h.reshape(B, L, -1)
|
||||
h = self.to_out(h)
|
||||
return h
|
||||
|
||||
class ModulatedTransformerCrossBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
ctx_channels: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "windowed"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
use_checkpoint: bool = False,
|
||||
use_rope: bool = False,
|
||||
rope_freq: Tuple[int, int] = (1.0, 10000.0),
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
share_mod: bool = False,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
|
||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
self.self_attn = MultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
type="self",
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_window=shift_window,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rope=use_rope,
|
||||
rope_freq=rope_freq,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
self.cross_attn = MultiHeadAttention(
|
||||
channels,
|
||||
ctx_channels=ctx_channels,
|
||||
num_heads=num_heads,
|
||||
type="cross",
|
||||
attn_mode="full",
|
||||
qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm_cross,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
self.mlp = FeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
else:
|
||||
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
||||
|
||||
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||
h = self.norm1(x)
|
||||
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
|
||||
h = self.self_attn(h, phases=phases)
|
||||
h = h * gate_msa.unsqueeze(1)
|
||||
x = x + h
|
||||
h = self.norm2(x)
|
||||
h = self.cross_attn(h, context)
|
||||
x = x + h
|
||||
h = self.norm3(x)
|
||||
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||
h = self.mlp(h)
|
||||
h = h * gate_mlp.unsqueeze(1)
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return self._forward(x, mod, context, phases)
|
||||
|
||||
|
||||
class SparseStructureFlowModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
cond_channels: int,
|
||||
out_channels: int,
|
||||
num_blocks: int,
|
||||
num_heads: Optional[int] = None,
|
||||
num_head_channels: Optional[int] = 64,
|
||||
mlp_ratio: float = 4,
|
||||
pe_mode: Literal["ape", "rope"] = "rope",
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||
use_checkpoint: bool = False,
|
||||
share_mod: bool = False,
|
||||
initialization: str = 'vanilla',
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
operations=None,
|
||||
device = None,
|
||||
dtype = torch.float32,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.pe_mode = pe_mode
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.initialization = initialization
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.t_embedder = TimestepEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3, device=device)
|
||||
coords = torch.meshgrid(*[torch.arange(res, device=self.device, dtype=dtype) for res in [resolution] * 3], indexing='ij')
|
||||
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
|
||||
rope_phases = pos_embedder(coords)
|
||||
self.register_buffer("rope_phases", rope_phases, persistent=False)
|
||||
|
||||
if pe_mode != "rope":
|
||||
self.rope_phases = None
|
||||
|
||||
self.input_layer = operations.Linear(in_channels, model_channels, device=device, dtype=dtype)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
ModulatedTransformerCrossBlock(
|
||||
model_channels,
|
||||
cond_channels,
|
||||
num_heads=self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
attn_mode='full',
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
use_rope=(pe_mode == "rope"),
|
||||
rope_freq=rope_freq,
|
||||
share_mod=share_mod,
|
||||
qk_rms_norm=self.qk_rms_norm,
|
||||
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
|
||||
self.out_layer = operations.Linear(model_channels, out_channels, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
||||
x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3)
|
||||
|
||||
h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
|
||||
|
||||
h = h.to(next(self.input_layer.parameters()).dtype)
|
||||
h = self.input_layer(h)
|
||||
t_emb = self.t_embedder(t, out_dtype = t.dtype)
|
||||
if self.share_mod:
|
||||
t_emb = self.adaLN_modulation(t_emb)
|
||||
t_emb = manual_cast(t_emb, self.dtype)
|
||||
h = manual_cast(h, self.dtype)
|
||||
cond = manual_cast(cond, self.dtype)
|
||||
for block in self.blocks:
|
||||
h = block(h, t_emb, cond, self.rope_phases)
|
||||
h = manual_cast(h, x.dtype)
|
||||
h = F.layer_norm(h, h.shape[-1:])
|
||||
h = h.to(next(self.out_layer.parameters()).dtype)
|
||||
h = self.out_layer(h)
|
||||
|
||||
h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous()
|
||||
|
||||
return h
|
||||
|
||||
def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
|
||||
t_shifted = t_shifted / 1000.0
|
||||
t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1))
|
||||
t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear)
|
||||
t_new *= 1000.0
|
||||
return t_new
|
||||
|
||||
class Trellis2(nn.Module):
|
||||
def __init__(self, resolution,
|
||||
in_channels = 32,
|
||||
out_channels = 32,
|
||||
model_channels = 1536,
|
||||
cond_channels = 1024,
|
||||
num_blocks = 30,
|
||||
num_heads = 12,
|
||||
mlp_ratio = 5.3334,
|
||||
share_mod = True,
|
||||
qk_rms_norm = True,
|
||||
qk_rms_norm_cross = True,
|
||||
init_txt_model=False, # for now
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operations = operations or nn
|
||||
# for some reason it passes num_heads = -1
|
||||
if num_heads == -1:
|
||||
num_heads = 12
|
||||
args = {
|
||||
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
|
||||
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
|
||||
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
|
||||
}
|
||||
txt_only = kwargs.get("txt_only", False)
|
||||
if not txt_only:
|
||||
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
|
||||
self.shape2txt = None
|
||||
if init_txt_model:
|
||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
||||
self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args)
|
||||
args.pop("out_channels")
|
||||
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
||||
else:
|
||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
||||
self.guidance_interval = [0.6, 1.0]
|
||||
self.guidance_interval_txt = [0.6, 0.9]
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
model_options = {}
|
||||
if hasattr(self, "meta"):
|
||||
model_options = self.meta
|
||||
timestep = timestep.to(x.dtype)
|
||||
embeds = kwargs.get("embeds")
|
||||
if embeds is None:
|
||||
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
|
||||
|
||||
is_1024 = True#self.img2shape.resolution == 1024
|
||||
coords = model_options.get("coords", None)
|
||||
coord_counts = model_options.get("coord_counts", None)
|
||||
mode = model_options.get("generation_mode", "structure_generation")
|
||||
|
||||
is_512_run = False
|
||||
if mode == "shape_generation_512":
|
||||
is_512_run = True
|
||||
mode = "shape_generation"
|
||||
|
||||
if coords is not None:
|
||||
if x.ndim == 4:
|
||||
x = x.squeeze(-1).transpose(1, 2)
|
||||
not_struct_mode = True
|
||||
else:
|
||||
mode = "structure_generation"
|
||||
not_struct_mode = False
|
||||
|
||||
if x.size(-1) == 16 and x.size(-2) == 16:
|
||||
mode = "structure_generation"
|
||||
not_struct_mode = False
|
||||
|
||||
if not not_struct_mode:
|
||||
bsz = x.size(0)
|
||||
x = x[:, :8]
|
||||
x = x.view(bsz, 8, 16, 16, 16)
|
||||
|
||||
if is_1024 and not_struct_mode and not is_512_run:
|
||||
context = embeds
|
||||
|
||||
sigmas = transformer_options.get("sigmas")[0].item()
|
||||
if sigmas < 1.00001:
|
||||
timestep *= 1000.0
|
||||
|
||||
if context.size(0) > 1:
|
||||
cond = context.chunk(2)[1]
|
||||
else:
|
||||
cond = context
|
||||
|
||||
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
||||
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
||||
|
||||
if not_struct_mode:
|
||||
orig_bsz = x.shape[0]
|
||||
rule = txt_rule if mode == "texture_generation" else shape_rule
|
||||
|
||||
# CFG Bypass Slicing
|
||||
if rule and orig_bsz > 1:
|
||||
half = orig_bsz // 2
|
||||
x_eval = x[half:]
|
||||
t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep
|
||||
c_eval = cond
|
||||
else:
|
||||
x_eval = x
|
||||
t_eval = timestep
|
||||
c_eval = context
|
||||
|
||||
B, N, C = x_eval.shape
|
||||
|
||||
# Vectorized SparseTensor Construction
|
||||
if mode in ["shape_generation", "texture_generation"]:
|
||||
if coord_counts is not None:
|
||||
logical_batch = coord_counts.shape[0]
|
||||
# Duplicate coords if CFG is active
|
||||
if B > logical_batch:
|
||||
c_pos = coords.clone()
|
||||
c_pos[:, 0] += logical_batch
|
||||
batched_coords = torch.cat([coords, c_pos], dim=0)
|
||||
counts_eval = torch.cat([coord_counts, coord_counts], dim=0)
|
||||
else:
|
||||
batched_coords = coords
|
||||
counts_eval = coord_counts
|
||||
|
||||
# Create boolean mask [B, N] to drop the padded zeros instantly
|
||||
mask = torch.arange(N, device=x.device).unsqueeze(0) < counts_eval.unsqueeze(1)
|
||||
feats_flat = x_eval[mask]
|
||||
else:
|
||||
feats_flat = x_eval.reshape(-1, C)
|
||||
coords_list =[]
|
||||
for i in range(B):
|
||||
c = coords.clone()
|
||||
c[:, 0] = i
|
||||
coords_list.append(c)
|
||||
batched_coords = torch.cat(coords_list, dim=0)
|
||||
mask = None
|
||||
else:
|
||||
batched_coords = coords
|
||||
feats_flat = x_eval
|
||||
mask = None
|
||||
|
||||
x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
|
||||
|
||||
if mode == "shape_generation":
|
||||
if is_512_run:
|
||||
out = self.img2shape_512(x_st, t_eval, c_eval)
|
||||
else:
|
||||
out = self.img2shape(x_st, t_eval, c_eval)
|
||||
|
||||
elif mode == "texture_generation":
|
||||
if self.shape2txt is None:
|
||||
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
||||
slat = model_options.get("shape_slat")
|
||||
if slat is None:
|
||||
raise ValueError("shape_slat can't be None")
|
||||
|
||||
slat_feats = slat
|
||||
# Duplicate shape context if CFG is active
|
||||
if coord_counts is not None and B > coord_counts.shape[0]:
|
||||
slat_feats = torch.cat([slat_feats, slat_feats], dim=0)
|
||||
elif coord_counts is None:
|
||||
slat_feats = slat_feats[:N].repeat(B, 1)
|
||||
|
||||
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats.to(x_st.feats.device)], dim=-1))
|
||||
out = self.shape2txt(x_st, t_eval, c_eval)
|
||||
|
||||
else: # structure
|
||||
orig_bsz = x.shape[0]
|
||||
if shape_rule and orig_bsz > 1:
|
||||
half = orig_bsz // 2
|
||||
x_eval = x[half:]
|
||||
t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep
|
||||
out = self.structure_model(x_eval, t_eval, cond)
|
||||
out = out.repeat(2, 1, 1, 1, 1)
|
||||
else:
|
||||
out = self.structure_model(x, timestep, context)
|
||||
|
||||
if not_struct_mode:
|
||||
if mask is not None:
|
||||
# Instantly scatter the valid tokens back into a padded rectangular tensor
|
||||
padded_out = torch.zeros((B, N, out.feats.shape[-1]), device=x.device, dtype=out.feats.dtype)
|
||||
padded_out[mask] = out.feats
|
||||
out_tensor = padded_out.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
out_tensor = out.feats.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
|
||||
|
||||
if rule and orig_bsz > 1:
|
||||
out_tensor = out_tensor.repeat(2, 1, 1, 1)
|
||||
return out_tensor
|
||||
else:
|
||||
out = torch.nn.functional.pad(out, (0, 0, 0, 0, 0, 0, 0, 24))
|
||||
|
||||
return out
|
||||
File diff suppressed because it is too large
Load Diff
@ -53,7 +53,6 @@ import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.trellis2.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
import comfy.ldm.cogvideo.model
|
||||
import comfy.ldm.rt_detr.rtdetr_v4
|
||||
@ -1638,16 +1637,6 @@ class WAN22(WAN21):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class Trellis2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.trellis2.model.Trellis2):
|
||||
super().__init__(model_config, model_type, device, unet_model)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
embeds = kwargs.get("embeds")
|
||||
out["embeds"] = comfy.conds.CONDRegular(embeds)
|
||||
return out
|
||||
|
||||
class WAN21_FlowRVS(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
|
||||
model_config.unet_config["model_type"] = "t2v"
|
||||
@ -1689,6 +1678,7 @@ class WAN21_SCAIL(WAN21):
|
||||
pose_latents = kwargs.get("pose_video_latent", None)
|
||||
if pose_latents is not None:
|
||||
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
|
||||
|
||||
return out
|
||||
|
||||
class WAN22_WanDancer(WAN21):
|
||||
|
||||
@ -113,30 +113,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
||||
return unet_config
|
||||
|
||||
if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||
unet_config = {}
|
||||
unet_config["image_model"] = "trellis2"
|
||||
|
||||
unet_config["init_txt_model"] = False
|
||||
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||
unet_config["init_txt_model"] = True
|
||||
|
||||
unet_config["resolution"] = 64
|
||||
if metadata is not None:
|
||||
if "is_512" in metadata:
|
||||
unet_config["resolution"] = 32
|
||||
|
||||
unet_config["num_heads"] = 12
|
||||
return unet_config
|
||||
|
||||
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: # trellis2 texture
|
||||
unet_config = {}
|
||||
unet_config["image_model"] = "trellis2"
|
||||
unet_config["resolution"] = 64
|
||||
unet_config["num_heads"] = 12
|
||||
unet_config["txt_only"] = True
|
||||
return unet_config
|
||||
|
||||
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
||||
unet_config = {}
|
||||
unet_config["audio_model"] = "dit1.0"
|
||||
|
||||
13
comfy/sd.py
13
comfy/sd.py
@ -15,7 +15,6 @@ import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import comfy.ldm.lightricks.vae.audio_vae
|
||||
import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.trellis2.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
@ -529,18 +528,6 @@ class VAE:
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd or "txt_dec.blocks.3.4.conv2.weight" in sd: # trellis2 or trellis2 texture only
|
||||
init_txt_model = False
|
||||
init_txt_model_only = False
|
||||
if "shape_dec.blocks.1.16.to_subdiv.weight" not in sd:
|
||||
init_txt_model_only = True
|
||||
if "txt_dec.blocks.1.16.norm1.weight" in sd:
|
||||
init_txt_model = True
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
# TODO
|
||||
self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model, init_txt_model_only= init_txt_model_only)
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
|
||||
@ -1318,29 +1318,6 @@ class WAN22_T2V(WAN21_T2V):
|
||||
out = model_base.WAN22(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class Trellis2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "trellis2"
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 3.5
|
||||
|
||||
latent_format = latent_formats.Trellis2
|
||||
vae_key_prefix = ["vae."]
|
||||
clip_vision_prefix = "conditioner.main_image_encoder.model."
|
||||
# this is only needed for the texture model
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.Trellis2(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class WAN21_FlowRVS(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -1807,7 +1784,6 @@ class Kandinsky5Image(Kandinsky5):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
|
||||
class ACEStep15(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"audio_model": "ace1.5",
|
||||
@ -1847,6 +1823,7 @@ class ACEStep15(supported_models_base.BASE):
|
||||
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
||||
|
||||
|
||||
class LongCatImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "flux",
|
||||
@ -1924,7 +1901,6 @@ class ErnieImage(supported_models_base.BASE):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
|
||||
|
||||
|
||||
|
||||
class SAM3(supported_models_base.BASE):
|
||||
unet_config = {"image_model": "SAM3"}
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
@ -2044,6 +2020,7 @@ class CogVideoX_Inpaint(CogVideoX_T2V):
|
||||
out = model_base.CogVideoX(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
|
||||
models = [
|
||||
LotusD,
|
||||
Stable_Zero123,
|
||||
@ -2130,5 +2107,4 @@ models = [
|
||||
CogVideoX_I2V,
|
||||
CogVideoX_T2V,
|
||||
SVD_img2vid,
|
||||
Trellis2
|
||||
]
|
||||
|
||||
@ -7,10 +7,9 @@ import torch
|
||||
|
||||
|
||||
class VOXEL:
|
||||
def __init__(self, data: torch.Tensor, voxel_colors=None, resolution=None):
|
||||
def __init__(self, data: torch.Tensor):
|
||||
self.data = data
|
||||
self.voxel_colors = voxel_colors
|
||||
self.resolution = resolution # each 3d model has its own resolution
|
||||
|
||||
|
||||
class MESH:
|
||||
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor,
|
||||
|
||||
@ -543,7 +543,7 @@ class AudioConcat(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="AudioConcat",
|
||||
search_aliases=["join audio", "combine audio", "append audio"],
|
||||
display_name="Concatenate Audio",
|
||||
display_name="Audio Concat",
|
||||
description="Concatenates the audio1 to audio2 in the specified direction.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -597,7 +597,7 @@ class AudioMerge(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="AudioMerge",
|
||||
search_aliases=["mix audio", "overlay audio", "layer audio"],
|
||||
display_name="Merge Audio",
|
||||
display_name="Audio Merge",
|
||||
description="Combine two audio tracks by overlaying their waveforms.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -667,9 +667,8 @@ class AudioAdjustVolume(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="AudioAdjustVolume",
|
||||
search_aliases=["audio gain", "loudness", "audio level"],
|
||||
display_name="Adjust Audio Volume",
|
||||
display_name="Audio Adjust Volume",
|
||||
category="audio",
|
||||
description="Adjust the volume of the audio by a specified amount in decibels (dB).",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.Int.Input(
|
||||
|
||||
@ -47,10 +47,8 @@ class LoadImageDataSetFromFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadImageDataSetFromFolder",
|
||||
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
|
||||
display_name="Load Image (from Folder)",
|
||||
category="image",
|
||||
description="Load a dataset of images from a specified folder and return a list of images. Supported formats: PNG, JPG, JPEG, WEBP.",
|
||||
display_name="Load Image Dataset from Folder",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
@ -86,16 +84,14 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadImageTextDataSetFromFolder",
|
||||
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
|
||||
display_name="Load Image-Text (from Folder)",
|
||||
category="image",
|
||||
description="Load a dataset of pairs of images and text captions from a specified folder and return them as a list. Supported formats: PNG, JPG, JPEG, WEBP.",
|
||||
display_name="Load Image and Text Dataset from Folder",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"folder",
|
||||
options=folder_paths.get_input_subfolders(),
|
||||
tooltip="The folder to load images and text captions from.",
|
||||
tooltip="The folder to load images from.",
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
@ -210,10 +206,8 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveImageDataSetToFolder",
|
||||
search_aliases=["save folder", "save to folder", "save dataset", "save images", "export dataset"],
|
||||
display_name="Save Image (to Folder) (DEPRECATED)",
|
||||
category="image",
|
||||
description="Save a dataset of images to a specified folder. Supported formats: PNG.",
|
||||
display_name="Save Image Dataset to Folder",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
is_input_list=True, # Receive images as list
|
||||
@ -232,7 +226,6 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[],
|
||||
is_deprecated=True, # This node is redundant and superseded by existing Save Image nodes where the target folder can be specified in the filename_prefix
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -253,20 +246,14 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveImageTextDataSetToFolder",
|
||||
search_aliases=["save folder", "save to folder", "save dataset", "save images", "save text", "export dataset"],
|
||||
display_name="Save Image-Text (to Folder)",
|
||||
category="image",
|
||||
description="Save a dataset of pairs of images and text captions to a specified folder. Images are saved as PNG files and captions are saved as TXT files with the same filename_prefix.",
|
||||
display_name="Save Image and Text Dataset to Folder",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
is_input_list=True, # Receive both images and texts as lists
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="List of images to save."),
|
||||
io.String.Input("texts",
|
||||
optional=True,
|
||||
force_input=True,
|
||||
tooltip="List of text captions to save."
|
||||
),
|
||||
io.String.Input("texts", tooltip="List of text captions to save."),
|
||||
io.String.Input(
|
||||
"folder_name",
|
||||
default="dataset",
|
||||
@ -283,7 +270,7 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, folder_name, filename_prefix, texts=None):
|
||||
def execute(cls, images, texts, folder_name, filename_prefix):
|
||||
# Extract scalar values
|
||||
folder_name = folder_name[0]
|
||||
filename_prefix = filename_prefix[0]
|
||||
@ -292,12 +279,11 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
saved_files = save_images_to_folder(images, output_dir, filename_prefix)
|
||||
|
||||
# Save captions
|
||||
if texts:
|
||||
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
|
||||
caption_filename = filename.replace(".png", ".txt")
|
||||
caption_path = os.path.join(output_dir, caption_filename)
|
||||
with open(caption_path, "w", encoding="utf-8") as f:
|
||||
f.write(caption)
|
||||
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
|
||||
caption_filename = filename.replace(".png", ".txt")
|
||||
caption_path = os.path.join(output_dir, caption_filename)
|
||||
with open(caption_path, "w", encoding="utf-8") as f:
|
||||
f.write(caption)
|
||||
|
||||
logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.")
|
||||
return io.NodeOutput()
|
||||
@ -328,13 +314,11 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
|
||||
Child classes should set:
|
||||
node_id: Unique node identifier (required)
|
||||
search_aliases: List of search aliases (optional)
|
||||
display_name: Display name (optional, defaults to node_id)
|
||||
description: Node description (optional)
|
||||
extra_inputs: List of additional io.Input objects beyond "images" (optional)
|
||||
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
||||
is_output_list: True (list output) or False (single output) (optional, default True)
|
||||
is_deprecated: True if the node is deprecated (optional, default False)
|
||||
|
||||
Child classes must implement ONE of:
|
||||
_process(cls, image, **kwargs) -> tensor (for single-item processing)
|
||||
@ -342,13 +326,12 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
"""
|
||||
|
||||
node_id = None
|
||||
search_aliases = []
|
||||
display_name = None
|
||||
description = None
|
||||
extra_inputs = []
|
||||
is_group_process = None # None = auto-detect, True/False = explicit
|
||||
is_output_list = None # None = auto-detect based on processing mode
|
||||
is_deprecated = False
|
||||
|
||||
@classmethod
|
||||
def _detect_processing_mode(cls):
|
||||
"""Detect whether this node uses group or individual processing.
|
||||
@ -419,10 +402,8 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
|
||||
return io.Schema(
|
||||
node_id=cls.node_id,
|
||||
search_aliases=cls.search_aliases,
|
||||
display_name=cls.display_name or cls.node_id,
|
||||
category=cls.category,
|
||||
description=cls.description,
|
||||
category="dataset/image",
|
||||
is_experimental=True,
|
||||
is_input_list=is_group, # True for group, False for individual
|
||||
inputs=inputs,
|
||||
@ -491,13 +472,11 @@ class TextProcessingNode(io.ComfyNode):
|
||||
|
||||
Child classes should set:
|
||||
node_id: Unique node identifier (required)
|
||||
search_aliases: List of search aliases (optional)
|
||||
display_name: Display name (optional, defaults to node_id)
|
||||
description: Node description (optional)
|
||||
extra_inputs: List of additional io.Input objects beyond "texts" (optional)
|
||||
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
||||
is_output_list: True (list output) or False (single output) (optional, default True)
|
||||
is_deprecated: True if the node is deprecated (optional, default False)
|
||||
|
||||
Child classes must implement ONE of:
|
||||
_process(cls, text, **kwargs) -> str (for single-item processing)
|
||||
@ -505,13 +484,12 @@ class TextProcessingNode(io.ComfyNode):
|
||||
"""
|
||||
|
||||
node_id = None
|
||||
search_aliases = []
|
||||
display_name = None
|
||||
description = None
|
||||
extra_inputs = []
|
||||
is_group_process = None # None = auto-detect, True/False = explicit
|
||||
is_output_list = None # None = auto-detect based on processing mode
|
||||
is_deprecated = False
|
||||
|
||||
@classmethod
|
||||
def _detect_processing_mode(cls):
|
||||
"""Detect whether this node uses group or individual processing.
|
||||
@ -649,17 +627,15 @@ class TextProcessingNode(io.ComfyNode):
|
||||
|
||||
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesByShorterEdge"
|
||||
display_name = "Resize Images by Shorter Edge (DEPRECATED)"
|
||||
category = "image/transform"
|
||||
description = "Resize images so that the shorter edge matches the specified dimension while preserving aspect ratio."
|
||||
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale shorter dimension
|
||||
display_name = "Resize Images by Shorter Edge"
|
||||
description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio."
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"shorter_edge",
|
||||
default=512,
|
||||
min=1,
|
||||
max=8192,
|
||||
tooltip="Target dimension for the shorter edge.",
|
||||
tooltip="Target length for the shorter edge.",
|
||||
),
|
||||
]
|
||||
|
||||
@ -679,17 +655,15 @@ class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
|
||||
class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesByLongerEdge"
|
||||
display_name = "Resize Images by Longer Edge (DEPRECATED)"
|
||||
category = "image/transform"
|
||||
description = "Resize images so that the longer edge matches the specified dimension while preserving aspect ratio."
|
||||
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale longer dimension
|
||||
display_name = "Resize Images by Longer Edge"
|
||||
description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio."
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"longer_edge",
|
||||
default=1024,
|
||||
min=1,
|
||||
max=8192,
|
||||
tooltip="Target dimension for the longer edge.",
|
||||
tooltip="Target length for the longer edge.",
|
||||
),
|
||||
]
|
||||
|
||||
@ -712,10 +686,8 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
||||
|
||||
class CenterCropImagesNode(ImageProcessingNode):
|
||||
node_id = "CenterCropImages"
|
||||
search_aliases=["crop", "cut", "trim"]
|
||||
display_name="Crop Image (Center)"
|
||||
category="image/transform"
|
||||
description = "Center crop an image to the specified dimensions."
|
||||
display_name = "Center Crop Images"
|
||||
description = "Center crop all images to the specified dimensions."
|
||||
extra_inputs = [
|
||||
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
|
||||
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
|
||||
@ -734,11 +706,10 @@ class CenterCropImagesNode(ImageProcessingNode):
|
||||
|
||||
class RandomCropImagesNode(ImageProcessingNode):
|
||||
node_id = "RandomCropImages"
|
||||
search_aliases=["crop", "cut", "trim"]
|
||||
display_name = "Crop Image (Random)"
|
||||
category="image/transform"
|
||||
description = "Randomly crop an image to the specified dimensions."
|
||||
|
||||
display_name = "Random Crop Images"
|
||||
description = (
|
||||
"Randomly crop all images to the specified dimensions (for data augmentation)."
|
||||
)
|
||||
extra_inputs = [
|
||||
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
|
||||
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
|
||||
@ -763,9 +734,7 @@ class RandomCropImagesNode(ImageProcessingNode):
|
||||
|
||||
class NormalizeImagesNode(ImageProcessingNode):
|
||||
node_id = "NormalizeImages"
|
||||
search_aliases=["normalize", "normalize colors"]
|
||||
display_name = "Normalize Image Colors"
|
||||
category = "image/color"
|
||||
display_name = "Normalize Images"
|
||||
description = "Normalize images using mean and standard deviation."
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
@ -793,10 +762,8 @@ class NormalizeImagesNode(ImageProcessingNode):
|
||||
|
||||
class AdjustBrightnessNode(ImageProcessingNode):
|
||||
node_id = "AdjustBrightness"
|
||||
search_aliases=["brightness"]
|
||||
display_name = "Adjust Brightness"
|
||||
category="image/adjustments"
|
||||
description = "Adjust the brightness of an image."
|
||||
description = "Adjust brightness of all images."
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
"factor",
|
||||
@ -814,10 +781,8 @@ class AdjustBrightnessNode(ImageProcessingNode):
|
||||
|
||||
class AdjustContrastNode(ImageProcessingNode):
|
||||
node_id = "AdjustContrast"
|
||||
search_aliases=["contrast"]
|
||||
display_name = "Adjust Contrast"
|
||||
category="image/adjustments"
|
||||
description = "Adjust the contrast of an image."
|
||||
description = "Adjust contrast of all images."
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
"factor",
|
||||
@ -835,10 +800,8 @@ class AdjustContrastNode(ImageProcessingNode):
|
||||
|
||||
class ShuffleDatasetNode(ImageProcessingNode):
|
||||
node_id = "ShuffleDataset"
|
||||
search_aliases=["shuffle", "randomize", "mix"]
|
||||
display_name = "Shuffle Images List"
|
||||
category = "image/batch"
|
||||
description = "Randomly shuffle the order of images in a list."
|
||||
display_name = "Shuffle Image Dataset"
|
||||
description = "Randomly shuffle the order of images in the dataset."
|
||||
is_group_process = True # Requires full list to shuffle
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
@ -860,15 +823,13 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ShuffleImageTextDataset",
|
||||
search_aliases=["shuffle", "randomize", "mix"],
|
||||
display_name = "Shuffle Pairs of Image-Text",
|
||||
category = "image/batch",
|
||||
description = "Randomly shuffle the order of pairs of image-text in a list.",
|
||||
display_name="Shuffle Image-Text Dataset",
|
||||
category="dataset/image",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="List of images to shuffle."),
|
||||
io.String.Input("texts", tooltip="List of texts to shuffle.", force_input=True),
|
||||
io.String.Input("texts", tooltip="List of texts to shuffle."),
|
||||
io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
@ -904,11 +865,8 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
|
||||
|
||||
class TextToLowercaseNode(TextProcessingNode):
|
||||
node_id = "TextToLowercase"
|
||||
search_aliases=["lowercase"]
|
||||
display_name = "Convert Text to Lowercase (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Convert text to lowercase."
|
||||
is_deprecated = True # This node is superseded by the Convert Text Case node
|
||||
display_name = "Text to Lowercase"
|
||||
description = "Convert all texts to lowercase."
|
||||
|
||||
@classmethod
|
||||
def _process(cls, text):
|
||||
@ -917,11 +875,8 @@ class TextToLowercaseNode(TextProcessingNode):
|
||||
|
||||
class TextToUppercaseNode(TextProcessingNode):
|
||||
node_id = "TextToUppercase"
|
||||
search_aliases=["uppercase"]
|
||||
display_name = "Convert Text to Uppercase (DEPRECATED)"
|
||||
category = "text"
|
||||
description = "Convert text to uppercase."
|
||||
is_deprecated = True # This node is superseded by the Convert Text Case node
|
||||
display_name = "Text to Uppercase"
|
||||
description = "Convert all texts to uppercase."
|
||||
|
||||
@classmethod
|
||||
def _process(cls, text):
|
||||
@ -930,10 +885,8 @@ class TextToUppercaseNode(TextProcessingNode):
|
||||
|
||||
class TruncateTextNode(TextProcessingNode):
|
||||
node_id = "TruncateText"
|
||||
search_aliases=["truncate", "cut", "shorten"]
|
||||
display_name = "Truncate Text"
|
||||
category = "text"
|
||||
description = "Truncate text to a maximum length."
|
||||
description = "Truncate all texts to a maximum length."
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"max_length", default=77, min=1, max=10000, tooltip="Maximum text length."
|
||||
@ -947,10 +900,8 @@ class TruncateTextNode(TextProcessingNode):
|
||||
|
||||
class AddTextPrefixNode(TextProcessingNode):
|
||||
node_id = "AddTextPrefix"
|
||||
display_name = "Add Text Prefix (DEPRECATED)"
|
||||
category = "text"
|
||||
display_name = "Add Text Prefix"
|
||||
description = "Add a prefix to all texts."
|
||||
is_deprecated = True # This node is superseded by the Concatenate Text node
|
||||
extra_inputs = [
|
||||
io.String.Input("prefix", default="", tooltip="Prefix to add."),
|
||||
]
|
||||
@ -962,10 +913,8 @@ class AddTextPrefixNode(TextProcessingNode):
|
||||
|
||||
class AddTextSuffixNode(TextProcessingNode):
|
||||
node_id = "AddTextSuffix"
|
||||
display_name = "Add Text Suffix (DEPRECATED)"
|
||||
category = "text"
|
||||
display_name = "Add Text Suffix"
|
||||
description = "Add a suffix to all texts."
|
||||
is_deprecated = True # This node is superseded by the Concatenate Text node
|
||||
extra_inputs = [
|
||||
io.String.Input("suffix", default="", tooltip="Suffix to add."),
|
||||
]
|
||||
@ -977,10 +926,8 @@ class AddTextSuffixNode(TextProcessingNode):
|
||||
|
||||
class ReplaceTextNode(TextProcessingNode):
|
||||
node_id = "ReplaceText"
|
||||
display_name = "Replace Text (DEPRECATED)"
|
||||
category = "text"
|
||||
display_name = "Replace Text"
|
||||
description = "Replace text in all texts."
|
||||
is_deprecated = True # This node is superseded by the other Replace Text node
|
||||
extra_inputs = [
|
||||
io.String.Input("find", default="", tooltip="Text to find."),
|
||||
io.String.Input("replace", default="", tooltip="Text to replace with."),
|
||||
@ -993,10 +940,8 @@ class ReplaceTextNode(TextProcessingNode):
|
||||
|
||||
class StripWhitespaceNode(TextProcessingNode):
|
||||
node_id = "StripWhitespace"
|
||||
display_name = "Strip Whitespace (DEPRECATED)"
|
||||
category = "text"
|
||||
display_name = "Strip Whitespace"
|
||||
description = "Strip leading and trailing whitespace from all texts."
|
||||
is_deprecated = True # This node is superseded by the Trim Text node
|
||||
|
||||
@classmethod
|
||||
def _process(cls, text):
|
||||
@ -1007,13 +952,11 @@ class StripWhitespaceNode(TextProcessingNode):
|
||||
|
||||
|
||||
class ImageDeduplicationNode(ImageProcessingNode):
|
||||
"""Remove duplicate or very similar images from a list using perceptual hashing."""
|
||||
"""Remove duplicate or very similar images from the dataset using perceptual hashing."""
|
||||
|
||||
node_id = "ImageDeduplication"
|
||||
search_aliases=["deduplicate", "remove duplicates", "similarity filter"]
|
||||
display_name = "Deduplicate Images"
|
||||
category = "image/batch"
|
||||
description = "Remove duplicate or very similar images from a list."
|
||||
display_name = "Image Deduplication"
|
||||
description = "Remove duplicate or very similar images from the dataset."
|
||||
is_group_process = True # Requires full list to compare images
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
@ -1083,9 +1026,7 @@ class ImageGridNode(ImageProcessingNode):
|
||||
"""Combine multiple images into a single grid/collage."""
|
||||
|
||||
node_id = "ImageGrid"
|
||||
search_aliases=["grid", "collage", "combine"]
|
||||
display_name = "Make Image Grid"
|
||||
category="image/batch"
|
||||
display_name = "Image Grid"
|
||||
description = "Arrange multiple images into a grid layout."
|
||||
is_group_process = True # Requires full list to create grid
|
||||
is_output_list = False # Outputs single grid image
|
||||
@ -1161,12 +1102,9 @@ class MergeImageListsNode(ImageProcessingNode):
|
||||
"""Merge multiple image lists into a single list."""
|
||||
|
||||
node_id = "MergeImageLists"
|
||||
search_aliases=["list", "merge list", "make list"]
|
||||
display_name = "Merge Image Lists (DEPRECATED)"
|
||||
category = "image/batch"
|
||||
display_name = "Merge Image Lists"
|
||||
description = "Concatenate multiple image lists into one."
|
||||
is_group_process = True # Receives images as list
|
||||
is_deprecated = True # This node is superseded by the Create List node
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, images):
|
||||
@ -1181,11 +1119,9 @@ class MergeTextListsNode(TextProcessingNode):
|
||||
"""Merge multiple text lists into a single list."""
|
||||
|
||||
node_id = "MergeTextLists"
|
||||
display_name = "Merge Text Lists (DEPRECATED)"
|
||||
category = "text"
|
||||
display_name = "Merge Text Lists"
|
||||
description = "Concatenate multiple text lists into one."
|
||||
is_group_process = True # Receives texts as list
|
||||
is_deprecated = True # This node is superseded by the Create List node
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, texts):
|
||||
@ -1206,10 +1142,8 @@ class ResolutionBucket(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ResolutionBucket",
|
||||
search_aliases=["bucket by resolution", "group by resolution", "batch by resolution"],
|
||||
display_name="Resolution Bucket",
|
||||
category="training",
|
||||
description="Group latents and conditionings into buckets",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
@ -1302,8 +1236,7 @@ class MakeTrainingDataset(io.ComfyNode):
|
||||
node_id="MakeTrainingDataset",
|
||||
search_aliases=["encode dataset"],
|
||||
display_name="Make Training Dataset",
|
||||
category="training",
|
||||
description="Encode images with VAE and texts with CLIP to create a training dataset of latents and conditionings.",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_input_list=True, # images and texts as lists
|
||||
inputs=[
|
||||
@ -1318,7 +1251,6 @@ class MakeTrainingDataset(io.ComfyNode):
|
||||
"texts",
|
||||
optional=True,
|
||||
tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).",
|
||||
force_input=True
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
@ -1388,10 +1320,9 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveTrainingDataset",
|
||||
search_aliases=["export dataset", "save dataset"],
|
||||
search_aliases=["export training data"],
|
||||
display_name="Save Training Dataset",
|
||||
category="training",
|
||||
description="Save encoded training dataset (latents + conditioning) to disk for efficient loading during training.",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
is_input_list=True, # Receive lists
|
||||
@ -1493,8 +1424,7 @@ class LoadTrainingDataset(io.ComfyNode):
|
||||
node_id="LoadTrainingDataset",
|
||||
search_aliases=["import dataset", "training data"],
|
||||
display_name="Load Training Dataset",
|
||||
category="training",
|
||||
description="Load encoded training dataset (latents + conditioning) from disk for use in training.",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
|
||||
@ -419,17 +419,15 @@ class VoxelToMeshBasic(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VoxelToMeshBasic",
|
||||
display_name="Voxel to Mesh (Basic) (DEPRECATED)",
|
||||
display_name="Voxel to Mesh (Basic)",
|
||||
category="3d",
|
||||
description="Converts a voxel grid to a mesh.",
|
||||
is_deprecated=True, # This node is superseded by the Voxel To Mesh node
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output(),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -455,10 +453,9 @@ class VoxelToMesh(IO.ComfyNode):
|
||||
node_id="VoxelToMesh",
|
||||
display_name="Voxel to Mesh",
|
||||
category="3d",
|
||||
description="Converts a voxel grid to a mesh.",
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
IO.Combo.Input("algorithm", options=["surface net", "basic"]),
|
||||
IO.Combo.Input("algorithm", options=["surface net", "basic"], advanced=True),
|
||||
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
|
||||
@ -55,10 +55,9 @@ class ImageCropV2(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCropV2",
|
||||
search_aliases=["crop", "cut", "trim"],
|
||||
search_aliases=["trim"],
|
||||
display_name="Crop Image",
|
||||
category="image/transform",
|
||||
description = "Crop an image to the specified dimensions.",
|
||||
essentials_category="Image Tools",
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
|
||||
@ -11,8 +11,8 @@ class LTXVAudioVAELoader(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAELoader",
|
||||
display_name="Load LTXV Audio VAE",
|
||||
category="loaders",
|
||||
display_name="LTXV Audio VAE Loader",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"ckpt_name",
|
||||
@ -40,7 +40,7 @@ class LTXVAudioVAEEncode(VAEEncodeAudio):
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEEncode",
|
||||
display_name="LTXV Audio VAE Encode",
|
||||
category="latent/audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Audio.Input("audio", tooltip="The audio to be encoded."),
|
||||
io.Vae.Input(
|
||||
@ -63,7 +63,7 @@ class LTXVAudioVAEDecode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEDecode",
|
||||
display_name="LTXV Audio VAE Decode",
|
||||
category="latent/audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
io.Latent.Input("samples", tooltip="The latent to be decoded."),
|
||||
io.Vae.Input(
|
||||
|
||||
@ -28,7 +28,7 @@ from comfy_extras.mediapipe.face_landmarker import FaceLandmarker
|
||||
from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection
|
||||
|
||||
|
||||
FaceDetectionType = io.Custom("FACE_DETECTION_MODEL")
|
||||
FaceLandmarkerType = io.Custom("FACE_LANDMARKER")
|
||||
FaceLandmarksType = io.Custom("FACE_LANDMARKS")
|
||||
|
||||
_CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights")
|
||||
@ -204,19 +204,18 @@ class LoadMediaPipeFaceLandmarker(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadMediaPipeFaceLandmarker",
|
||||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
|
||||
display_name="Load Face Detection Model (MediaPipe)",
|
||||
display_name="Load MediaPipe Face Landmarker",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("detection"),
|
||||
tooltip="Face detection model from models/detection/."),
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("mediapipe"),
|
||||
tooltip="Face Landmarker safetensors from models/mediapipe/."),
|
||||
],
|
||||
outputs=[FaceDetectionType.Output()],
|
||||
outputs=[FaceLandmarkerType.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("detection", model_name), safe_load=True)
|
||||
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("mediapipe", model_name), safe_load=True)
|
||||
wrapper = FaceLandmarkerModel(sd)
|
||||
return io.NodeOutput(wrapper)
|
||||
|
||||
@ -235,12 +234,10 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MediaPipeFaceLandmarker",
|
||||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
|
||||
display_name="Detect Face Landmarks (MediaPipe)",
|
||||
display_name="MediaPipe Face Landmarker",
|
||||
category="image/detection",
|
||||
description="Detects facial landmarks using MediaPipe model.",
|
||||
inputs=[
|
||||
FaceDetectionType.Input("face_detection_model"),
|
||||
FaceLandmarkerType.Input("face_landmarker"),
|
||||
io.Image.Input("image"),
|
||||
io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short",
|
||||
tooltip="Face detector range. 'short' is tuned for close-up faces "
|
||||
@ -264,9 +261,9 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, face_detection_model, image, detector_variant, num_faces, min_confidence,
|
||||
def execute(cls, face_landmarker, image, detector_variant, num_faces, min_confidence,
|
||||
missing_frame_fallback) -> io.NodeOutput:
|
||||
canonical = face_detection_model.canonical_data
|
||||
canonical = face_landmarker.canonical_data
|
||||
img_np = _image_to_uint8(image)
|
||||
B, H, W = img_np.shape[:3]
|
||||
chunk = 16
|
||||
@ -279,7 +276,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq:
|
||||
for i in range(0, B, chunk):
|
||||
end = min(i + chunk, B)
|
||||
res.extend(face_detection_model.detect_batch(
|
||||
res.extend(face_landmarker.detect_batch(
|
||||
[img_np[bi] for bi in range(i, end)],
|
||||
num_faces=int(num_faces),
|
||||
score_thresh=float(min_confidence),
|
||||
@ -309,7 +306,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
|
||||
per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])})
|
||||
bboxes.append(per_bb)
|
||||
return io.NodeOutput({"frames": frames, "image_size": (H, W),
|
||||
"connection_sets": face_detection_model.connection_sets}, bboxes)
|
||||
"connection_sets": face_landmarker.connection_sets}, bboxes)
|
||||
|
||||
|
||||
# Topology keys unioned by the 'all' connections preset (contour parts + irises + nose).
|
||||
@ -335,10 +332,8 @@ class MediaPipeFaceMeshVisualize(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MediaPipeFaceMeshVisualize",
|
||||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection", "visualize"],
|
||||
display_name="Visualize Face Landmarks (MediaPipe)",
|
||||
display_name="MediaPipe Face Mesh Visualize",
|
||||
category="image/detection",
|
||||
description="Draws face landmarks mesh on the input image.",
|
||||
inputs=[
|
||||
FaceLandmarksType.Input("face_landmarks"),
|
||||
io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."),
|
||||
@ -448,10 +443,8 @@ class MediaPipeFaceMask(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MediaPipeFaceMask",
|
||||
search_aliases=["face", "facial", "mediapipe", "face mask", "blazeface", "face detection", "visualize"],
|
||||
display_name="Draw Face Mask (MediaPipe)",
|
||||
display_name="MediaPipe Face Mask",
|
||||
category="image/detection",
|
||||
description="Draws a mask from face landmarks.",
|
||||
inputs=[
|
||||
FaceLandmarksType.Input("face_landmarks"),
|
||||
io.DynamicCombo.Input(
|
||||
|
||||
@ -1,845 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
import copy
|
||||
import comfy.utils
|
||||
import logging
|
||||
import scipy
|
||||
|
||||
def get_mesh_batch_item(mesh, index):
|
||||
if hasattr(mesh, "vertex_counts") and mesh.vertex_counts is not None:
|
||||
vertex_count = int(mesh.vertex_counts[index].item())
|
||||
face_count = int(mesh.face_counts[index].item())
|
||||
vertices = mesh.vertices[index, :vertex_count]
|
||||
faces = mesh.faces[index, :face_count]
|
||||
colors = None
|
||||
if hasattr(mesh, "colors") and mesh.colors is not None:
|
||||
if hasattr(mesh, "color_counts") and mesh.color_counts is not None:
|
||||
color_count = int(mesh.color_counts[index].item())
|
||||
colors = mesh.colors[index, :color_count]
|
||||
else:
|
||||
colors = mesh.colors[index, :vertex_count]
|
||||
return vertices, faces, colors
|
||||
|
||||
colors = None
|
||||
if hasattr(mesh, "colors") and mesh.colors is not None:
|
||||
colors = mesh.colors[index]
|
||||
return mesh.vertices[index], mesh.faces[index], colors
|
||||
|
||||
def pack_variable_mesh_batch(vertices, faces, colors=None):
|
||||
batch_size = len(vertices)
|
||||
max_vertices = max(v.shape[0] for v in vertices)
|
||||
max_faces = max(f.shape[0] for f in faces)
|
||||
|
||||
packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1]))
|
||||
packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1]))
|
||||
vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64)
|
||||
face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64)
|
||||
|
||||
for i, (v, f) in enumerate(zip(vertices, faces)):
|
||||
packed_vertices[i, :v.shape[0]] = v
|
||||
packed_faces[i, :f.shape[0]] = f
|
||||
|
||||
mesh = Types.MESH(packed_vertices, packed_faces)
|
||||
mesh.vertex_counts = vertex_counts
|
||||
mesh.face_counts = face_counts
|
||||
|
||||
if colors is not None:
|
||||
max_colors = max(c.shape[0] for c in colors)
|
||||
packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1]))
|
||||
color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64)
|
||||
for i, c in enumerate(colors):
|
||||
packed_colors[i, :c.shape[0]] = c
|
||||
mesh.vertex_colors = packed_colors
|
||||
mesh.color_counts = color_counts
|
||||
|
||||
return mesh
|
||||
|
||||
|
||||
def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution):
|
||||
"""
|
||||
Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field.
|
||||
"""
|
||||
device = comfy.model_management.vae_offload_device()
|
||||
|
||||
origin = torch.tensor([-0.5, -0.5, -0.5], device=device)
|
||||
voxel_size = 1.0 / resolution
|
||||
|
||||
# map voxels
|
||||
voxel_pos = voxel_coords.to(device).float() * voxel_size + origin
|
||||
verts = mesh.vertices.to(device).squeeze(0)
|
||||
voxel_colors = voxel_colors.to(device)
|
||||
|
||||
voxel_pos_np = voxel_pos.numpy()
|
||||
verts_np = verts.numpy()
|
||||
|
||||
tree = scipy.spatial.cKDTree(voxel_pos_np)
|
||||
|
||||
# nearest neighbour k=1
|
||||
_, nearest_idx_np = tree.query(verts_np, k=1, workers=-1)
|
||||
|
||||
nearest_idx = torch.from_numpy(nearest_idx_np).long()
|
||||
v_colors = voxel_colors[nearest_idx]
|
||||
|
||||
# to [0, 1]
|
||||
srgb_colors = v_colors.clamp(0, 1)#(v_colors * 0.5 + 0.5).clamp(0, 1)
|
||||
|
||||
# to Linear RGB (required for GLTF)
|
||||
linear_colors = torch.pow(srgb_colors, 2.2)
|
||||
|
||||
final_colors = linear_colors.unsqueeze(0)
|
||||
|
||||
out_mesh = copy.deepcopy(mesh)
|
||||
out_mesh.vertex_colors = final_colors
|
||||
|
||||
return out_mesh
|
||||
|
||||
class PaintMesh(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="PaintMesh",
|
||||
display_name="Paint Mesh",
|
||||
category="latent/3d",
|
||||
description=(
|
||||
"Paints the mesh using colors from the input voxel field by matching each vertex "
|
||||
"to the nearest voxel color."
|
||||
),
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.Voxel.Input("voxel_colors")
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output("mesh"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, voxel_colors):
|
||||
voxels = voxel_colors
|
||||
coords = voxels.data
|
||||
colors = voxels.voxel_colors
|
||||
resolution = voxels.resolution
|
||||
|
||||
if coords.shape[0] == 0:
|
||||
return IO.NodeOutput(paint_mesh_default_colors(mesh))
|
||||
|
||||
mesh_batch_size = mesh.vertices.shape[0]
|
||||
|
||||
if coords.shape[-1] == 4 and mesh_batch_size > 1:
|
||||
batch_idx = coords[:, 0].long()
|
||||
voxel_coords = coords[:, 1:]
|
||||
mesh_batch_size = mesh.vertices.shape[0]
|
||||
|
||||
out_verts, out_faces, out_colors = [], [], []
|
||||
for i in range(mesh_batch_size):
|
||||
sel = batch_idx == i
|
||||
item_coords = voxel_coords[sel]
|
||||
item_colors = colors[sel]
|
||||
item_vertices, item_faces, _ = get_mesh_batch_item(mesh, i)
|
||||
item_mesh = Types.MESH(vertices=item_vertices.unsqueeze(0), faces=item_faces.unsqueeze(0))
|
||||
|
||||
if item_coords.shape[0] == 0:
|
||||
painted = paint_mesh_default_colors(item_mesh)
|
||||
else:
|
||||
painted = paint_mesh_with_voxels(item_mesh, item_coords, item_colors, resolution=resolution)
|
||||
|
||||
out_verts.append(painted.vertices.squeeze(0))
|
||||
out_faces.append(painted.faces.squeeze(0))
|
||||
out_colors.append(painted.vertex_colors.squeeze(0))
|
||||
|
||||
out_mesh = pack_variable_mesh_batch(out_verts, out_faces, out_colors)
|
||||
return IO.NodeOutput(out_mesh)
|
||||
|
||||
if coords.shape[-1] == 4:
|
||||
coords = coords[:, 1:]
|
||||
|
||||
out_mesh = paint_mesh_with_voxels(mesh, coords, colors, resolution=resolution)
|
||||
return IO.NodeOutput(out_mesh)
|
||||
|
||||
def paint_mesh_default_colors(mesh):
|
||||
out_mesh = copy.copy(mesh)
|
||||
vertex_count = mesh.vertices.shape[1]
|
||||
out_mesh.vertex_colors = mesh.vertices.new_zeros((1, vertex_count, 3))
|
||||
return out_mesh
|
||||
|
||||
|
||||
def fill_holes_fn(vertices, faces, max_perimeter=0.03):
|
||||
is_batched = vertices.ndim == 3
|
||||
if is_batched:
|
||||
v_list, f_list = [], []
|
||||
for i in range(vertices.shape[0]):
|
||||
v_i, f_i = fill_holes_fn(vertices[i], faces[i], max_perimeter)
|
||||
v_list.append(v_i)
|
||||
f_list.append(f_i)
|
||||
max_v = max(v.shape[0] for v in v_list)
|
||||
for i in range(len(v_list)):
|
||||
if v_list[i].shape[0] < max_v:
|
||||
pad = torch.zeros(max_v - v_list[i].shape[0], 3, device=v_list[i].device, dtype=v_list[i].dtype)
|
||||
v_list[i] = torch.cat([v_list[i], pad], dim=0)
|
||||
return torch.stack(v_list), torch.stack(f_list)
|
||||
|
||||
device = vertices.device
|
||||
v = vertices
|
||||
f = faces
|
||||
|
||||
if f.numel() == 0:
|
||||
return v, f
|
||||
|
||||
edges = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0)
|
||||
edges_sorted, _ = torch.sort(edges, dim=1)
|
||||
max_v = v.shape[0]
|
||||
packed_undirected = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long()
|
||||
unique_packed, counts = torch.unique(packed_undirected, return_counts=True)
|
||||
boundary_packed = unique_packed[counts == 1]
|
||||
|
||||
if boundary_packed.numel() == 0:
|
||||
return v, f
|
||||
|
||||
boundary_mask = torch.isin(packed_undirected, boundary_packed)
|
||||
b_edges = edges_sorted[boundary_mask]
|
||||
|
||||
adj = {}
|
||||
for i in range(b_edges.shape[0]):
|
||||
a = b_edges[i, 0].item()
|
||||
b = b_edges[i, 1].item()
|
||||
adj.setdefault(a, []).append(b)
|
||||
adj.setdefault(b, []).append(a)
|
||||
|
||||
# Trace all boundary loops
|
||||
loops = []
|
||||
visited = set()
|
||||
for start_node in adj.keys():
|
||||
if start_node in visited:
|
||||
continue
|
||||
curr = start_node
|
||||
prev = -1
|
||||
loop = []
|
||||
while curr not in visited:
|
||||
visited.add(curr)
|
||||
loop.append(curr)
|
||||
neighbors = adj[curr]
|
||||
candidates = [n for n in neighbors if n != prev]
|
||||
if not candidates:
|
||||
loop = []
|
||||
break
|
||||
next_node = candidates[0]
|
||||
prev, curr = curr, next_node
|
||||
if curr == start_node:
|
||||
loops.append(loop)
|
||||
break
|
||||
|
||||
if not loops:
|
||||
return v, f
|
||||
|
||||
# Mesh normal for winding orientation only
|
||||
face_normals = torch.linalg.cross(
|
||||
v[f[:, 1]] - v[f[:, 0]],
|
||||
v[f[:, 2]] - v[f[:, 0]],
|
||||
dim=-1
|
||||
)
|
||||
mesh_normal = face_normals.mean(dim=0)
|
||||
mesh_normal = mesh_normal / (torch.norm(mesh_normal) + 1e-8)
|
||||
|
||||
# === FIX: Fill ALL boundary loops below perimeter threshold ===
|
||||
new_verts = []
|
||||
new_faces = []
|
||||
v_idx = v.shape[0]
|
||||
|
||||
for loop in loops:
|
||||
loop_t = torch.tensor(loop, device=device, dtype=torch.long)
|
||||
loop_v = v[loop_t]
|
||||
|
||||
# Perimeter check
|
||||
next_v = torch.roll(loop_v, -1, dims=0)
|
||||
diffs = loop_v - next_v
|
||||
perimeter = torch.norm(diffs, dim=1).sum().item()
|
||||
|
||||
if perimeter > max_perimeter:
|
||||
continue
|
||||
|
||||
# Ensure CCW winding consistent with mesh
|
||||
cross = torch.linalg.cross(loop_v, next_v, dim=-1)
|
||||
loop_normal = cross.sum(dim=0)
|
||||
loop_normal = loop_normal / (torch.norm(loop_normal) + 1e-8)
|
||||
if torch.dot(loop_normal, mesh_normal) < 0:
|
||||
loop = loop[::-1]
|
||||
loop_t = torch.tensor(loop, device=device, dtype=torch.long)
|
||||
loop_v = v[loop_t]
|
||||
|
||||
if len(loop) == 3:
|
||||
new_faces.append([loop[0], loop[1], loop[2]])
|
||||
else:
|
||||
centroid = loop_v.mean(dim=0)
|
||||
new_verts.append(centroid)
|
||||
for i in range(len(loop)):
|
||||
new_faces.append([loop[i], loop[(i + 1) % len(loop)], v_idx])
|
||||
v_idx += 1
|
||||
|
||||
if new_verts:
|
||||
v = torch.cat([v, torch.stack(new_verts)], dim=0)
|
||||
if new_faces:
|
||||
f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0)
|
||||
|
||||
return v, f
|
||||
|
||||
def _cleanup_mesh(verts, faces, min_angle_deg=0.5, max_aspect=100.0):
|
||||
if faces.numel() == 0:
|
||||
return verts, faces
|
||||
|
||||
v0 = verts[faces[:, 0]]
|
||||
v1 = verts[faces[:, 1]]
|
||||
v2 = verts[faces[:, 2]]
|
||||
e0 = v1 - v0
|
||||
e1 = v2 - v1
|
||||
e2 = v0 - v2
|
||||
l0 = torch.norm(e0, dim=-1)
|
||||
l1 = torch.norm(e1, dim=-1)
|
||||
l2 = torch.norm(e2, dim=-1)
|
||||
n = torch.cross(e0, e2, dim=-1)
|
||||
area = torch.norm(n, dim=-1)
|
||||
|
||||
max_edge = torch.max(torch.max(l0, l1), l2)
|
||||
aspect = max_edge * max_edge / (2.0 * area + 1e-12)
|
||||
|
||||
cos_a = (l1 * l1 + l2 * l2 - l0 * l0) / (2 * l1 * l2 + 1e-12)
|
||||
cos_b = (l0 * l0 + l2 * l2 - l1 * l1) / (2 * l0 * l2 + 1e-12)
|
||||
cos_c = (l0 * l0 + l1 * l1 - l2 * l2) / (2 * l0 * l1 + 1e-12)
|
||||
cos_all = torch.stack([cos_a, cos_b, cos_c], dim=-1)
|
||||
angles = torch.acos(torch.clamp(cos_all, -1, 1)) * 180 / np.pi
|
||||
|
||||
good = (aspect < max_aspect) & (angles.min(dim=1)[0] > min_angle_deg) & (area > 1e-12)
|
||||
faces = faces[good]
|
||||
|
||||
if faces.numel() == 0:
|
||||
return verts, faces
|
||||
|
||||
used = torch.zeros(verts.shape[0], dtype=torch.bool, device=verts.device)
|
||||
used[faces[:, 0]] = True
|
||||
used[faces[:, 1]] = True
|
||||
used[faces[:, 2]] = True
|
||||
|
||||
remap = torch.full((verts.shape[0],), -1, dtype=torch.int64, device=verts.device)
|
||||
remap[used] = torch.arange(used.sum().item(), device=verts.device)
|
||||
verts = verts[used]
|
||||
faces = remap[faces]
|
||||
return verts, faces
|
||||
|
||||
def _pytorch_edge_errors_fast(verts, Q, edges, stabilizer, max_edge_length_sq, mesh_scale_sq):
|
||||
n_edges = edges.shape[0]
|
||||
dtype = verts.dtype
|
||||
if n_edges == 0:
|
||||
return (torch.empty((0, 3), dtype=dtype, device=verts.device),
|
||||
torch.empty((0,), dtype=dtype, device=verts.device),
|
||||
torch.zeros((0,), dtype=torch.bool, device=verts.device))
|
||||
|
||||
device = verts.device
|
||||
mesh_scale = (mesh_scale_sq) ** 0.5
|
||||
|
||||
va = edges[:, 0]
|
||||
vb = edges[:, 1]
|
||||
Q0 = Q[va]
|
||||
Q1 = Q[vb]
|
||||
Qe = Q0 + Q1
|
||||
|
||||
A = Qe[:, :3, :3] + torch.eye(3, device=device, dtype=dtype).unsqueeze(0) * stabilizer
|
||||
b = -Qe[:, :3, 3].unsqueeze(-1)
|
||||
|
||||
dets = torch.det(A)
|
||||
good = dets.abs() > 1e-12
|
||||
opt = torch.zeros((n_edges, 3), dtype=dtype, device=device)
|
||||
|
||||
if good.any():
|
||||
try:
|
||||
sol = torch.linalg.solve(A[good], b[good])
|
||||
opt[good] = sol.squeeze(-1)
|
||||
except Exception:
|
||||
good = torch.zeros_like(good)
|
||||
|
||||
if (~good).any():
|
||||
bad_idx = torch.nonzero(~good, as_tuple=True)[0]
|
||||
opt[bad_idx] = (verts[va[bad_idx]] + verts[vb[bad_idx]]) * 0.5
|
||||
|
||||
pa = verts[va]
|
||||
pb = verts[vb]
|
||||
el = torch.norm(pb - pa, dim=-1)
|
||||
dist_a = torch.norm(opt - pa, dim=-1)
|
||||
dist_b = torch.norm(opt - pb, dim=-1)
|
||||
wander_bad = (dist_a > 4.0 * el) | (dist_b > 4.0 * el)
|
||||
|
||||
if wander_bad.any():
|
||||
bad_idx = torch.nonzero(wander_bad, as_tuple=True)[0]
|
||||
opt[bad_idx] = (verts[va[bad_idx]] + verts[vb[bad_idx]]) * 0.5
|
||||
|
||||
v4 = torch.cat([opt, torch.ones((n_edges, 1), device=device, dtype=dtype)], dim=1)
|
||||
err = torch.abs(torch.einsum("ei,eij,ej->e", v4, Qe, v4))
|
||||
|
||||
length_ok = el > mesh_scale * 1e-5
|
||||
error_ok = err < max_edge_length_sq
|
||||
nan_ok = ~torch.isnan(opt).any(dim=-1) & ~torch.isnan(err)
|
||||
valid = length_ok & error_ok & nan_ok
|
||||
|
||||
return opt, err, valid
|
||||
|
||||
|
||||
def _build_quadrics_fast(verts, faces):
|
||||
v0 = verts[faces[:, 0]]
|
||||
v1 = verts[faces[:, 1]]
|
||||
v2 = verts[faces[:, 2]]
|
||||
e1 = v1 - v0
|
||||
e2 = v2 - v0
|
||||
n = torch.cross(e1, e2, dim=-1)
|
||||
area = torch.norm(n, dim=-1)
|
||||
mask = area > 1e-12
|
||||
n_norm = torch.zeros_like(n)
|
||||
n_norm[mask] = n[mask] / area[mask].unsqueeze(-1)
|
||||
d = -(n_norm * v0).sum(dim=-1, keepdim=True)
|
||||
p = torch.cat([n_norm, d], dim=-1)
|
||||
K = torch.einsum("fi,fj->fij", p, p)
|
||||
K = K * area[:, None, None]
|
||||
V = verts.shape[0]
|
||||
Q = torch.zeros((V, 4, 4), dtype=verts.dtype, device=verts.device)
|
||||
K_flat = K.reshape(-1, 16)
|
||||
Q_flat = Q.reshape(V, 16)
|
||||
for corner in range(3):
|
||||
idx = faces[:, corner].unsqueeze(1).expand(-1, 16)
|
||||
Q_flat.scatter_add_(0, idx, K_flat)
|
||||
return Q_flat.reshape(V, 4, 4)
|
||||
|
||||
|
||||
def _gpu_greedy_matching_fast(edges, err, v_alive, max_select):
|
||||
"""Vectorized greedy matching.
|
||||
|
||||
Selects an independent set of edges (no two share a vertex) preferring
|
||||
lowest error. Replaces _gpu_greedy_sampled's Python per-edge loop with
|
||||
two scatter_reduce calls.
|
||||
"""
|
||||
device = edges.device
|
||||
n_edges = edges.shape[0]
|
||||
if n_edges == 0:
|
||||
return torch.empty(0, dtype=torch.int64, device=device)
|
||||
|
||||
va = edges[:, 0]
|
||||
vb = edges[:, 1]
|
||||
num_verts = v_alive.shape[0]
|
||||
|
||||
# Pack (error_bits, edge_idx) into one int64 so amin gives a unique winner.
|
||||
# err is non-negative finite float32 -> IEEE bits are monotonic.
|
||||
err32 = err.to(torch.float32).clamp(min=0).contiguous()
|
||||
err_bits = err32.view(torch.int32).to(torch.int64) & 0xFFFFFFFF
|
||||
edge_idx = torch.arange(n_edges, device=device, dtype=torch.int64)
|
||||
key = (err_bits << 32) | edge_idx
|
||||
|
||||
INT64_MAX = torch.iinfo(torch.int64).max
|
||||
best_key = torch.full((num_verts,), INT64_MAX, dtype=torch.int64, device=device)
|
||||
best_key.scatter_reduce_(0, va, key, reduce='amin', include_self=True)
|
||||
best_key.scatter_reduce_(0, vb, key, reduce='amin', include_self=True)
|
||||
|
||||
# An edge wins iff it is the min-key edge incident to BOTH its endpoints
|
||||
# AND both endpoints are still alive.
|
||||
is_winner = (key == best_key[va]) & (key == best_key[vb]) & v_alive[va] & v_alive[vb]
|
||||
|
||||
sel = torch.nonzero(is_winner, as_tuple=True)[0]
|
||||
|
||||
if sel.numel() > max_select:
|
||||
sel_err = err[sel]
|
||||
top = torch.topk(sel_err, max_select, largest=False).indices
|
||||
sel = sel[top]
|
||||
|
||||
return sel
|
||||
|
||||
|
||||
def _qem_simplify_fast(vertices, faces_in, colors_in, normals_in, target_faces, device, max_edge_length=None):
|
||||
# Use float32 instead of float64. RTX-class consumer GPUs run FP32 ~32-64x
|
||||
# faster than FP64, and QEM only needs the stabilizer for conditioning.
|
||||
# Always copy=True so we can safely mutate verts/colors/normals in-place.
|
||||
verts = vertices.detach().to(device=device, dtype=torch.float32, copy=True)
|
||||
faces = faces_in.detach().to(device=device, dtype=torch.int64)
|
||||
colors = (
|
||||
colors_in.detach().to(device=device, dtype=torch.float32, copy=True)
|
||||
if colors_in is not None
|
||||
else None
|
||||
)
|
||||
# ADDED: Initialize normals
|
||||
normals = (
|
||||
normals_in.detach().to(device=device, dtype=torch.float32, copy=True)
|
||||
if normals_in is not None
|
||||
else None
|
||||
)
|
||||
|
||||
num_verts = verts.shape[0]
|
||||
num_faces = faces.shape[0]
|
||||
|
||||
logging.debug(f"[QEM-fast] Input: {num_verts} verts, {num_faces} faces, target={target_faces}")
|
||||
|
||||
v_alive = torch.ones(num_verts, dtype=torch.bool, device=device)
|
||||
f_alive = torch.ones(num_faces, dtype=torch.bool, device=device)
|
||||
|
||||
Q = _build_quadrics_fast(verts, faces)
|
||||
|
||||
bbox = verts.max(dim=0)[0] - verts.min(dim=0)[0]
|
||||
mesh_scale = torch.norm(bbox).item()
|
||||
|
||||
if max_edge_length is None or max_edge_length <= 0:
|
||||
max_edge_length = mesh_scale * 2.0
|
||||
|
||||
if max_edge_length < 1e-6:
|
||||
max_edge_length = 1.0
|
||||
|
||||
stabilizer = mesh_scale * mesh_scale * 0.001
|
||||
max_edge_length_sq = max_edge_length * max_edge_length
|
||||
mesh_scale_sq = mesh_scale * mesh_scale
|
||||
|
||||
iteration = 0
|
||||
total_collapses = 0
|
||||
last_faces = num_faces
|
||||
|
||||
while True:
|
||||
n_faces = int(f_alive.sum().item())
|
||||
|
||||
if n_faces <= target_faces:
|
||||
break
|
||||
|
||||
alive_v = torch.nonzero(v_alive, as_tuple=True)[0]
|
||||
alive_f = torch.nonzero(f_alive, as_tuple=True)[0]
|
||||
|
||||
if alive_v.numel() <= 4 or alive_f.numel() == 0:
|
||||
break
|
||||
|
||||
# Compact active mesh
|
||||
vmap = torch.full((num_verts,), -1, dtype=torch.int64, device=device)
|
||||
vmap[alive_v] = torch.arange(alive_v.numel(), device=device)
|
||||
|
||||
active_faces = faces[alive_f]
|
||||
remapped = vmap[active_faces]
|
||||
|
||||
# Extract edges
|
||||
e0 = remapped[:, [0, 1]]
|
||||
e1 = remapped[:, [1, 2]]
|
||||
e2 = remapped[:, [2, 0]]
|
||||
edges = torch.cat([e0, e1, e2], dim=0)
|
||||
edges = torch.sort(edges, dim=1)[0]
|
||||
edges = edges[(edges >= 0).all(dim=1)]
|
||||
edges = edges[edges[:, 0] != edges[:, 1]]
|
||||
|
||||
if edges.shape[0] == 0:
|
||||
break
|
||||
|
||||
# Deduplicate edges
|
||||
num_compact = alive_v.numel()
|
||||
packed = edges[:, 0].long() * num_compact + edges[:, 1].long()
|
||||
packed = torch.unique(packed)
|
||||
edges = torch.stack([packed // num_compact, packed % num_compact], dim=1)
|
||||
|
||||
edges_orig = alive_v[edges]
|
||||
|
||||
# Filter by edge length
|
||||
pa = verts[edges_orig[:, 0]]
|
||||
pb = verts[edges_orig[:, 1]]
|
||||
el = torch.norm(pb - pa, dim=-1)
|
||||
short_enough = el < max_edge_length
|
||||
|
||||
if not short_enough.any():
|
||||
max_edge_length = el.max().item() * 2.0
|
||||
max_edge_length_sq = max_edge_length * max_edge_length
|
||||
short_enough = el < max_edge_length
|
||||
if not short_enough.any():
|
||||
break
|
||||
|
||||
edges_orig = edges_orig[short_enough]
|
||||
if edges_orig.shape[0] == 0:
|
||||
break
|
||||
|
||||
# Sample edges for processing
|
||||
n_edges_total = edges_orig.shape[0]
|
||||
max_edges_to_process = 10_000_000
|
||||
|
||||
if n_edges_total > max_edges_to_process:
|
||||
perm = torch.randint(0, n_edges_total, (max_edges_to_process,), device=device)
|
||||
edges_orig = edges_orig[perm]
|
||||
n_edges = max_edges_to_process
|
||||
else:
|
||||
n_edges = n_edges_total
|
||||
|
||||
optimal, err, valid = _pytorch_edge_errors_fast(
|
||||
verts, Q, edges_orig, stabilizer, max_edge_length_sq, mesh_scale_sq
|
||||
)
|
||||
|
||||
if not valid.any():
|
||||
valid = torch.ones(n_edges, dtype=torch.bool, device=device)
|
||||
|
||||
valid_idx = torch.nonzero(valid, as_tuple=True)[0]
|
||||
edges_orig = edges_orig[valid_idx]
|
||||
optimal = optimal[valid_idx]
|
||||
err = err[valid_idx]
|
||||
|
||||
faces_to_remove = n_faces - target_faces
|
||||
max_collapses = min(1_000_000, max(10_000, faces_to_remove // 4))
|
||||
|
||||
sel = _gpu_greedy_matching_fast(edges_orig, err, v_alive, max_collapses)
|
||||
|
||||
if sel.numel() == 0:
|
||||
break
|
||||
|
||||
v_a = edges_orig[sel, 0]
|
||||
v_b = edges_orig[sel, 1]
|
||||
|
||||
# Apply collapses
|
||||
verts[v_a] = optimal[sel]
|
||||
v_alive[v_b] = False
|
||||
Q[v_a] += Q[v_b]
|
||||
|
||||
if colors is not None:
|
||||
colors[v_a] = (colors[v_a] + colors[v_b]) * 0.5
|
||||
|
||||
if normals is not None:
|
||||
normals[v_a] = (normals[v_a] + normals[v_b]) * 0.5
|
||||
|
||||
merge_map = torch.arange(num_verts, device=device)
|
||||
merge_map[v_b] = v_a
|
||||
faces = merge_map[faces]
|
||||
|
||||
bad = (
|
||||
(faces[:, 0] == faces[:, 1])
|
||||
| (faces[:, 1] == faces[:, 2])
|
||||
| (faces[:, 2] == faces[:, 0])
|
||||
)
|
||||
f_alive &= ~bad
|
||||
|
||||
total_collapses += v_a.numel()
|
||||
iteration += 1
|
||||
|
||||
if iteration % 50 == 0 or n_faces < last_faces * 0.9:
|
||||
logging.debug(f"[QEM-fast] Iter {iteration}: {total_collapses} collapses, {int(f_alive.sum().item())} faces, applied {v_a.numel()}")
|
||||
last_faces = n_faces
|
||||
|
||||
if iteration % 5 == 0 and int(f_alive.sum().item()) < num_faces * 0.5:
|
||||
faces = faces[f_alive]
|
||||
f_alive = torch.ones(faces.shape[0], dtype=torch.bool, device=device)
|
||||
num_faces = faces.shape[0]
|
||||
|
||||
if iteration > 5000:
|
||||
break
|
||||
|
||||
# Finalize
|
||||
final_v = verts[v_alive]
|
||||
final_c = colors[v_alive] if colors is not None else None
|
||||
final_n = normals[v_alive] if normals is not None else None
|
||||
|
||||
remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device)
|
||||
remap[v_alive] = torch.arange(int(v_alive.sum().item()), device=device)
|
||||
|
||||
final_f_raw = faces[f_alive]
|
||||
alive_mask = v_alive[final_f_raw].all(dim=1)
|
||||
final_f_raw = final_f_raw[alive_mask]
|
||||
final_f = remap[final_f_raw]
|
||||
valid_faces = (final_f >= 0).all(dim=1)
|
||||
final_f = final_f[valid_faces]
|
||||
|
||||
if final_f.numel() > 0:
|
||||
final_f = torch.unique(torch.sort(final_f, dim=1)[0], dim=0)
|
||||
|
||||
if final_n is not None and final_f.numel() > 0:
|
||||
v0, v1, v2 = final_v[final_f[:, 0]], final_v[final_f[:, 1]], final_v[final_f[:, 2]]
|
||||
|
||||
# calculate the actual normal of the simplified faces
|
||||
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
||||
|
||||
# Get the average reference normal for each face
|
||||
n0, n1, n2 = final_n[final_f[:, 0]], final_n[final_f[:, 1]], final_n[final_f[:, 2]]
|
||||
ref_face_normals = (n0 + n1 + n2) / 3.0
|
||||
|
||||
# Dot product to check if they point in the same direction
|
||||
dot_products = (face_normals * ref_face_normals).sum(dim=-1)
|
||||
|
||||
# Flip the indices of ONLY the incorrect faces (swap vertex 1 and 2)
|
||||
wrong_way_mask = dot_products < 0
|
||||
final_f[wrong_way_mask] = final_f[wrong_way_mask][:, [0, 2, 1]]
|
||||
|
||||
final_v, final_f = _cleanup_mesh(final_v, final_f, min_angle_deg=0.5, max_aspect=100.0)
|
||||
|
||||
return final_v, final_f, final_c, final_n
|
||||
|
||||
|
||||
def simplify_fn_fast(vertices, faces, colors=None, normals=None, target=100000, max_edge_length=None):
|
||||
if vertices.ndim == 3:
|
||||
v_list, f_list, c_list, n_list = [], [], [], []
|
||||
for i in range(vertices.shape[0]):
|
||||
c_in = colors[i] if colors is not None else None
|
||||
n_in = normals[i] if normals is not None else None
|
||||
v_i, f_i, c_i, n_i = simplify_fn_fast(vertices[i], faces[i], c_in, n_in, target, max_edge_length)
|
||||
v_list.append(v_i)
|
||||
f_list.append(f_i)
|
||||
if c_i is not None:
|
||||
c_list.append(c_i)
|
||||
if n_i is not None:
|
||||
n_list.append(n_i)
|
||||
|
||||
c_out = torch.stack(c_list) if len(c_list) > 0 else None
|
||||
n_out = torch.stack(n_list) if len(n_list) > 0 else None
|
||||
return torch.stack(v_list), torch.stack(f_list), c_out, n_out
|
||||
|
||||
if faces.shape[0] <= target:
|
||||
return vertices, faces, colors, normals
|
||||
|
||||
device = vertices.device
|
||||
dtype = vertices.dtype
|
||||
face_dtype = faces.dtype
|
||||
color_dtype = colors.dtype if colors is not None else None
|
||||
# ADDED: Normal dtype
|
||||
normal_dtype = normals.dtype if normals is not None else None
|
||||
|
||||
# Pass tensors directly; _qem_simplify_fast handles dtype/device + copy.
|
||||
out_v, out_f, out_c, out_n = _qem_simplify_fast(
|
||||
vertices, faces, colors, normals, target, device, max_edge_length
|
||||
)
|
||||
|
||||
final_v = out_v.to(device=device, dtype=dtype)
|
||||
final_f = out_f.to(device=device, dtype=face_dtype)
|
||||
final_c = (
|
||||
out_c.to(device=device, dtype=color_dtype)
|
||||
if out_c is not None
|
||||
else None
|
||||
)
|
||||
final_n = (
|
||||
out_n.to(device=device, dtype=normal_dtype)
|
||||
if out_n is not None
|
||||
else None
|
||||
)
|
||||
return final_v, final_f, final_c, final_n
|
||||
|
||||
def compute_vertex_normals(verts, faces):
|
||||
"""Computes area-weighted vertex normals."""
|
||||
# QUICK FIX: Ensure indices are int64 for scatter_add_
|
||||
faces_long = faces.to(torch.int64)
|
||||
|
||||
i0, i1, i2 = faces_long[:, 0], faces_long[:, 1], faces_long[:, 2]
|
||||
v0, v1, v2 = verts[i0], verts[i1], verts[i2]
|
||||
|
||||
# calculate unnormalized face normals (magnitude is proportional to area)
|
||||
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
||||
|
||||
# accumulate face normals to vertices
|
||||
vertex_normals = torch.zeros_like(verts)
|
||||
vertex_normals.scatter_add_(0, i0.unsqueeze(-1).expand_as(face_normals), face_normals)
|
||||
vertex_normals.scatter_add_(0, i1.unsqueeze(-1).expand_as(face_normals), face_normals)
|
||||
vertex_normals.scatter_add_(0, i2.unsqueeze(-1).expand_as(face_normals), face_normals)
|
||||
|
||||
return torch.nn.functional.normalize(vertex_normals, p=2, dim=-1, eps=1e-6)
|
||||
|
||||
def _process_mesh_batch(mesh, per_item_fn):
|
||||
"""Handles list/batched/single mesh dispatching, color extraction, and stacking."""
|
||||
mesh = copy.deepcopy(mesh)
|
||||
|
||||
def process_single(v, f, c, bar):
|
||||
v, f, c = per_item_fn(v, f, c)
|
||||
bar.update(1)
|
||||
return v, f, c
|
||||
|
||||
is_list = isinstance(mesh.vertices, list)
|
||||
is_batched_tensor = not is_list and mesh.vertices.ndim == 3
|
||||
|
||||
if is_list or is_batched_tensor:
|
||||
out_v, out_f, out_c = [], [], []
|
||||
bsz = len(mesh.vertices) if is_list else mesh.vertices.shape[0]
|
||||
bar = comfy.utils.ProgressBar(bsz)
|
||||
|
||||
for i in range(bsz):
|
||||
v_i = mesh.vertices[i]
|
||||
f_i = mesh.faces[i]
|
||||
c_i = None
|
||||
if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None:
|
||||
c_i = mesh.vertex_colors[i] if (isinstance(mesh.vertex_colors, list) or mesh.vertex_colors.ndim == 3) else mesh.vertex_colors
|
||||
|
||||
v_i, f_i, c_i = process_single(v_i, f_i, c_i, bar)
|
||||
|
||||
out_v.append(v_i)
|
||||
out_f.append(f_i)
|
||||
if c_i is not None:
|
||||
out_c.append(c_i)
|
||||
|
||||
if all(v.shape == out_v[0].shape for v in out_v) and all(f.shape == out_f[0].shape for f in out_f):
|
||||
mesh.vertices = torch.stack(out_v)
|
||||
mesh.faces = torch.stack(out_f)
|
||||
if out_c:
|
||||
mesh.vertex_colors = torch.stack(out_c)
|
||||
else:
|
||||
mesh.vertices = out_v
|
||||
mesh.faces = out_f
|
||||
if out_c:
|
||||
mesh.vertex_colors = out_c
|
||||
else:
|
||||
c = mesh.vertex_colors if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None else None
|
||||
bar = comfy.utils.ProgressBar(1)
|
||||
v, f, c = process_single(mesh.vertices, mesh.faces, c, bar)
|
||||
mesh.vertices = v
|
||||
mesh.faces = f
|
||||
if c is not None:
|
||||
mesh.vertex_colors = c
|
||||
|
||||
return IO.NodeOutput(mesh)
|
||||
|
||||
|
||||
class DecimateMesh(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="DecimateMesh",
|
||||
display_name="Decimate Mesh",
|
||||
category="latent/3d",
|
||||
description="Simplifies a mesh to a target face count using QEM.",
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.Int.Input("target_face_count", default=200_000, min=0, max=50_000_000,
|
||||
tooltip="Target maximum number of faces. Set to 0 to disable."),
|
||||
],
|
||||
outputs=[IO.Mesh.Output("mesh")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, target_face_count):
|
||||
def _fn(v, f, c):
|
||||
if target_face_count > 0 and f.shape[0] > target_face_count:
|
||||
n = compute_vertex_normals(v, f)
|
||||
v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count)
|
||||
return v, f, c
|
||||
return _process_mesh_batch(mesh, _fn)
|
||||
|
||||
|
||||
class FillHoles(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="FillHoles",
|
||||
display_name="Fill Holes",
|
||||
category="latent/3d",
|
||||
description="Fills holes in a mesh up to a maximum perimeter threshold.",
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.Float.Input("max_perimeter", default=0.03, min=0.0, step=0.0001,
|
||||
tooltip="Maximum hole perimeter to fill. Set to 0 to disable."),
|
||||
],
|
||||
outputs=[IO.Mesh.Output("mesh")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, max_perimeter):
|
||||
def _fn(v, f, c):
|
||||
if max_perimeter > 0:
|
||||
v, f = fill_holes_fn(v, f, max_perimeter=max_perimeter)
|
||||
return v, f, c
|
||||
return _process_mesh_batch(mesh, _fn)
|
||||
|
||||
class PostProcessMeshExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
FillHoles,
|
||||
DecimateMesh,
|
||||
PaintMesh
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> PostProcessMeshExtension:
|
||||
return PostProcessMeshExtension()
|
||||
@ -103,10 +103,8 @@ class MoGePanoramaInference(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGePanoramaInference",
|
||||
search_aliases=["moge", "panorama", "depth", "geometry", "depth estimation", "geometry estimation"],
|
||||
display_name="Run MoGe Panorama Inference",
|
||||
display_name="MoGe Panorama Inference",
|
||||
category="image/geometry_estimation",
|
||||
description="Run MoGe on an equirectangular panorama by splitting it into 12 perspective views, running inference on each, and merging the results into a single depth map.",
|
||||
inputs=[
|
||||
MoGeModelType.Input("moge_model"),
|
||||
io.Image.Input("image", tooltip="Equirectangular panorama (any aspect)."),
|
||||
@ -224,9 +222,7 @@ class MoGeInference(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeInference",
|
||||
search_aliases=["moge", "depth", "geometry", "depth estimation", "geometry estimation"],
|
||||
display_name="Run MoGe Inference",
|
||||
description="Run MoGe on a single image to estimate depth and geometry.",
|
||||
display_name="MoGe Inference",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeModelType.Input("moge_model"),
|
||||
@ -281,9 +277,7 @@ class MoGeRender(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeRender",
|
||||
search_aliases=["moge", "render", "geometry", "depth", "normal"],
|
||||
display_name="Render MoGe Geometry",
|
||||
description="Render a depth map or normal map from geometry data",
|
||||
display_name="MoGe Render",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeGeometry.Input("moge_geometry"),
|
||||
@ -348,9 +342,7 @@ class MoGePointMapToMesh(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGePointMapToMesh",
|
||||
search_aliases=["moge", "mesh", "geometry", "point map"],
|
||||
display_name="Convert MoGe Point Map to Mesh",
|
||||
description="Convert a MoGe point map into a 3D mesh.",
|
||||
display_name="MoGe Point Map to Mesh",
|
||||
category="image/geometry_estimation",
|
||||
inputs=[
|
||||
MoGeGeometry.Input("moge_geometry"),
|
||||
|
||||
@ -234,12 +234,6 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
textures = []
|
||||
samplers = []
|
||||
materials = []
|
||||
pbr = {
|
||||
"metallicFactor": 0.0,
|
||||
"roughnessFactor": 0.5,
|
||||
"baseColorFactor": [0.22, 0.22, 0.22, 1.0],
|
||||
}
|
||||
|
||||
if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes:
|
||||
buffer_views.append({
|
||||
"buffer": 0,
|
||||
@ -249,13 +243,15 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
|
||||
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
|
||||
textures.append({"source": 0, "sampler": 0})
|
||||
pbr["baseColorTexture"] = {"index": 0, "texCoord": 0}
|
||||
|
||||
materials.append({
|
||||
"pbrMetallicRoughness": pbr,
|
||||
"doubleSided": True,
|
||||
})
|
||||
primitive["material"] = 0
|
||||
materials.append({
|
||||
"pbrMetallicRoughness": {
|
||||
"baseColorTexture": {"index": 0, "texCoord": 0},
|
||||
"metallicFactor": 0.0,
|
||||
"roughnessFactor": 1.0,
|
||||
},
|
||||
"doubleSided": True,
|
||||
})
|
||||
primitive["material"] = 0
|
||||
|
||||
gltf = {
|
||||
"asset": {"version": "2.0", "generator": "ComfyUI"},
|
||||
@ -377,14 +373,10 @@ class SaveGLB(IO.ComfyNode):
|
||||
continue
|
||||
tex_img = Image.fromarray(texture_np[i], mode="RGB") if texture_np is not None else None
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(
|
||||
vertices_i, faces_i,
|
||||
os.path.join(full_output_folder, f),
|
||||
metadata,
|
||||
uvs=uvs_i,
|
||||
vertex_colors=v_colors,
|
||||
texture_image=tex_img,
|
||||
)
|
||||
save_glb(vertices_i, faces_i, os.path.join(full_output_folder, f), metadata,
|
||||
uvs=uvs_i,
|
||||
vertex_colors=v_colors,
|
||||
texture_image=tex_img)
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
|
||||
@ -1,667 +0,0 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types, io
|
||||
from comfy.ldm.trellis2.vae import SparseTensor
|
||||
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
|
||||
import comfy.model_management
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
|
||||
|
||||
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
||||
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
|
||||
if len(sample_shape) == 5:
|
||||
memory_required *= max(1, int(sample_shape[4]))
|
||||
memory_required = max(1, int(memory_required))
|
||||
device = comfy.model_management.get_torch_device()
|
||||
comfy.model_management.load_models_gpu(
|
||||
[vae.patcher],
|
||||
memory_required=memory_required,
|
||||
force_full_load=getattr(vae, "disable_offload", False),
|
||||
)
|
||||
free_memory = vae.patcher.get_free_memory(device)
|
||||
batch_number = max(1, int(free_memory / memory_required))
|
||||
return batch_number
|
||||
|
||||
shape_slat_normalization = {
|
||||
"mean": torch.tensor([
|
||||
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
|
||||
-0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
|
||||
0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
|
||||
-0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
|
||||
])[None],
|
||||
"std": torch.tensor([
|
||||
5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
|
||||
5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
|
||||
4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
|
||||
5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
|
||||
])[None]
|
||||
}
|
||||
|
||||
tex_slat_normalization = {
|
||||
"mean": torch.tensor([
|
||||
3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075,
|
||||
0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149,
|
||||
-1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717,
|
||||
1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862
|
||||
])[None],
|
||||
"std": torch.tensor([
|
||||
2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822,
|
||||
2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588,
|
||||
2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999,
|
||||
2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190
|
||||
])[None]
|
||||
}
|
||||
|
||||
def shape_norm(shape_latent, coords):
|
||||
std = shape_slat_normalization["std"].to(shape_latent)
|
||||
mean = shape_slat_normalization["mean"].to(shape_latent)
|
||||
samples = SparseTensor(feats = shape_latent, coords=coords)
|
||||
samples = samples * std + mean
|
||||
return samples
|
||||
|
||||
|
||||
def infer_batched_coord_layout(coords):
|
||||
if coords.ndim != 2 or coords.shape[1] != 4:
|
||||
raise ValueError(f"Expected Trellis2 coords with shape [N, 4], got {tuple(coords.shape)}")
|
||||
|
||||
if coords.shape[0] == 0:
|
||||
raise ValueError("Trellis2 coords can't be empty")
|
||||
|
||||
batch_ids = coords[:, 0].to(torch.int64)
|
||||
if (batch_ids < 0).any():
|
||||
raise ValueError(f"Trellis2 batch ids must be non-negative, got {batch_ids.unique(sorted=True).tolist()}")
|
||||
batch_size = int(batch_ids.max().item()) + 1
|
||||
counts = torch.bincount(batch_ids, minlength=batch_size)
|
||||
|
||||
if (counts == 0).any():
|
||||
raise ValueError(f"Non-contiguous Trellis2 batch ids in coords: {batch_ids.unique(sorted=True).tolist()}")
|
||||
|
||||
max_tokens = int(counts.max().item())
|
||||
return batch_size, counts, max_tokens
|
||||
|
||||
|
||||
def split_batched_coords(coords, coord_counts):
|
||||
if coord_counts.ndim != 1:
|
||||
raise ValueError(f"Trellis2 coord_counts must be 1D, got shape {tuple(coord_counts.shape)}")
|
||||
if (coord_counts < 0).any():
|
||||
raise ValueError(f"Trellis2 coord_counts must be non-negative, got {coord_counts.tolist()}")
|
||||
if int(coord_counts.sum().item()) != coords.shape[0]:
|
||||
raise ValueError(
|
||||
f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}"
|
||||
)
|
||||
|
||||
batch_ids = coords[:, 0].to(torch.int64)
|
||||
order = torch.argsort(batch_ids, stable=True)
|
||||
sorted_coords = coords.index_select(0, order)
|
||||
sorted_batch_ids = batch_ids.index_select(0, order)
|
||||
|
||||
offsets = coord_counts.cumsum(0) - coord_counts
|
||||
items = []
|
||||
for i in range(coord_counts.shape[0]):
|
||||
count = int(coord_counts[i].item())
|
||||
start = int(offsets[i].item())
|
||||
coords_i = sorted_coords[start:start + count]
|
||||
ids_i = sorted_batch_ids[start:start + count]
|
||||
if coords_i.shape[0] != count or not torch.all(ids_i == i):
|
||||
raise ValueError(f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}")
|
||||
items.append(coords_i)
|
||||
return items
|
||||
|
||||
def flatten_batched_sparse_latent(samples, coords, coord_counts):
|
||||
samples = samples.squeeze(-1).transpose(1, 2)
|
||||
if coord_counts is None:
|
||||
return samples.reshape(-1, samples.shape[-1]), coords
|
||||
|
||||
coords_items = split_batched_coords(coords, coord_counts)
|
||||
feat_list = []
|
||||
coord_list = []
|
||||
for i, coords_i in enumerate(coords_items):
|
||||
count = int(coord_counts[i].item())
|
||||
feat_list.append(samples[i, :count])
|
||||
coord_list.append(coords_i)
|
||||
|
||||
return torch.cat(feat_list, dim=0), torch.cat(coord_list, dim=0)
|
||||
|
||||
|
||||
def split_batched_sparse_latent(samples, coords, coord_counts):
|
||||
samples = samples.squeeze(-1).transpose(1, 2)
|
||||
if coord_counts is None:
|
||||
return [(samples.reshape(-1, samples.shape[-1]), coords)]
|
||||
|
||||
coords_items = split_batched_coords(coords, coord_counts)
|
||||
items = []
|
||||
for i, coords_i in enumerate(coords_items):
|
||||
count = int(coord_counts[i].item())
|
||||
items.append((samples[i, :count], coords_i))
|
||||
return items
|
||||
|
||||
class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VaeDecodeShapeTrellis",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output("mesh"),
|
||||
ShapeSubdivides.Output(display_name = "shape_subdivides"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae):
|
||||
|
||||
resolution = int(vae.first_stage_model.resolution.item())
|
||||
sample_tensor = samples["samples"]
|
||||
device = comfy.model_management.get_torch_device()
|
||||
coords = samples["coords"]
|
||||
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||
trellis_vae = vae.first_stage_model
|
||||
coord_counts = samples.get("coord_counts")
|
||||
|
||||
samples = samples["samples"]
|
||||
if coord_counts is None:
|
||||
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
||||
samples = shape_norm(samples.to(device), coords.to(device))
|
||||
mesh, subs = trellis_vae.decode_shape_slat(samples, resolution)
|
||||
else:
|
||||
split_items = split_batched_sparse_latent(samples, coords, coord_counts)
|
||||
mesh = []
|
||||
subs_per_sample = []
|
||||
for feats_i, coords_i in split_items:
|
||||
coords_i = coords_i.to(device).clone()
|
||||
coords_i[:, 0] = 0
|
||||
sample_i = shape_norm(feats_i.to(device), coords_i)
|
||||
mesh_i, subs_i = trellis_vae.decode_shape_slat(sample_i, resolution)
|
||||
mesh.append(mesh_i[0])
|
||||
subs_per_sample.append(subs_i)
|
||||
|
||||
subs = []
|
||||
for stage_index in range(len(subs_per_sample[0])):
|
||||
stage_tensors = [sample_subs[stage_index] for sample_subs in subs_per_sample]
|
||||
feats_list = [stage_tensor.feats for stage_tensor in stage_tensors]
|
||||
coords_list = [stage_tensor.coords for stage_tensor in stage_tensors]
|
||||
subs.append(SparseTensor.from_tensor_list(feats_list, coords_list))
|
||||
|
||||
face_list = [m.faces for m in mesh]
|
||||
vert_list = [m.vertices for m in mesh]
|
||||
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
|
||||
mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list))
|
||||
else:
|
||||
mesh = pack_variable_mesh_batch(vert_list, face_list)
|
||||
return IO.NodeOutput(mesh, subs)
|
||||
|
||||
class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VaeDecodeTextureTrellis",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
ShapeSubdivides.Input("shape_subdivides",
|
||||
tooltip=(
|
||||
"Shape information used to guide higher-detail reconstruction during decoding. "
|
||||
"Helps preserve structure consistency at higher resolutions."
|
||||
)),
|
||||
],
|
||||
outputs=[
|
||||
IO.Voxel.Output("voxel_colors"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae, shape_subdivides):
|
||||
sample_tensor = samples["samples"]
|
||||
device = comfy.model_management.get_torch_device()
|
||||
coords = samples["coords"]
|
||||
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||
trellis_vae = vae.first_stage_model
|
||||
coord_counts = samples.get("coord_counts")
|
||||
|
||||
samples = samples["samples"]
|
||||
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
||||
samples = samples.to(device)
|
||||
std = tex_slat_normalization["std"].to(samples)
|
||||
mean = tex_slat_normalization["mean"].to(samples)
|
||||
samples = SparseTensor(feats = samples, coords=coords.to(device))
|
||||
samples = samples * std + mean
|
||||
|
||||
voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides)
|
||||
color_feats = voxel.feats[:, :3]
|
||||
voxel_coords = voxel.coords#[:, 1:]
|
||||
|
||||
voxel = Types.VOXEL(voxel_coords, color_feats, 1024)
|
||||
return IO.NodeOutput(voxel)
|
||||
|
||||
class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VaeDecodeStructureTrellis2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Combo.Input("resolution", options=["32", "64"], default="32")
|
||||
],
|
||||
outputs=[
|
||||
IO.Voxel.Output("voxel"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae, resolution):
|
||||
resolution = int(resolution)
|
||||
sample_tensor = samples["samples"]
|
||||
sample_tensor = sample_tensor[:, :8]
|
||||
batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||
decoder = vae.first_stage_model.struct_dec
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
decoded_batches = []
|
||||
for start in range(0, sample_tensor.shape[0], batch_number):
|
||||
sample_chunk = sample_tensor[start:start + batch_number].to(load_device)
|
||||
decoded_batches.append(decoder(sample_chunk) > 0)
|
||||
decoded = torch.cat(decoded_batches, dim=0)
|
||||
current_res = decoded.shape[2]
|
||||
|
||||
if current_res != resolution:
|
||||
ratio = current_res // resolution
|
||||
decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
|
||||
out = Types.VOXEL(decoded.squeeze(1).float())
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
class Trellis2UpsampleCascade(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Trellis2UpsampleCascade",
|
||||
category="latent/3d",
|
||||
display_name="Trellis2 Upsample Cascade",
|
||||
description="Upsamples low-resolution Trellis2 shape latents into higher resolution coordinates while respecting the maximum token budget.",
|
||||
inputs=[
|
||||
IO.Latent.Input("shape_latent"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Combo.Input("target_resolution", options=["1024", "1536"], default="1024", tooltip="Controls output detail level for upsampling."),
|
||||
IO.Int.Input("max_tokens", default=49152, min=1024, max=100000,
|
||||
tooltip=(
|
||||
"Maximum number of output elements (coordinates) allowed after upsampling. "
|
||||
"Used to limit memory usage and control mesh density."
|
||||
))
|
||||
],
|
||||
outputs=[
|
||||
IO.Voxel.Output(
|
||||
"high_res_voxel",
|
||||
tooltip=(
|
||||
"High-resolution sparse coordinates produced after cascade upsampling. "
|
||||
"Represents the refined 3D structure at target resolution."
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, shape_latent, vae, target_resolution, max_tokens):
|
||||
shape_latent_512 = shape_latent
|
||||
device = comfy.model_management.get_torch_device()
|
||||
prepare_trellis_vae_for_decode(vae, shape_latent_512["samples"].shape)
|
||||
|
||||
coord_counts = shape_latent_512.get("coord_counts")
|
||||
decoder = vae.first_stage_model.shape_dec
|
||||
lr_resolution = 512
|
||||
target_resolution = int(target_resolution)
|
||||
|
||||
if coord_counts is None:
|
||||
feats, coords_512 = flatten_batched_sparse_latent(
|
||||
shape_latent_512["samples"],
|
||||
shape_latent_512["coords"],
|
||||
coord_counts,
|
||||
)
|
||||
feats = feats.to(device)
|
||||
coords_512 = coords_512.to(device)
|
||||
slat = shape_norm(feats, coords_512)
|
||||
slat.feats = slat.feats.to(next(decoder.parameters()).dtype)
|
||||
hr_coords = decoder.upsample(slat, upsample_times=4)
|
||||
|
||||
hr_resolution = target_resolution
|
||||
while True:
|
||||
quant_coords = torch.cat([
|
||||
hr_coords[:, :1],
|
||||
((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
|
||||
], dim=1)
|
||||
final_coords = quant_coords.unique(dim=0)
|
||||
num_tokens = final_coords.shape[0]
|
||||
|
||||
if num_tokens < max_tokens or hr_resolution <= 1024:
|
||||
break
|
||||
hr_resolution -= 128
|
||||
|
||||
return IO.NodeOutput(final_coords,)
|
||||
|
||||
items = split_batched_sparse_latent(
|
||||
shape_latent_512["samples"],
|
||||
shape_latent_512["coords"],
|
||||
coord_counts,
|
||||
)
|
||||
decoder_dtype = next(decoder.parameters()).dtype
|
||||
|
||||
sample_hr_coords = []
|
||||
for feats_i, coords_i in items:
|
||||
feats_i = feats_i.to(device)
|
||||
coords_i = coords_i.to(device).clone()
|
||||
coords_i[:, 0] = 0
|
||||
slat_i = shape_norm(feats_i, coords_i)
|
||||
slat_i.feats = slat_i.feats.to(decoder_dtype)
|
||||
sample_hr_coords.append(decoder.upsample(slat_i, upsample_times=4))
|
||||
|
||||
hr_resolution = target_resolution
|
||||
while True:
|
||||
exceeds_limit = False
|
||||
for hr_coords_i in sample_hr_coords:
|
||||
quant_coords_i = torch.cat([
|
||||
hr_coords_i[:, :1],
|
||||
((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
|
||||
], dim=1)
|
||||
if quant_coords_i.unique(dim=0).shape[0] >= max_tokens:
|
||||
exceeds_limit = True
|
||||
break
|
||||
if not exceeds_limit or hr_resolution <= 1024:
|
||||
break
|
||||
hr_resolution -= 128
|
||||
|
||||
final_coords_list = []
|
||||
output_coord_counts = []
|
||||
for sample_offset, hr_coords_i in enumerate(sample_hr_coords):
|
||||
quant_coords_i = torch.cat([
|
||||
hr_coords_i[:, :1],
|
||||
((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
|
||||
], dim=1)
|
||||
final_coords_i = quant_coords_i.unique(dim=0)
|
||||
final_coords_i = final_coords_i.clone()
|
||||
final_coords_i[:, 0] = sample_offset
|
||||
final_coords_list.append(final_coords_i)
|
||||
output_coord_counts.append(int(final_coords_i.shape[0]))
|
||||
|
||||
coords = torch.cat(final_coords_list, dim=0)
|
||||
output = Types.VOXEL(coords)
|
||||
output.coord_counts = torch.tensor(output_coord_counts, dtype=torch.int64)
|
||||
output.resolutions = torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64)
|
||||
output.upsampled = True
|
||||
|
||||
return IO.NodeOutput(output,)
|
||||
|
||||
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
||||
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
||||
|
||||
def run_conditioning(model, cropped_img_tensor, include_1024=True):
|
||||
model_internal = model.model
|
||||
device = comfy.model_management.intermediate_device()
|
||||
torch_device = comfy.model_management.get_torch_device()
|
||||
|
||||
def prepare_tensor(pil_img, size):
|
||||
resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS)
|
||||
img_np = np.array(resized_pil).astype(np.float32) / 255.0
|
||||
img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device)
|
||||
return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
||||
|
||||
model_internal.image_size = 512
|
||||
input_512 = prepare_tensor(cropped_img_tensor, 512)
|
||||
cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0]
|
||||
|
||||
cond_1024 = None
|
||||
if include_1024:
|
||||
model_internal.image_size = 1024
|
||||
input_1024 = prepare_tensor(cropped_img_tensor, 1024)
|
||||
cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0]
|
||||
|
||||
conditioning = {
|
||||
'cond_512': cond_512.to(device),
|
||||
'neg_cond': torch.zeros_like(cond_512).to(device),
|
||||
}
|
||||
if cond_1024 is not None:
|
||||
conditioning['cond_1024'] = cond_1024.to(device)
|
||||
|
||||
return conditioning
|
||||
class Trellis2Conditioning(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Trellis2Conditioning",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
IO.ClipVision.Input("clip_vision_model"),
|
||||
IO.Image.Input("image"),
|
||||
IO.Mask.Input("mask"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
IO.Conditioning.Output(display_name="negative"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip_vision_model, image, mask) -> IO.NodeOutput:
|
||||
# Normalize to batched form so per-image conditioning loop below is uniform.
|
||||
if image.ndim == 3:
|
||||
image = image.unsqueeze(0)
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
batch_size = image.shape[0]
|
||||
if mask.shape[0] == 1 and batch_size > 1:
|
||||
mask = mask.expand(batch_size, -1, -1)
|
||||
elif mask.shape[0] != batch_size:
|
||||
raise ValueError(f"Trellis2Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}")
|
||||
|
||||
cond_512_list = []
|
||||
cond_1024_list = []
|
||||
|
||||
for b in range(batch_size):
|
||||
item_image = image[b]
|
||||
item_mask = mask[b] if mask.size(0) > 1 else mask[0]
|
||||
|
||||
img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
|
||||
pil_img = Image.fromarray(img_np)
|
||||
pil_mask = Image.fromarray(mask_np)
|
||||
|
||||
max_size = max(pil_img.size)
|
||||
scale = min(1.0, 1024 / max_size)
|
||||
if scale < 1.0:
|
||||
new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale)
|
||||
pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST)
|
||||
|
||||
rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8)
|
||||
rgba_np[:, :, :3] = np.array(pil_img)
|
||||
rgba_np[:, :, 3] = np.array(pil_mask)
|
||||
|
||||
alpha = rgba_np[:, :, 3]
|
||||
bbox_coords = np.argwhere(alpha > 0.8 * 255)
|
||||
|
||||
if len(bbox_coords) > 0:
|
||||
y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1])
|
||||
y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1])
|
||||
|
||||
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
|
||||
size = max(y_max - y_min, x_max - x_min)
|
||||
|
||||
crop_x1 = int(center_x - size // 2)
|
||||
crop_y1 = int(center_y - size // 2)
|
||||
crop_x2 = int(center_x + size // 2)
|
||||
crop_y2 = int(center_y + size // 2)
|
||||
|
||||
rgba_pil = Image.fromarray(rgba_np)
|
||||
cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2))
|
||||
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
|
||||
else:
|
||||
import logging
|
||||
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
|
||||
cropped_np = rgba_np.astype(np.float32) / 255.0
|
||||
|
||||
bg_rgb = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
fg = cropped_np[:, :, :3]
|
||||
alpha_float = cropped_np[:, :, 3:4]
|
||||
composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float)
|
||||
|
||||
# to match trellis2 code (quantize -> dequantize)
|
||||
composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
|
||||
|
||||
cropped_pil = Image.fromarray(composite_uint8)
|
||||
|
||||
item_conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True)
|
||||
cond_512_list.append(item_conditioning["cond_512"])
|
||||
cond_1024_list.append(item_conditioning["cond_1024"])
|
||||
|
||||
cond_512_batched = torch.cat(cond_512_list, dim=0)
|
||||
cond_1024_batched = torch.cat(cond_1024_list, dim=0)
|
||||
neg_cond_batched = torch.zeros_like(cond_512_batched)
|
||||
neg_embeds_batched = torch.zeros_like(cond_1024_batched)
|
||||
|
||||
positive = [[cond_512_batched, {"embeds": cond_1024_batched}]]
|
||||
negative = [[neg_cond_batched, {"embeds": neg_embeds_batched}]]
|
||||
return IO.NodeOutput(positive, negative)
|
||||
|
||||
class EmptyTrellis2ShapeLatent(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyTrellis2ShapeLatent",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Voxel.Input(
|
||||
"voxel",
|
||||
tooltip=(
|
||||
"Shape structure input. Accepts either a voxel structure "
|
||||
"or upsampled voxel coordinates from a previous cascade stage."
|
||||
)
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, voxel):
|
||||
# to accept the upscaled coords
|
||||
is_512_pass = False
|
||||
upsampled = hasattr(voxel, "upsampled")
|
||||
if upsampled:
|
||||
voxel = voxel.data
|
||||
|
||||
if not upsampled:
|
||||
decoded = voxel.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
is_512_pass = True
|
||||
|
||||
else:
|
||||
coords = voxel.int()
|
||||
is_512_pass = False
|
||||
|
||||
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
|
||||
in_channels = 32
|
||||
# image like format
|
||||
latent = torch.zeros(batch_size, in_channels, max_tokens, 1)
|
||||
|
||||
if is_512_pass:
|
||||
generation_mode = "shape_generation_512"
|
||||
else:
|
||||
generation_mode = "shape_generation"
|
||||
return IO.NodeOutput({"samples": latent, "coords": coords, "coord_counts": counts, "type": "trellis2",
|
||||
"model_options": {"generation_mode": generation_mode, "coords": coords, "coord_counts": counts}})
|
||||
|
||||
class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyTrellis2LatentTexture",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Voxel.Input(
|
||||
"voxel",
|
||||
tooltip=(
|
||||
"Shape structure input. Accepts either a voxel structure "
|
||||
"or upsampled voxel coordinates from a previous cascade stage."
|
||||
)
|
||||
),
|
||||
IO.Latent.Input("shape_latent"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, voxel, shape_latent):
|
||||
channels = 32
|
||||
upsampled = hasattr(voxel, "upsampled")
|
||||
if upsampled:
|
||||
voxel = voxel.data
|
||||
|
||||
if not upsampled:
|
||||
decoded = voxel.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
else:
|
||||
coords = voxel.int()
|
||||
|
||||
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
|
||||
|
||||
shape_latent = shape_latent["samples"]
|
||||
if shape_latent.ndim == 4:
|
||||
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
|
||||
|
||||
latent = torch.zeros(batch_size, channels, max_tokens, 1)
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts,
|
||||
"model_options": {"generation_mode": "texture_generation",
|
||||
"coords": coords, "coord_counts": counts, "shape_slat": shape_latent}})
|
||||
|
||||
|
||||
class EmptyTrellis2LatentStructure(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyTrellis2LatentStructure",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, batch_size):
|
||||
in_channels = 8
|
||||
resolution = 16
|
||||
latent = torch.zeros(batch_size, in_channels, resolution, resolution, resolution)
|
||||
output = {
|
||||
"samples": latent,
|
||||
"type": "trellis2",
|
||||
}
|
||||
return IO.NodeOutput(output)
|
||||
|
||||
class Trellis2Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
Trellis2Conditioning,
|
||||
EmptyTrellis2ShapeLatent,
|
||||
EmptyTrellis2LatentStructure,
|
||||
EmptyTrellis2LatentTexture,
|
||||
VaeDecodeTextureTrellis,
|
||||
VaeDecodeShapeTrellis,
|
||||
VaeDecodeStructureTrellis2,
|
||||
Trellis2UpsampleCascade,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> Trellis2Extension:
|
||||
return Trellis2Extension()
|
||||
@ -60,7 +60,7 @@ folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geom
|
||||
|
||||
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["detection"] = ([os.path.join(models_dir, "detection")], supported_pt_extensions)
|
||||
folder_names_and_paths["mediapipe"] = ([os.path.join(models_dir, "mediapipe")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
|
||||
6
nodes.py
6
nodes.py
@ -1537,10 +1537,6 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
if "noise_mask" in latent:
|
||||
noise_mask = latent["noise_mask"]
|
||||
|
||||
if "model_options" in latent:
|
||||
inner = model.model.diffusion_model
|
||||
inner.meta = latent["model_options"]
|
||||
|
||||
callback = latent_preview.prepare_callback(model, steps)
|
||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||
@ -2434,8 +2430,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_toolkit.py",
|
||||
"nodes_replacements.py",
|
||||
"nodes_nag.py",
|
||||
"nodes_trellis2.py",
|
||||
"nodes_mesh_postprocess.py",
|
||||
"nodes_sdpose.py",
|
||||
"nodes_math.py",
|
||||
"nodes_number_convert.py",
|
||||
|
||||
673
openapi.yaml
673
openapi.yaml
@ -1556,6 +1556,12 @@ paths:
|
||||
type: string
|
||||
enum: [asc, desc]
|
||||
description: Sort direction
|
||||
- name: job_ids
|
||||
in: query
|
||||
schema:
|
||||
type: string
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Comma-separated UUIDs to filter assets by associated job."
|
||||
- name: include_public
|
||||
in: query
|
||||
schema:
|
||||
@ -2508,25 +2514,37 @@ paths:
|
||||
|
||||
/api/assets/import:
|
||||
post:
|
||||
operationId: importPublishedAssets
|
||||
operationId: importAssets
|
||||
tags: [assets]
|
||||
summary: "[cloud-only] Import published assets into the caller's library"
|
||||
description: |
|
||||
[cloud-only] Imports the specified published assets into the caller's asset library. New DB records reference the same storage objects; no file copying occurs. Assets the caller already owns (by hash) are deduplicated. The `id` field on each returned `AssetInfo` is the caller's newly-created private asset ID, not the published asset ID supplied in the request.
|
||||
summary: Import assets from external URLs
|
||||
description: "[cloud-only] Imports one or more assets from external URLs into the cloud asset store."
|
||||
x-runtime: [cloud]
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ImportPublishedAssetsRequest"
|
||||
type: object
|
||||
required:
|
||||
- imports
|
||||
properties:
|
||||
imports:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/AssetImportRequest"
|
||||
description: Assets to import
|
||||
responses:
|
||||
"200":
|
||||
description: Successfully imported assets
|
||||
description: Import initiated
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ImportPublishedAssetsResponse"
|
||||
type: object
|
||||
properties:
|
||||
assets:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/Asset"
|
||||
"400":
|
||||
description: Bad request
|
||||
content:
|
||||
@ -3772,295 +3790,6 @@ paths:
|
||||
schema:
|
||||
$ref: "#/components/schemas/JwksResponse"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OAuth 2.1 / RFC 7591 Dynamic Client Registration (cloud)
|
||||
# ---------------------------------------------------------------------------
|
||||
/.well-known/oauth-authorization-server:
|
||||
get:
|
||||
operationId: getOAuthAuthorizationServer
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] OAuth 2.1 authorization-server metadata (RFC 8414)"
|
||||
description: "[cloud-only] Public metadata document for OAuth 2.1 clients. Cached 5 minutes."
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
responses:
|
||||
"200":
|
||||
description: Authorization-server metadata
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthAuthorizationServerMetadata"
|
||||
"404":
|
||||
description: OAuth disabled
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
|
||||
/.well-known/oauth-protected-resource:
|
||||
get:
|
||||
operationId: getOAuthProtectedResource
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] OAuth 2.1 protected-resource metadata (RFC 9728)"
|
||||
description: "[cloud-only] Public metadata describing the currently advertised protected resource. Cached 5 minutes."
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
responses:
|
||||
"200":
|
||||
description: Protected-resource metadata
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthProtectedResourceMetadata"
|
||||
"404":
|
||||
description: OAuth disabled or no active resource configured
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
|
||||
/oauth/authorize:
|
||||
get:
|
||||
operationId: getOAuthAuthorize
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] Begin or resume an OAuth 2.1 authorization request"
|
||||
description: |
|
||||
[cloud-only] Two modes:
|
||||
- **Initial entry** (OAuth params present): validates client/redirect/resource/scopes, persists a server-side authorization-request row, and either redirects (no session / unverified email) to the configured frontend login URL carrying only the opaque `oauth_request_id`, or returns the JSON consent challenge for the frontend to render.
|
||||
- **Resume** (`oauth_request_id` present): loads the server-side row, fails closed if expired/consumed/unknown, returns the JSON consent challenge. Browser-replayed OAuth params are intentionally ignored.
|
||||
|
||||
The frontend renders the consent UI from the JSON payload and POSTs the user's decision back to this endpoint.
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
parameters:
|
||||
- { name: response_type, in: query, required: false, schema: { type: string } }
|
||||
- { name: client_id, in: query, required: false, schema: { type: string } }
|
||||
- { name: redirect_uri, in: query, required: false, schema: { type: string } }
|
||||
- { name: scope, in: query, required: false, schema: { type: string } }
|
||||
- name: state
|
||||
in: query
|
||||
required: false
|
||||
schema: { type: string }
|
||||
description: |
|
||||
RFC 6749 §10.12 marks `state` as RECOMMENDED. Cloud hardening makes it REQUIRED on the initial-entry path (omitted only on the resume path where `oauth_request_id` is supplied instead). This parameter is `required: false` at the spec level only because the operation is dual-mode (initial entry vs. resume); the runtime rejects empty `state` on the initial-entry path with a stable `invalid_request` 400.
|
||||
- { name: code_challenge, in: query, required: false, schema: { type: string } }
|
||||
- { name: code_challenge_method, in: query, required: false, schema: { type: string } }
|
||||
- { name: resource, in: query, required: false, schema: { type: string } }
|
||||
- { name: oauth_request_id, in: query, required: false, schema: { type: string } }
|
||||
responses:
|
||||
"200":
|
||||
description: Consent challenge payload (session present, email verified). Frontend renders the consent UI from this payload and POSTs back to /oauth/authorize.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthConsentChallenge"
|
||||
"302":
|
||||
description: Redirect to login (no session / unverified email) or to registered redirect_uri (pre-validated client error)
|
||||
headers:
|
||||
Location:
|
||||
schema:
|
||||
type: string
|
||||
"400":
|
||||
description: Invalid authorize request (pre-redirect failure — unknown client, redirect mismatch, malformed params)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
"404":
|
||||
description: OAuth disabled
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
post:
|
||||
operationId: postOAuthAuthorize
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] Submit OAuth consent decision"
|
||||
description: |
|
||||
[cloud-only] JSON-only consent submission. The handler verifies the per-row CSRF token, atomically marks the authorization request consumed (single-use covers both allow and deny paths), then returns the redirect URL the browser must navigate to. The URL contains either `code` + original `state` for allow, or the RFC 6749 §5.2 error and `state` for deny.
|
||||
|
||||
Workspace membership is re-checked at submission time. Consent is persisted keyed by `(user_id, client_id, resource_id, workspace_id)`; broadening the previously approved scope set requires a fresh consent flow.
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required: [oauth_request_id, csrf_token, decision, workspace_id]
|
||||
properties:
|
||||
oauth_request_id: { type: string, format: uuid }
|
||||
csrf_token: { type: string }
|
||||
decision: { type: string, enum: [allow, deny] }
|
||||
workspace_id: { type: string }
|
||||
responses:
|
||||
"200":
|
||||
description: Redirect URL for the frontend to navigate to (allow → with code+state; deny → with error+state)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthAuthorizeRedirectResponse"
|
||||
"400":
|
||||
description: Bad request (CSRF mismatch, expired/consumed request, inaccessible workspace)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
"403":
|
||||
description: Scope broadening on consent re-grant — fresh consent flow required
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
"404":
|
||||
description: OAuth disabled
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
|
||||
/oauth/token:
|
||||
post:
|
||||
operationId: postOAuthToken
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] Exchange authorization code or refresh token for a resource-bound access token"
|
||||
description: |
|
||||
[cloud-only] OAuth 2.1 token endpoint (RFC 6749 §3.2). Public clients only — `client_secret` is rejected.
|
||||
|
||||
Two grant types are supported:
|
||||
- `authorization_code` — exchanges the code minted by `/oauth/authorize` (with PKCE verifier) for an access token + first refresh token. Single-use; reuse fails closed.
|
||||
- `refresh_token` — rotates the refresh token. Old token immediately invalid; presenting an already-rotated token revokes the entire token family and emits a security metric.
|
||||
|
||||
Both grant types re-validate canonical user state, current workspace membership, and the resource's active flag at every mint. A code or refresh token bound to a deactivated resource fails closed.
|
||||
|
||||
Errors follow RFC 6749 §5.2. Logs never contain raw codes, refresh tokens, or minted tokens.
|
||||
|
||||
Per RFC 6749 §5.1, every 200 and 400 response carries `Cache-Control: no-store` and `Pragma: no-cache` so intermediaries cannot cache token-bearing or state-change-reason responses.
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/x-www-form-urlencoded:
|
||||
schema:
|
||||
type: object
|
||||
required: [grant_type, client_id]
|
||||
properties:
|
||||
grant_type: { type: string, enum: [authorization_code, refresh_token] }
|
||||
client_id: { type: string }
|
||||
code: { type: string }
|
||||
redirect_uri: { type: string }
|
||||
code_verifier: { type: string }
|
||||
refresh_token: { type: string }
|
||||
scope: { type: string }
|
||||
client_secret: { type: string }
|
||||
responses:
|
||||
"200":
|
||||
description: New token pair
|
||||
headers:
|
||||
Cache-Control:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-store" per RFC 6749 §5.1'
|
||||
Pragma:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-cache" per RFC 6749 §5.1'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthTokenResponse"
|
||||
"400":
|
||||
description: RFC 6749 §5.2 error
|
||||
headers:
|
||||
Cache-Control:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-store" per RFC 6749 §5.1'
|
||||
Pragma:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-cache" per RFC 6749 §5.1'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthTokenError"
|
||||
"404":
|
||||
description: OAuth disabled
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
|
||||
/oauth/register:
|
||||
post:
|
||||
operationId: postOAuthRegister
|
||||
tags: [auth]
|
||||
summary: "[cloud-only] Dynamic Client Registration (RFC 7591)"
|
||||
description: |
|
||||
[cloud-only] Public, unauthenticated, insert-only RFC 7591 §3.1 client registration. Used by MCP-spec-compliant clients to self-register a public OAuth client without operator involvement.
|
||||
|
||||
Policy:
|
||||
|
||||
- Public clients only — `token_endpoint_auth_method` is forced to `none`. Confidential-client registration is out of scope this phase.
|
||||
- Server-owned `resource_grants`. Caller-supplied `scope` or `resource_grants` is rejected as `invalid_client_metadata` (would be a privilege-escalation surface). Dynamic clients receive the same scopes the active resource publishes.
|
||||
- Application-type-aware redirect URI policy. `application_type=native` accepts loopback (`127.0.0.1`, `::1`, `localhost`) and reverse-DNS-shaped custom schemes; `application_type=web` accepts HTTPS to hosts in an operator-controlled allowlist only. `application_type` is REQUIRED on the request — missing or empty rejects with `invalid_client_metadata`.
|
||||
- Anti-impersonation: reserved client names are rejected from third parties via NFKC-folded compare.
|
||||
- Generated `client_id` carries a stable prefix to distinguish dynamic from seeded clients in audit logs.
|
||||
- Cache-Control: `no-store` on every 201 and 400 response (the response carries fresh credentials and rejection reasons).
|
||||
x-runtime: [cloud]
|
||||
security: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthRegisterRequest"
|
||||
responses:
|
||||
"201":
|
||||
description: Registered. Body echoes the metadata RFC 7591 §3.2.1 requires.
|
||||
headers:
|
||||
Cache-Control:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-store"'
|
||||
Pragma:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-cache"'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthRegisterResponse"
|
||||
"400":
|
||||
description: RFC 7591 §3.2.2 invalid client metadata
|
||||
headers:
|
||||
Cache-Control:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-store"'
|
||||
Pragma:
|
||||
schema:
|
||||
type: string
|
||||
description: 'Always "no-cache"'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/OAuthRegisterError"
|
||||
"404":
|
||||
description: OAuth disabled
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
"503":
|
||||
description: No active resource is configured — DCR cannot mint a usable client until an active resource row is seeded.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CloudError"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Billing (cloud)
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -7361,35 +7090,24 @@ components:
|
||||
type: string
|
||||
description: Target path on the runtime filesystem
|
||||
|
||||
ImportPublishedAssetsRequest:
|
||||
AssetImportRequest:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Request body for importing published assets into the caller's library."
|
||||
description: "[cloud-only] A single asset to import from an external URL."
|
||||
required:
|
||||
- published_asset_ids
|
||||
- url
|
||||
properties:
|
||||
published_asset_ids:
|
||||
url:
|
||||
type: string
|
||||
format: uri
|
||||
description: URL of the asset to import
|
||||
name:
|
||||
type: string
|
||||
description: Display name for the imported asset
|
||||
tags:
|
||||
type: array
|
||||
description: IDs of published assets (inputs and models) to import.
|
||||
items:
|
||||
type: string
|
||||
share_id:
|
||||
type: string
|
||||
nullable: true
|
||||
description: |
|
||||
Optional. Share ID of the published workflow these assets belong to. When provided (non-null, non-empty): all `published_asset_ids` must belong to this share's workflow version; returns 400 if the share is not found or any asset does not belong to it. When omitted, null, or empty string: no share-scoped validation is performed and the assets are validated only against global rules (preserved for clients that have not yet adopted `share_id`).
|
||||
|
||||
ImportPublishedAssetsResponse:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Response after importing published assets. Each returned `AssetInfo.id` is the caller's newly-created private asset ID, not the published asset ID supplied in the request."
|
||||
required:
|
||||
- assets
|
||||
properties:
|
||||
assets:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/AssetInfo"
|
||||
|
||||
RemoteAssetMetadata:
|
||||
type: object
|
||||
@ -7706,325 +7424,6 @@ components:
|
||||
description: RSA exponent (base64url)
|
||||
additionalProperties: true
|
||||
|
||||
OAuthAuthorizationServerMetadata:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] OAuth 2.1 authorization-server metadata (RFC 8414)."
|
||||
required:
|
||||
- issuer
|
||||
- authorization_endpoint
|
||||
- token_endpoint
|
||||
- jwks_uri
|
||||
- response_types_supported
|
||||
- grant_types_supported
|
||||
- code_challenge_methods_supported
|
||||
- token_endpoint_auth_methods_supported
|
||||
properties:
|
||||
issuer:
|
||||
type: string
|
||||
format: uri
|
||||
authorization_endpoint:
|
||||
type: string
|
||||
format: uri
|
||||
token_endpoint:
|
||||
type: string
|
||||
format: uri
|
||||
jwks_uri:
|
||||
type: string
|
||||
format: uri
|
||||
registration_endpoint:
|
||||
type: string
|
||||
format: uri
|
||||
description: "[cloud-only] RFC 7591 §3.1 Dynamic Client Registration endpoint. Advertised so MCP-spec-compliant clients can auto-discover and self-register without operator involvement. Present only when DCR is enabled."
|
||||
response_types_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
grant_types_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
code_challenge_methods_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
token_endpoint_auth_methods_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
scopes_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
||||
OAuthProtectedResourceMetadata:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] OAuth 2.1 protected-resource metadata (RFC 9728)."
|
||||
required:
|
||||
- resource
|
||||
- authorization_servers
|
||||
- scopes_supported
|
||||
properties:
|
||||
resource:
|
||||
type: string
|
||||
format: uri
|
||||
authorization_servers:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
format: uri
|
||||
scopes_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
bearer_methods_supported:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
||||
OAuthConsentChallenge:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Server-side state describing the OAuth consent decision the user is being asked to make. Returned by GET /oauth/authorize when a valid session exists; the frontend renders the consent UI from this payload and POSTs the decision back. Browser never sees the original OAuth params on resume."
|
||||
required:
|
||||
- oauth_request_id
|
||||
- csrf_token
|
||||
- client_display_name
|
||||
- resource_display_name
|
||||
- scopes
|
||||
- workspaces
|
||||
properties:
|
||||
oauth_request_id:
|
||||
type: string
|
||||
format: uuid
|
||||
description: Opaque server-side identifier for the authorization-request row. Carried back unchanged in the consent submission.
|
||||
csrf_token:
|
||||
type: string
|
||||
description: Per-row CSRF token bound to this authorization request (not to the session). Must be echoed back on POST.
|
||||
client_display_name:
|
||||
type: string
|
||||
description: Human-readable name of the OAuth client requesting authorization.
|
||||
resource_display_name:
|
||||
type: string
|
||||
description: Human-readable name of the protected resource.
|
||||
scopes:
|
||||
type: array
|
||||
description: Scopes the client is requesting for this resource. The frontend should present these for the user to approve.
|
||||
items:
|
||||
type: string
|
||||
workspaces:
|
||||
type: array
|
||||
description: Workspaces the user can select from. Membership is re-checked on POST.
|
||||
items:
|
||||
$ref: "#/components/schemas/OAuthConsentChallengeWorkspace"
|
||||
|
||||
OAuthConsentChallengeWorkspace:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] One workspace option presented in the OAuth consent challenge."
|
||||
required: [id, name, type, role]
|
||||
properties:
|
||||
id: { type: string }
|
||||
name: { type: string }
|
||||
type: { type: string, enum: [personal, team] }
|
||||
role: { type: string, enum: [owner, member] }
|
||||
|
||||
OAuthAuthorizeRedirectResponse:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Redirect target produced after a JSON consent submission. The frontend must navigate the browser to this URL so custom-scheme client callbacks work without relying on fetch-visible 302 headers."
|
||||
required:
|
||||
- redirect_url
|
||||
properties:
|
||||
redirect_url:
|
||||
type: string
|
||||
format: uri
|
||||
description: OAuth client redirect URI with either code+state for allow, or error+state for deny.
|
||||
|
||||
OAuthTokenResponse:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] RFC 6749 §5.1 successful token response."
|
||||
required: [access_token, token_type, expires_in, refresh_token, scope]
|
||||
properties:
|
||||
access_token:
|
||||
type: string
|
||||
description: Resource-bound access token (audience matches the protected resource).
|
||||
token_type:
|
||||
type: string
|
||||
enum: [Bearer]
|
||||
expires_in:
|
||||
type: integer
|
||||
description: Access token lifetime in seconds.
|
||||
refresh_token:
|
||||
type: string
|
||||
description: Opaque refresh token. Rotates on every successful refresh; presenting an already-rotated token revokes the entire family.
|
||||
scope:
|
||||
type: string
|
||||
description: Space-delimited scopes granted with this token.
|
||||
|
||||
OAuthTokenError:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] RFC 6749 §5.2 error response."
|
||||
required: [error]
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
description: 'RFC 6749 §5.2 error code: invalid_request, invalid_client, invalid_grant, unauthorized_client, unsupported_grant_type, invalid_scope.'
|
||||
error_description:
|
||||
type: string
|
||||
description: Human-readable, no leak of internal storage state.
|
||||
|
||||
OAuthRegisterRequest:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
additionalProperties: false
|
||||
description: "[cloud-only] RFC 7591 §2 client metadata document. Only the fields the server honors are listed; presence of `scope` or `resource_grants` in the request is rejected (`invalid_client_metadata`) because those are server-owned for dynamic clients."
|
||||
required:
|
||||
- redirect_uris
|
||||
- application_type
|
||||
properties:
|
||||
redirect_uris:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
minItems: 1
|
||||
maxItems: 5
|
||||
description: 1–5 redirect URIs. Validated against `application_type` policy.
|
||||
client_name:
|
||||
type: string
|
||||
maxLength: 100
|
||||
description: Human-readable name shown in the consent UI. Reserved-name list rejects impersonation of major clients.
|
||||
application_type:
|
||||
type: string
|
||||
enum: [native, web]
|
||||
description: |
|
||||
RFC 7591 §2 application_type. **REQUIRED** — clients MUST declare intent; the server does not default this field. `native` for desktop / CLI / MCP-spec-strict clients (loopback redirects); `web` for hosted clients (HTTPS only, host must be allowlisted). A missing or explicitly empty `application_type` rejects with `invalid_client_metadata`.
|
||||
token_endpoint_auth_method:
|
||||
type: string
|
||||
enum: [none]
|
||||
description: 'Public clients only this phase — must be `none` if present. The server forces `none` regardless.'
|
||||
grant_types:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
enum: [authorization_code, refresh_token]
|
||||
description: Optional. Defaults to `["authorization_code","refresh_token"]`.
|
||||
response_types:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
enum: [code]
|
||||
description: Optional. Defaults to `["code"]`.
|
||||
scope:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Dynamic clients do not pick scopes — the server assigns scopes from the active resource's published list. Sending `scope` in the registration body is treated as a privilege-escalation attempt and returns `invalid_client_metadata`."
|
||||
resource_grants:
|
||||
type: object
|
||||
nullable: true
|
||||
additionalProperties:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: "**REJECTED IF PRESENT.** Same reason as `scope`. The set of resources and scopes a dynamic client may request is server-policy, not request-driven."
|
||||
client_uri:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
logo_uri:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
tos_uri:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
policy_uri:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
software_id:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
software_version:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
contacts:
|
||||
type: array
|
||||
nullable: true
|
||||
items:
|
||||
type: string
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
jwks:
|
||||
type: object
|
||||
nullable: true
|
||||
additionalProperties: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
jwks_uri:
|
||||
type: string
|
||||
nullable: true
|
||||
description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase."
|
||||
|
||||
OAuthRegisterResponse:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] RFC 7591 §3.2.1 successful registration response."
|
||||
required:
|
||||
- client_id
|
||||
- client_id_issued_at
|
||||
- redirect_uris
|
||||
- grant_types
|
||||
- response_types
|
||||
- token_endpoint_auth_method
|
||||
- application_type
|
||||
properties:
|
||||
client_id:
|
||||
type: string
|
||||
description: Server-generated client_id.
|
||||
client_id_issued_at:
|
||||
type: integer
|
||||
format: int64
|
||||
description: Unix timestamp (seconds) when the client was registered.
|
||||
client_name:
|
||||
type: string
|
||||
redirect_uris:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
grant_types:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
response_types:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
token_endpoint_auth_method:
|
||||
type: string
|
||||
enum: [none]
|
||||
application_type:
|
||||
type: string
|
||||
enum: [native, web]
|
||||
|
||||
OAuthRegisterError:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] RFC 7591 §3.2.2 error response."
|
||||
required:
|
||||
- error
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
enum: [invalid_redirect_uri, invalid_client_metadata]
|
||||
error_description:
|
||||
type: string
|
||||
nullable: true
|
||||
|
||||
BillingBalance:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
|
||||
@ -21,6 +21,7 @@ from app.assets.database.queries import (
|
||||
get_reference_ids_by_ids,
|
||||
ensure_tags_exist,
|
||||
add_tags_to_reference,
|
||||
set_reference_tags,
|
||||
)
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
@ -159,6 +160,153 @@ class TestListReferencesPage:
|
||||
assert refs[0].name == "large"
|
||||
|
||||
|
||||
class TestTagRetrievalOrder:
|
||||
"""End-to-end check: tags written through the public write paths come
|
||||
back from the public read paths in insertion order rather than the
|
||||
composite-PK alphabetical order SQLite would otherwise impose.
|
||||
|
||||
Each test deliberately picks tag names that would sort differently
|
||||
under alphabetical vs insertion order, so an alphabetical regression
|
||||
fails loudly.
|
||||
"""
|
||||
|
||||
def _make_ref(self, session: Session) -> AssetReference:
|
||||
asset = _make_asset(session, "h1")
|
||||
return _make_reference(session, asset, name="x.bin")
|
||||
|
||||
def test_set_reference_tags_preserves_input_order_in_list(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
# "checkpoints" < "models" alphabetically; if added_at stagger
|
||||
# works, list_references_page returns insertion order.
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
assert tag_map[ref.id] == ["models", "checkpoints"]
|
||||
|
||||
def test_set_reference_tags_preserves_input_order_in_fetch(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
# Subpath tag sorts before "models" alphabetically.
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "diffusers/kolors/text_encoder"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
result = fetch_reference_asset_and_tags(session, ref.id)
|
||||
assert result is not None
|
||||
_, _, tags = result
|
||||
# Bucket-prefix expansion appends the standalone `diffusers` token
|
||||
# at path-tier (microsecond stagger) so FE set-membership filters
|
||||
# match nested category paths.
|
||||
assert tags == ["models", "diffusers/kolors/text_encoder", "diffusers"]
|
||||
|
||||
def test_add_tags_to_reference_lands_after_path_tags(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
|
||||
session.commit()
|
||||
|
||||
# "aaa-..." sorts before both path tags alphabetically. If added_at
|
||||
# stagger is missing, alphabetic tiebreak would hoist it to tags[0].
|
||||
add_tags_to_reference(
|
||||
session, reference_id=ref.id, tags=["aaa-user-tag"], origin="manual"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
assert tag_map[ref.id] == ["models", "checkpoints", "aaa-user-tag"]
|
||||
|
||||
def test_multi_tag_batch_lands_after_path_tags(self, session: Session):
|
||||
ref = self._make_ref(session)
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
|
||||
session.commit()
|
||||
|
||||
# Three user tags inserted in non-alphabetical input order. Per-tag
|
||||
# microsecond stagger should preserve at least the "user batch is
|
||||
# after path tags" property; within the user batch insertion order
|
||||
# is also preserved.
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["zzz-z", "favorite", "experiment-q4"],
|
||||
origin="manual",
|
||||
)
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
tags = tag_map[ref.id]
|
||||
assert tags[0:2] == ["models", "checkpoints"]
|
||||
assert set(tags[2:]) == {"zzz-z", "favorite", "experiment-q4"}
|
||||
|
||||
def test_user_batch_lands_after_path_batch_under_clock_collision(
|
||||
self, session: Session, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Windows-specific race: when two back-to-back commits share the
|
||||
same datetime.now() microsecond, the path-tier and user-tier
|
||||
added_at values used to collide and alphabetic tiebreak would
|
||||
hoist user tags ahead of path tags. The fix reads
|
||||
max(existing_added_at) for the reference and seeds the next batch
|
||||
past it, deterministically restoring insertion order.
|
||||
|
||||
This test simulates the collision by pinning get_utc_now() so the
|
||||
platform-dependent race becomes a platform-independent failure.
|
||||
"""
|
||||
ref = self._make_ref(session)
|
||||
|
||||
from datetime import datetime
|
||||
from app.assets.database import queries as queries_pkg
|
||||
from app.assets.database.queries import tags as tags_module
|
||||
|
||||
frozen = datetime(2026, 1, 1, 0, 0, 0)
|
||||
monkeypatch.setattr(tags_module, "get_utc_now", lambda: frozen)
|
||||
monkeypatch.setattr(queries_pkg, "get_utc_now", lambda: frozen, raising=False)
|
||||
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["models", "checkpoints"])
|
||||
session.commit()
|
||||
|
||||
# Same frozen timestamp — without the max(existing) seed, the
|
||||
# user batch would share added_at with the path batch and
|
||||
# `aaa-user-tag` would sort to position 0 via the alphabetic
|
||||
# tiebreaker.
|
||||
add_tags_to_reference(
|
||||
session, reference_id=ref.id, tags=["aaa-user-tag"], origin="manual"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
assert tag_map[ref.id] == ["models", "checkpoints", "aaa-user-tag"]
|
||||
|
||||
def test_remove_then_add_does_not_disrupt_path_tag_positions(
|
||||
self, session: Session
|
||||
):
|
||||
ref = self._make_ref(session)
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "loras/my/custom/path"],
|
||||
)
|
||||
session.commit()
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["temp-tag"])
|
||||
session.commit()
|
||||
from app.assets.database.queries import remove_tags_from_reference
|
||||
|
||||
remove_tags_from_reference(session, reference_id=ref.id, tags=["temp-tag"])
|
||||
session.commit()
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["second-tag"])
|
||||
session.commit()
|
||||
|
||||
_, tag_map, _ = list_references_page(session)
|
||||
# `loras` is expanded from the nested category path; user-added
|
||||
# tags trail behind it via the microsecond stagger.
|
||||
assert tag_map[ref.id] == [
|
||||
"models",
|
||||
"loras/my/custom/path",
|
||||
"loras",
|
||||
"second-tag",
|
||||
]
|
||||
|
||||
|
||||
class TestFetchReferenceAssetAndTags:
|
||||
def test_returns_none_for_nonexistent(self, session: Session):
|
||||
result = fetch_reference_asset_and_tags(session, "nonexistent")
|
||||
|
||||
@ -160,6 +160,120 @@ class TestAddTagsToReference:
|
||||
add_tags_to_reference(session, reference_id="nonexistent", tags=["x"])
|
||||
|
||||
|
||||
class TestBucketPrefixExpansion:
|
||||
"""The standalone bucket token must appear in the asset's tag set for
|
||||
nested category paths so FE filters like
|
||||
`include_tags=models,checkpoints` continue to match.
|
||||
"""
|
||||
|
||||
def test_set_reference_tags_inserts_bucket_for_nested_path(
|
||||
self, session: Session
|
||||
):
|
||||
asset = _make_asset(session, "hash-nested")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
result = set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "checkpoints/flux"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert set(result.total) == {"models", "checkpoints/flux", "checkpoints"}
|
||||
stored = get_reference_tags(session, reference_id=ref.id)
|
||||
# tag[1] keeps the slash-joined positional contract; the standalone
|
||||
# bucket lands after it via path-tier microsecond stagger so user
|
||||
# tags remain at the tail.
|
||||
assert stored[:3] == ["models", "checkpoints/flux", "checkpoints"]
|
||||
|
||||
def test_set_reference_tags_idempotent_on_replay(self, session: Session):
|
||||
asset = _make_asset(session, "hash-replay")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "checkpoints/flux"],
|
||||
)
|
||||
# Replay with the same caller-supplied set; expansion is already
|
||||
# baked in, so nothing should be added or removed.
|
||||
result = set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "checkpoints/flux"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert result.added == []
|
||||
assert result.removed == []
|
||||
assert set(result.total) == {"models", "checkpoints/flux", "checkpoints"}
|
||||
|
||||
def test_add_tags_to_reference_expands_bucket(self, session: Session):
|
||||
asset = _make_asset(session, "hash-add")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
result = add_tags_to_reference(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["loras/style/v2"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert set(result.added) == {"loras/style/v2", "loras"}
|
||||
stored = get_reference_tags(session, reference_id=ref.id)
|
||||
assert "loras" in stored
|
||||
assert "loras/style/v2" in stored
|
||||
|
||||
def test_add_tags_does_not_duplicate_existing_bucket(self, session: Session):
|
||||
asset = _make_asset(session, "hash-dedupe")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
add_tags_to_reference(
|
||||
session, reference_id=ref.id, tags=["models", "checkpoints"]
|
||||
)
|
||||
result = add_tags_to_reference(
|
||||
session, reference_id=ref.id, tags=["checkpoints/flux"]
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# `checkpoints` was already there from the first add; only the
|
||||
# slash-joined token is genuinely new.
|
||||
assert result.added == ["checkpoints/flux"]
|
||||
assert "checkpoints" in result.already_present
|
||||
|
||||
def test_flat_category_is_unaffected(self, session: Session):
|
||||
asset = _make_asset(session, "hash-flat")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
result = set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["models", "checkpoints"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert set(result.total) == {"models", "checkpoints"}
|
||||
assert get_reference_tags(session, reference_id=ref.id) == [
|
||||
"models",
|
||||
"checkpoints",
|
||||
]
|
||||
|
||||
def test_unknown_prefix_passes_through(self, session: Session):
|
||||
asset = _make_asset(session, "hash-user")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
# `my-org` isn't a registered bucket — the slash-joined user tag
|
||||
# should not trigger bucket expansion.
|
||||
result = set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=["my-org/team-a"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert result.total == ["my-org/team-a"]
|
||||
|
||||
|
||||
class TestRemoveTagsFromReference:
|
||||
def test_removes_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
|
||||
@ -4,7 +4,7 @@ from pathlib import Path
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag
|
||||
from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets
|
||||
|
||||
|
||||
@ -102,6 +102,82 @@ class TestBatchInsertSeedAssets:
|
||||
assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}"
|
||||
|
||||
|
||||
class TestBucketPrefixExpansionOnIngest:
|
||||
"""Path-scanning ingest must persist the standalone bucket token for
|
||||
nested category paths so the FE set-membership filter
|
||||
(`include_tags=models,checkpoints`) matches assets organized into
|
||||
subfolders (`models/checkpoints/flux/foo.safetensors`).
|
||||
"""
|
||||
|
||||
def test_nested_path_inserts_standalone_bucket(
|
||||
self, session: Session, temp_dir: Path
|
||||
):
|
||||
file_path = temp_dir / "flux.safetensors"
|
||||
file_path.write_bytes(b"content")
|
||||
|
||||
specs: list[SeedAssetSpec] = [
|
||||
{
|
||||
"abs_path": str(file_path),
|
||||
"size_bytes": 7,
|
||||
"mtime_ns": 1234567890000000000,
|
||||
"info_name": "flux",
|
||||
# Shape emitted by get_name_and_tags_from_asset_path for a
|
||||
# nested model path.
|
||||
"tags": ["models", "checkpoints/flux"],
|
||||
"fname": "flux.safetensors",
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": "application/safetensors",
|
||||
}
|
||||
]
|
||||
|
||||
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||
|
||||
assert result.inserted_refs == 1
|
||||
ref = session.query(AssetReference).filter_by(name="flux").one()
|
||||
stored = [
|
||||
row.tag_name
|
||||
for row in session.query(AssetReferenceTag)
|
||||
.filter_by(asset_reference_id=ref.id)
|
||||
.order_by(AssetReferenceTag.added_at.asc())
|
||||
.all()
|
||||
]
|
||||
assert stored == ["models", "checkpoints/flux", "checkpoints"]
|
||||
|
||||
def test_flat_path_remains_two_tags(
|
||||
self, session: Session, temp_dir: Path
|
||||
):
|
||||
file_path = temp_dir / "vanilla.safetensors"
|
||||
file_path.write_bytes(b"content")
|
||||
|
||||
specs: list[SeedAssetSpec] = [
|
||||
{
|
||||
"abs_path": str(file_path),
|
||||
"size_bytes": 7,
|
||||
"mtime_ns": 1234567890000000000,
|
||||
"info_name": "vanilla",
|
||||
"tags": ["models", "checkpoints"],
|
||||
"fname": "vanilla.safetensors",
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": "application/safetensors",
|
||||
}
|
||||
]
|
||||
|
||||
batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||
|
||||
ref = session.query(AssetReference).filter_by(name="vanilla").one()
|
||||
stored = {
|
||||
row.tag_name
|
||||
for row in session.query(AssetReferenceTag)
|
||||
.filter_by(asset_reference_id=ref.id)
|
||||
.all()
|
||||
}
|
||||
# Dedupe means flat layouts don't pick up a redundant `checkpoints`
|
||||
# row — tag[1] already serves both positional and set-membership.
|
||||
assert stored == {"models", "checkpoints"}
|
||||
|
||||
|
||||
class TestMetadataExtraction:
|
||||
def test_extracts_mime_type_for_model_files(self, temp_dir: Path):
|
||||
"""Verify metadata extraction returns correct mime_type for model files."""
|
||||
|
||||
@ -6,7 +6,11 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.services.path_utils import get_asset_category_and_relative_path
|
||||
from app.assets.services.path_utils import (
|
||||
get_asset_category_and_relative_path,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -38,6 +42,50 @@ def fake_dirs():
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_dirs_multi_bucket():
|
||||
"""Variant fixture with multiple model buckets (checkpoints + diffusers + loras)."""
|
||||
with tempfile.TemporaryDirectory() as root:
|
||||
root_path = Path(root)
|
||||
input_dir = root_path / "input"
|
||||
output_dir = root_path / "output"
|
||||
temp_dir = root_path / "temp"
|
||||
checkpoints_dir = root_path / "models" / "checkpoints"
|
||||
diffusers_dir = root_path / "models" / "diffusers"
|
||||
loras_dir = root_path / "models" / "loras"
|
||||
for d in (
|
||||
input_dir,
|
||||
output_dir,
|
||||
temp_dir,
|
||||
checkpoints_dir,
|
||||
diffusers_dir,
|
||||
loras_dir,
|
||||
):
|
||||
d.mkdir(parents=True)
|
||||
|
||||
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
|
||||
mock_fp.get_input_directory.return_value = str(input_dir)
|
||||
mock_fp.get_output_directory.return_value = str(output_dir)
|
||||
mock_fp.get_temp_directory.return_value = str(temp_dir)
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[
|
||||
("checkpoints", [str(checkpoints_dir)]),
|
||||
("diffusers", [str(diffusers_dir)]),
|
||||
("loras", [str(loras_dir)]),
|
||||
],
|
||||
):
|
||||
yield {
|
||||
"input": input_dir,
|
||||
"output": output_dir,
|
||||
"temp": temp_dir,
|
||||
"checkpoints": checkpoints_dir,
|
||||
"diffusers": diffusers_dir,
|
||||
"loras": loras_dir,
|
||||
}
|
||||
|
||||
|
||||
class TestGetAssetCategoryAndRelativePath:
|
||||
def test_input_file(self, fake_dirs):
|
||||
f = fake_dirs["input"] / "photo.png"
|
||||
@ -79,3 +127,161 @@ class TestGetAssetCategoryAndRelativePath:
|
||||
def test_unknown_path_raises(self, fake_dirs):
|
||||
with pytest.raises(ValueError, match="not within"):
|
||||
get_asset_category_and_relative_path("/some/random/path.png")
|
||||
|
||||
|
||||
class TestGetNameAndTagsFromAssetPath:
|
||||
"""tags collapse the parent subpath into a single slash-joined tag.
|
||||
|
||||
Consumers should be able to read ``tags[1]`` as a stable category
|
||||
identifier regardless of how deep the file lives in the bucket.
|
||||
"""
|
||||
|
||||
def test_flat_input(self, fake_dirs_multi_bucket):
|
||||
f = fake_dirs_multi_bucket["input"] / "photo.png"
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "photo.png"
|
||||
assert tags == ["input"]
|
||||
|
||||
def test_flat_output(self, fake_dirs_multi_bucket):
|
||||
f = fake_dirs_multi_bucket["output"] / "result_00001.png"
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "result_00001.png"
|
||||
assert tags == ["output"]
|
||||
|
||||
def test_flat_models_checkpoint(self, fake_dirs_multi_bucket):
|
||||
f = fake_dirs_multi_bucket["checkpoints"] / "flux.safetensors"
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "flux.safetensors"
|
||||
assert tags == ["models", "checkpoints"]
|
||||
|
||||
def test_diffusers_nested_subpath_slash_joined(self, fake_dirs_multi_bucket):
|
||||
"""Diffusers components live in nested directories — the full subpath
|
||||
must collapse into one tag so consumers can look up the model category
|
||||
via tags[1] regardless of nesting depth.
|
||||
|
||||
The subpath is lowercased to match the canonicalization
|
||||
:func:`ensure_tags_exist` applies on the write side; without that,
|
||||
the asset_reference_tags.tag_name FK to tags.name would fail for
|
||||
any path containing uppercase letters.
|
||||
"""
|
||||
nested = (
|
||||
fake_dirs_multi_bucket["diffusers"]
|
||||
/ "Kolors"
|
||||
/ "text_encoder"
|
||||
)
|
||||
nested.mkdir(parents=True)
|
||||
f = nested / "model.safetensors"
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "model.safetensors"
|
||||
assert tags == ["models", "diffusers/kolors/text_encoder"]
|
||||
|
||||
def test_deep_lora_user_subpath_slash_joined(self, fake_dirs_multi_bucket):
|
||||
"""User-created subdirectories under a model bucket also collapse to a
|
||||
single tag rather than one tag per directory."""
|
||||
nested = (
|
||||
fake_dirs_multi_bucket["loras"]
|
||||
/ "my"
|
||||
/ "custom"
|
||||
/ "path"
|
||||
)
|
||||
nested.mkdir(parents=True)
|
||||
f = nested / "v0001.safetensors"
|
||||
f.touch()
|
||||
name, tags = get_name_and_tags_from_asset_path(str(f))
|
||||
assert name == "v0001.safetensors"
|
||||
assert tags == ["models", "loras/my/custom/path"]
|
||||
|
||||
|
||||
class TestResolveDestinationFromTags:
|
||||
"""resolve_destination_from_tags must accept both the legacy
|
||||
one-tag-per-directory shape and the new slash-joined shape so that an
|
||||
upload using the tags it just read back from /api/assets round-trips
|
||||
to the right on-disk destination.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def resolve_dirs(self):
|
||||
with tempfile.TemporaryDirectory() as root:
|
||||
root_path = Path(root)
|
||||
input_dir = root_path / "input"
|
||||
output_dir = root_path / "output"
|
||||
checkpoints_dir = root_path / "models" / "checkpoints"
|
||||
diffusers_dir = root_path / "models" / "diffusers"
|
||||
loras_dir = root_path / "models" / "loras"
|
||||
for d in (input_dir, output_dir, checkpoints_dir, diffusers_dir, loras_dir):
|
||||
d.mkdir(parents=True)
|
||||
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
|
||||
mock_fp.get_input_directory.return_value = str(input_dir)
|
||||
mock_fp.get_output_directory.return_value = str(output_dir)
|
||||
mock_fp.folder_names_and_paths = {
|
||||
"checkpoints": ([str(checkpoints_dir)], None),
|
||||
"diffusers": ([str(diffusers_dir)], None),
|
||||
"loras": ([str(loras_dir)], None),
|
||||
}
|
||||
yield {
|
||||
"input": input_dir,
|
||||
"output": output_dir,
|
||||
"checkpoints": checkpoints_dir,
|
||||
"diffusers": diffusers_dir,
|
||||
"loras": loras_dir,
|
||||
}
|
||||
|
||||
def test_models_flat_category(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["models", "checkpoints"])
|
||||
assert base == str(resolve_dirs["checkpoints"])
|
||||
assert subdirs == []
|
||||
|
||||
def test_models_slash_joined_new_shape(self, resolve_dirs):
|
||||
# The shape get_name_and_tags_from_asset_path now emits.
|
||||
base, subdirs = resolve_destination_from_tags(
|
||||
["models", "diffusers/kolors/text_encoder"]
|
||||
)
|
||||
assert base == str(resolve_dirs["diffusers"])
|
||||
assert subdirs == ["kolors", "text_encoder"]
|
||||
|
||||
def test_models_legacy_one_tag_per_dir(self, resolve_dirs):
|
||||
# The legacy shape must still resolve identically.
|
||||
base, subdirs = resolve_destination_from_tags(
|
||||
["models", "diffusers", "kolors", "text_encoder"]
|
||||
)
|
||||
assert base == str(resolve_dirs["diffusers"])
|
||||
assert subdirs == ["kolors", "text_encoder"]
|
||||
|
||||
def test_models_loras_slash_joined(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(
|
||||
["models", "loras/my/custom/path"]
|
||||
)
|
||||
assert base == str(resolve_dirs["loras"])
|
||||
assert subdirs == ["my", "custom", "path"]
|
||||
|
||||
def test_input_no_subdir(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["input"])
|
||||
assert base == str(resolve_dirs["input"])
|
||||
assert subdirs == []
|
||||
|
||||
def test_input_slash_joined_subdir(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["input", "portraits/2026"])
|
||||
assert base == str(resolve_dirs["input"])
|
||||
assert subdirs == ["portraits", "2026"]
|
||||
|
||||
def test_output_slash_joined_subdir(self, resolve_dirs):
|
||||
base, subdirs = resolve_destination_from_tags(["output", "runs/abc"])
|
||||
assert base == str(resolve_dirs["output"])
|
||||
assert subdirs == ["runs", "abc"]
|
||||
|
||||
def test_unknown_category_rejected(self, resolve_dirs):
|
||||
with pytest.raises(ValueError, match="unknown model category"):
|
||||
resolve_destination_from_tags(["models", "not_a_real_category"])
|
||||
|
||||
def test_unknown_category_via_slash_joined(self, resolve_dirs):
|
||||
# First segment of a slash-joined tag must still match a registered category.
|
||||
with pytest.raises(ValueError, match="unknown model category 'bogus'"):
|
||||
resolve_destination_from_tags(["models", "bogus/sub/path"])
|
||||
|
||||
def test_traversal_in_subdir_rejected(self, resolve_dirs):
|
||||
with pytest.raises(ValueError, match="invalid path component"):
|
||||
resolve_destination_from_tags(["models", "checkpoints/..", "evil"])
|
||||
|
||||
@ -32,7 +32,7 @@ def test_seed_asset_removed_when_file_is_deleted(
|
||||
# Verify it is visible via API and carries no hash (seed)
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
params={"include_tags": "unit-tests/syncseed", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body1 = r1.json()
|
||||
@ -52,7 +52,7 @@ def test_seed_asset_removed_when_file_is_deleted(
|
||||
# It should disappear (AssetInfo and seed Asset gone)
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
params={"include_tags": "unit-tests/syncseed", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body2 = r2.json()
|
||||
@ -332,7 +332,7 @@ def test_fastpass_removes_stale_state_row_no_missing(
|
||||
|
||||
rl = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}"},
|
||||
params={"include_tags": f"unit-tests/{scope}"},
|
||||
timeout=120,
|
||||
)
|
||||
bl = rl.json()
|
||||
|
||||
@ -280,9 +280,15 @@ def test_metadata_filename_is_set_for_seed_asset_without_hash(
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Scanner emits tags as ``[root, "<dir1>/<dir2>/..."]`` — the second tag
|
||||
# is the slash-joined parent subpath. For ``<root>/unit-tests/<scope>/a/b/<name>``
|
||||
# the second tag is ``"unit-tests/<scope>/a/b"``.
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
|
||||
params={
|
||||
"include_tags": f"unit-tests/{scope}/a/b",
|
||||
"name_contains": name,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
body = r1.json()
|
||||
|
||||
69
tests-unit/assets_test/test_helpers.py
Normal file
69
tests-unit/assets_test/test_helpers.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""Unit tests for app.assets.helpers."""
|
||||
|
||||
from app.assets.helpers import expand_bucket_prefixes
|
||||
|
||||
|
||||
class TestExpandBucketPrefixes:
|
||||
def test_flat_category_unchanged(self):
|
||||
# `checkpoints` is already a standalone token, no expansion needed.
|
||||
assert expand_bucket_prefixes(["models", "checkpoints"]) == [
|
||||
"models",
|
||||
"checkpoints",
|
||||
]
|
||||
|
||||
def test_nested_category_inserts_bucket(self):
|
||||
# Path-derived shape for `models/checkpoints/flux/foo.safetensors` —
|
||||
# the standalone bucket has to be present so the FE set-membership
|
||||
# filter (`include_tags=models,checkpoints`) matches the asset.
|
||||
assert expand_bucket_prefixes(["models", "checkpoints/flux"]) == [
|
||||
"models",
|
||||
"checkpoints/flux",
|
||||
"checkpoints",
|
||||
]
|
||||
|
||||
def test_deeply_nested_only_first_segment_expands(self):
|
||||
# Only the FIRST slash segment ever gets emitted as a standalone —
|
||||
# intermediate path segments don't have routing significance.
|
||||
assert expand_bucket_prefixes(
|
||||
["models", "diffusers/kolors/text_encoder"]
|
||||
) == ["models", "diffusers/kolors/text_encoder", "diffusers"]
|
||||
|
||||
def test_unknown_prefix_does_not_expand(self):
|
||||
# Free-form user labels with slashes whose first segment is not a
|
||||
# registered bucket pass through opaquely.
|
||||
assert expand_bucket_prefixes(["models", "my-org/team-a"]) == [
|
||||
"models",
|
||||
"my-org/team-a",
|
||||
]
|
||||
|
||||
def test_idempotent(self):
|
||||
# Re-applying the helper is a no-op once the bucket is in the set.
|
||||
expanded = expand_bucket_prefixes(["models", "checkpoints/flux"])
|
||||
assert expand_bucket_prefixes(expanded) == expanded
|
||||
|
||||
def test_does_not_duplicate_existing_bucket(self):
|
||||
# If the caller already supplied the standalone bucket, don't add a
|
||||
# second copy.
|
||||
assert expand_bucket_prefixes(
|
||||
["models", "checkpoints/flux", "checkpoints"]
|
||||
) == ["models", "checkpoints/flux", "checkpoints"]
|
||||
|
||||
def test_preserves_caller_order(self):
|
||||
# User tags after path tags must stay after; the inserted bucket
|
||||
# token slots in immediately after its slash-joined parent so the
|
||||
# microsecond stagger lands it at path-tier before user-tier.
|
||||
assert expand_bucket_prefixes(
|
||||
["models", "loras/style", "favorite", "v2"]
|
||||
) == ["models", "loras/style", "loras", "favorite", "v2"]
|
||||
|
||||
def test_empty_input(self):
|
||||
assert expand_bucket_prefixes([]) == []
|
||||
|
||||
def test_input_root_with_subpath_no_expansion(self):
|
||||
# `portraits` isn't a registered model category, so the input
|
||||
# subpath stays opaque (FE filter doesn't have a checkpoint-loader
|
||||
# analogue for input subfolders).
|
||||
assert expand_bucket_prefixes(["input", "portraits/2026"]) == [
|
||||
"input",
|
||||
"portraits/2026",
|
||||
]
|
||||
@ -29,7 +29,10 @@ def create_seed_file(comfy_tmp_base_dir: Path):
|
||||
def find_asset(http: requests.Session, api_base: str):
|
||||
"""Query API for assets matching scope and optional name."""
|
||||
def _find(scope: str, name: str | None = None) -> list[dict]:
|
||||
params = {"include_tags": f"unit-tests,{scope}"}
|
||||
# Scanner now emits tags as ``[root, "<dir1>/<dir2>/..."]`` rather than
|
||||
# one tag per directory. For files at ``<root>/unit-tests/<scope>/...``
|
||||
# the second tag is exactly ``"unit-tests/<scope>"``.
|
||||
params = {"include_tags": f"unit-tests/{scope}"}
|
||||
if name:
|
||||
params["name_contains"] = name
|
||||
r = http.get(f"{api_base}/api/assets", params=params, timeout=120)
|
||||
@ -138,4 +141,7 @@ def test_special_chars_in_path_escaped_correctly(
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
assert find_asset(scope.split("/")[0], fp.name), "Asset with special chars should survive"
|
||||
# Scanner emits the full parent subpath as a single slash-joined tag, so
|
||||
# the lookup tag is ``unit-tests/<scope>`` even when <scope> itself
|
||||
# contains a slash (parent + special-char dirname).
|
||||
assert find_asset(scope, fp.name), "Asset with special chars should survive"
|
||||
|
||||
135
tests-unit/assets_test/test_user_tag_http_smoke.py
Normal file
135
tests-unit/assets_test/test_user_tag_http_smoke.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""HTTP-layer smoke test: user-added tags via POST /api/assets/{id}/tags
|
||||
land after path tags when read back via GET /api/assets.
|
||||
|
||||
Exercises the full route handler -> service -> query path that the unit
|
||||
tests at tests-unit/assets_test/queries/test_asset_info.py only cover at
|
||||
the service layer.
|
||||
"""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def smoke_asset(http: requests.Session, api_base: str):
|
||||
"""Upload a single asset into models/checkpoints/unit-tests/smoke
|
||||
and delete it on teardown."""
|
||||
name = "smoke_user_tag.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "smoke"]
|
||||
files = {"file": (name, b"S" * 4096, "application/octet-stream")}
|
||||
form_data = {
|
||||
"tags": json.dumps(tags),
|
||||
"name": name,
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||
assert r.status_code == 201, r.text
|
||||
body = r.json()
|
||||
yield body
|
||||
http.delete(
|
||||
f"{api_base}/api/assets/{body['id']}?delete_content=true", timeout=30
|
||||
)
|
||||
|
||||
|
||||
def _fetch_asset_tags(http, api_base, ref_id):
|
||||
r = http.get(f"{api_base}/api/assets/{ref_id}", timeout=30)
|
||||
assert r.status_code == 200, r.text
|
||||
return r.json()["tags"]
|
||||
|
||||
|
||||
def test_user_tag_lands_after_path_tags_via_http(
|
||||
http: requests.Session, api_base: str, smoke_asset: dict
|
||||
):
|
||||
ref_id = smoke_asset["id"]
|
||||
|
||||
initial_tags = _fetch_asset_tags(http, api_base, ref_id)
|
||||
# Path tags should already be at the front in upload order.
|
||||
assert initial_tags[:2] == ["models", "checkpoints"]
|
||||
|
||||
# Add a user tag that would jump to position 0 under alphabetical sort.
|
||||
r = http.post(
|
||||
f"{api_base}/api/assets/{ref_id}/tags",
|
||||
json={"tags": ["aaa-user-tag"]},
|
||||
timeout=30,
|
||||
)
|
||||
assert r.status_code in (200, 201), r.text
|
||||
|
||||
tags_after = _fetch_asset_tags(http, api_base, ref_id)
|
||||
# Path tags must still be at the front; user tag goes to the end.
|
||||
assert tags_after[0] == "models"
|
||||
assert tags_after[1] == "checkpoints"
|
||||
assert "aaa-user-tag" in tags_after
|
||||
assert tags_after[-1] == "aaa-user-tag"
|
||||
|
||||
|
||||
def test_user_tag_batch_lands_after_path_tags_via_http(
|
||||
http: requests.Session, api_base: str, smoke_asset: dict
|
||||
):
|
||||
ref_id = smoke_asset["id"]
|
||||
|
||||
# Add three user tags in a single request, in non-alphabetical input
|
||||
# order. They should all land after the path tags (microsecond stagger
|
||||
# in set_reference_tags / add_tags_to_reference is what makes this
|
||||
# work — without it, "aaa" would jump to position 0).
|
||||
r = http.post(
|
||||
f"{api_base}/api/assets/{ref_id}/tags",
|
||||
json={"tags": ["zzz-z", "favorite", "aaa-experiment"]},
|
||||
timeout=30,
|
||||
)
|
||||
assert r.status_code in (200, 201), r.text
|
||||
|
||||
tags_after = _fetch_asset_tags(http, api_base, ref_id)
|
||||
assert tags_after[0] == "models"
|
||||
assert tags_after[1] == "checkpoints"
|
||||
user_tail = tags_after[len({"models", "checkpoints", "unit-tests", "smoke"}):]
|
||||
assert set(user_tail) >= {"zzz-z", "favorite", "aaa-experiment"}
|
||||
# Critically: alphabetical sort would put 'aaa-experiment' at position 0.
|
||||
assert tags_after.index("aaa-experiment") > tags_after.index("models")
|
||||
assert tags_after.index("aaa-experiment") > tags_after.index("checkpoints")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nested_checkpoint_asset(http: requests.Session, api_base: str):
|
||||
"""Upload a checkpoint at the slash-joined path shape cloud emits
|
||||
(`models/checkpoints/flux/...`), then delete it on teardown.
|
||||
"""
|
||||
name = "nested_checkpoint.safetensors"
|
||||
tags = ["models", "checkpoints/flux"]
|
||||
files = {"file": (name, b"S" * 4096, "application/octet-stream")}
|
||||
form_data = {
|
||||
"tags": json.dumps(tags),
|
||||
"name": name,
|
||||
"user_metadata": json.dumps({}),
|
||||
}
|
||||
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||
assert r.status_code == 201, r.text
|
||||
body = r.json()
|
||||
yield body
|
||||
http.delete(
|
||||
f"{api_base}/api/assets/{body['id']}?delete_content=true", timeout=30
|
||||
)
|
||||
|
||||
|
||||
def test_nested_checkpoint_satisfies_fe_set_filter(
|
||||
http: requests.Session, api_base: str, nested_checkpoint_asset: dict
|
||||
):
|
||||
"""The case Simon flagged: a nested-path checkpoint must still match
|
||||
`include_tags=models,checkpoints` — the FE combo-widget filter.
|
||||
"""
|
||||
ref_id = nested_checkpoint_asset["id"]
|
||||
|
||||
stored = _fetch_asset_tags(http, api_base, ref_id)
|
||||
# tag[1] keeps cloud's slash-joined positional contract; tag[2] holds
|
||||
# the standalone bucket the FE filter looks for.
|
||||
assert stored[:3] == ["models", "checkpoints/flux", "checkpoints"]
|
||||
|
||||
# The actual FE query — exact set-membership across both tokens.
|
||||
r = http.get(
|
||||
f"{api_base}/api/assets",
|
||||
params=[("include_tags", "models"), ("include_tags", "checkpoints")],
|
||||
timeout=30,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
returned_ids = {a["id"] for a in r.json()["assets"]}
|
||||
assert ref_id in returned_ids
|
||||
Reference in New Issue
Block a user