Compare commits

..

1 Commits

Author SHA1 Message Date
e4f3d335dc feat: Add VideoSlice node with lazy operations on VideoInput
- Add VideoOp base class and SliceOp in _input/video_types.py
- Add sliced() method to VideoInput that returns a copy with operation appended
- Each subclass applies operations in get_components() and get_frame_count()
- After materialization, VideoFromFile delegates to internal VideoFromComponents
- Add VideoSlice node that uses video.sliced(start_frame, frame_count)
- Add tests for SliceOp, sliced() behavior, and materialization
2026-01-23 20:52:15 -08:00
58 changed files with 471 additions and 2661 deletions

1
.gitignore vendored
View File

@ -21,7 +21,6 @@ venv/
*.log
web_custom_versions/
.DS_Store
*:Zone.Identifier
openapi.yaml
filtered-openapi.yaml
uv.lock

139
PLAN.md
View File

@ -1,139 +0,0 @@
# Plan: Align Local Asset/Tag Endpoints with Cloud
## Endpoint Comparison
| Endpoint | Cloud (openapi.yaml) | Local (routes.py) |
|----------|---------------------|-------------------|
| `GET /api/assets` | ✅ + `include_public` param | ✅ |
| `POST /api/assets` | ✅ multipart + JSON URL upload | ✅ multipart only |
| `GET /api/assets/{id}` | ✅ | ✅ |
| `PUT /api/assets/{id}` | ✅ (`name`, `mime_type`, `preview_id`, `user_metadata`) | ✅ (`name`, `tags`, `user_metadata`) |
| `DELETE /api/assets/{id}` | ✅ | ✅ |
| `GET /api/assets/{id}/content` | ❌ | ✅ |
| `POST /api/assets/{id}/tags` | ✅ | ✅ |
| `DELETE /api/assets/{id}/tags` | ✅ | ✅ |
| `PUT /api/assets/{id}/preview` | ❌ | ✅ |
| `POST /api/assets/from-hash` | ✅ | ✅ |
| `HEAD /api/assets/hash/{hash}` | ✅ | ✅ |
| `GET /api/assets/remote-metadata` | ✅ | ❌ |
| `POST /api/assets/download` | ✅ (background download) | ❌ |
| `GET /api/assets/tags/refine` | ✅ (tag histogram) | ❌ |
| `GET /api/tags` | ✅ + `include_public` param | ✅ |
| `POST /api/assets/scan/seed` | ❌ | ✅ (local only) |
---
## Phase 1: Add Missing Cloud Endpoints to Local
### 1.1 `GET /api/assets/remote-metadata` *(deferred)*
Fetch metadata from remote URLs (CivitAI, HuggingFace) without downloading the file.
**Status:** Not supported yet. Add stub/placeholder that returns 501 Not Implemented.
**Parameters:**
- `url` (required): Download URL to retrieve metadata from
**Returns:** Asset metadata (name, size, hash if available, etc.)
### 1.2 `POST /api/assets/download` *(deferred)*
Initiate background download job for large files from HuggingFace or CivitAI.
**Status:** Not supported yet. Add stub/placeholder that returns 501 Not Implemented.
**Request body:**
- `source_url` (required): URL to download from
- `tags`: Optional tags for the asset
- `user_metadata`: Optional metadata
- `preview_id`: Optional preview asset ID
**Returns:**
- 200 if file already exists (returns asset immediately)
- 202 with `task_id` for background download tracking via `GET /api/tasks/{task_id}`
### 1.3 `GET /api/assets/tags/refine`
Get tag histogram for filtered assets (useful for search refinement UI).
**Parameters:**
- `include_tags`: Filter assets with ALL these tags
- `exclude_tags`: Exclude assets with ANY of these tags
- `name_contains`: Filter by name substring
- `metadata_filter`: JSON filter for metadata fields
- `limit`: Max tags to return (default 100)
- `include_public`: Include public/shared assets
**Returns:** List of tags with counts for matching assets
---
## Phase 2: Update Existing Endpoints for Parity
### 2.1 `GET /api/assets`
- Add `include_public` query parameter (boolean, default true)
### 2.2 `POST /api/assets`
- Add JSON body upload path for URL-based uploads:
```json
{
"url": "https://...",
"name": "model.safetensors",
"tags": ["models", "checkpoints"],
"user_metadata": {},
"preview_id": "uuid"
}
```
- Keep existing multipart upload support
### 2.3 `PUT /api/assets/{id}`
- Add `mime_type` field support
- Add `preview_id` field support
- Remove direct `tags` field (recommend using dedicated `POST/DELETE /api/assets/{id}/tags` endpoints instead)
### 2.4 `GET /api/tags`
- Add `include_public` query parameter (boolean, default true)
---
## Phase 3: Local-Only Endpoints
These endpoints exist locally but not in cloud.
### 3.1 `GET /api/assets/{id}/content`
Download asset file content. Cloud uses signed URLs instead. **Keep for local.**
### 3.2 `PUT /api/assets/{id}/preview`
**Remove this endpoint.** Merge functionality into `PUT /api/assets/{id}` by adding `preview_id` field support (aligns with cloud).
### 3.3 `POST /api/assets/scan/seed`
Filesystem seeding/scanning for local asset discovery. Not applicable to cloud. **Keep as local-only.**
---
## Phase 4: Testing
Add tests for all new and modified endpoints to ensure functionality matches cloud behavior.
### 4.1 New Endpoint Tests
- `GET /api/assets/remote-metadata` Test with valid/invalid URLs, various sources (CivitAI, HuggingFace)
- `POST /api/assets/download` Test background download initiation, existing file detection, task tracking
- `GET /api/assets/tags/refine` Test histogram generation with various filter combinations
### 4.2 Updated Endpoint Tests
- `GET /api/assets` Test `include_public` param filtering
- `POST /api/assets` Test JSON URL upload path alongside existing multipart tests
- `PUT /api/assets/{id}` Test `mime_type` and `preview_id` field updates
- `GET /api/tags` Test `include_public` param filtering
### 4.3 Removed Endpoint Tests
- Remove tests for `PUT /api/assets/{id}/preview`
- Add tests for `preview_id` in `PUT /api/assets/{id}` to cover the merged functionality
---
## Implementation Order
1. Phase 2.1, 2.4 Add `include_public` params (low effort, high compatibility)
2. Phase 2.3 Update PUT endpoint fields + remove preview endpoint
3. Phase 2.2 Add JSON URL upload to POST
4. Phase 1.3 Add tags/refine endpoint
5. Phase 1.1, 1.2 Add stub endpoints returning 501 (deferred implementation)
6. Phase 4 Add tests for each phase as implemented

View File

@ -1,20 +1,14 @@
import logging
import uuid
import urllib.parse
import os
import contextlib
from aiohttp import web
from pydantic import ValidationError
import app.assets.manager as manager
import app.assets.scanner as scanner
from app import user_manager
from app.assets.api import schemas_in
from app.assets.helpers import get_query_dict
import folder_paths
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
@ -34,18 +28,6 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
@ROUTES.head("/api/assets/hash/{hash}")
async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = hash_str.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
exists = manager.asset_exists(asset_hash=hash_str)
return web.Response(status=200 if exists else 404)
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
"""
@ -67,29 +49,10 @@ async def list_assets(request: web.Request) -> web.Response:
sort=q.sort,
order=q.order,
owner_id=USER_MANAGER.get_request_user_id(request),
include_public=q.include_public,
)
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.get("/api/assets/remote-metadata")
async def get_remote_asset_metadata(request: web.Request) -> web.Response:
"""
Fetch metadata from remote URLs (CivitAI, HuggingFace) without downloading.
Status: Not implemented yet.
"""
return _error_response(501, "NOT_IMPLEMENTED", "Remote metadata fetching is not yet supported.")
@ROUTES.post("/api/assets/download")
async def create_asset_download(request: web.Request) -> web.Response:
"""
Initiate background download job for large files from HuggingFace or CivitAI.
Status: Not implemented yet.
"""
return _error_response(501, "NOT_IMPLEMENTED", "Background asset download is not yet supported.")
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
async def get_asset(request: web.Request) -> web.Response:
"""
@ -113,306 +76,6 @@ async def get_asset(request: web.Request) -> web.Response:
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
async def download_asset_content(request: web.Request) -> web.Response:
# question: do we need disposition? could we just stick with one of these?
disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}:
disposition = "attachment"
try:
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
asset_info_id=str(uuid.UUID(request.match_info["id"])),
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
except NotImplementedError as nie:
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
except FileNotFoundError:
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
resp = web.FileResponse(abs_path)
resp.content_type = content_type
resp.headers["Content-Disposition"] = cd
return resp
@ROUTES.post("/api/assets/from-hash")
async def create_asset_from_hash(request: web.Request) -> web.Response:
try:
payload = await request.json()
body = schemas_in.CreateFromHashBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
result = manager.create_asset_from_hash(
hash_str=body.hash,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
return web.json_response(result.model_dump(mode="json"), status=201)
@ROUTES.post("/api/assets")
async def upload_asset(request: web.Request) -> web.Response:
"""Asset upload endpoint supporting multipart/form-data (file upload) or application/json (URL-based)."""
content_type = (request.content_type or "").lower()
if content_type.startswith("application/json"):
try:
payload = await request.json()
schemas_in.UploadAssetFromUrlBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
return _error_response(501, "NOT_IMPLEMENTED", "URL-based asset upload is not yet supported. Use multipart/form-data file upload.")
if not content_type.startswith("multipart/"):
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads or application/json for URL-based uploads.")
reader = await request.multipart()
file_present = False
file_client_name: str | None = None
tags_raw: list[str] = []
provided_name: str | None = None
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
file_written = 0
tmp_path: str | None = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
if s:
if ":" not in s:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
provided_hash = f"{algo}:{digest}"
try:
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
except Exception:
provided_hash_exists = None # do not fail the whole request here
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
continue # Do not create temp file; we will create AssetInfo from the existing content
# Otherwise, store to temp for hashing/ingest
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
try:
if os.path.exists(tmp_path or ""):
os.remove(tmp_path)
finally:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
# If client did not send file, and we are not doing a from-hash fast path -> error
if not file_present and not (provided_hash and provided_hash_exists):
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
# Empty upload is only acceptable if we are fast-pathing from existing hash
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
try:
spec = schemas_in.UploadAssetSpec.model_validate({
"tags": tags_raw,
"name": provided_name,
"user_metadata": user_metadata_raw,
"hash": provided_hash,
})
except ValidationError as ve:
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _validation_error_response("INVALID_BODY", ve)
# Validate models category against configured folders (consistent with previous behavior)
if spec.tags and spec.tags[0] == "models":
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
)
owner_id = USER_MANAGER.get_request_user_id(request)
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
if spec.hash and provided_hash_exists is True:
try:
result = manager.create_asset_from_hash(
hash_str=spec.hash,
name=spec.name or (spec.hash.split(":", 1)[1]),
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
)
except Exception:
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
# Drain temp if we accidentally saved (e.g., hash field came after file)
if tmp_path and os.path.exists(tmp_path):
with contextlib.suppress(Exception):
os.remove(tmp_path)
status = 200 if (not result.created_new) else 201
return web.json_response(result.model_dump(mode="json"), status=status)
# Otherwise, we must have a temp file path to ingest
if not tmp_path or not os.path.exists(tmp_path):
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
try:
created = manager.upload_asset_from_temp_path(
spec,
temp_path=tmp_path,
client_filename=file_client_name,
owner_id=owner_id,
expected_asset_hash=spec.hash,
)
status = 201 if created.created_new else 200
return web.json_response(created.model_dump(mode="json"), status=status)
except ValueError as e:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
msg = str(e)
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
return _error_response(
400,
"HASH_MISMATCH",
"Uploaded file hash does not match provided hash.",
)
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
except Exception:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
async def update_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.update_asset(
asset_info_id=asset_info_id,
name=body.name,
mime_type=body.mime_type,
preview_id=body.preview_id,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"update_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
async def delete_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
delete_content = request.query.get("delete_content")
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
try:
deleted = manager.delete_asset_reference(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
delete_content_if_orphan=delete_content,
)
except Exception:
logging.exception(
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if not deleted:
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
return web.Response(status=204)
@ROUTES.get("/api/tags")
async def get_tags(request: web.Request) -> web.Response:
"""
@ -435,109 +98,5 @@ async def get_tags(request: web.Request) -> web.Response:
order=query.order,
include_zero=query.include_zero,
owner_id=USER_MANAGER.get_request_user_id(request),
include_public=query.include_public,
)
return web.json_response(result.model_dump(mode="json"))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def add_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsAdd.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.add_tags_to_asset(
asset_info_id=asset_info_id,
tags=data.tags,
origin="manual",
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def delete_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsRemove.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.remove_tags_from_asset(
asset_info_id=asset_info_id,
tags=data.tags,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.get("/api/assets/tags/refine")
async def get_asset_tag_histogram(request: web.Request) -> web.Response:
"""
GET request to get a tag histogram for filtered assets.
"""
query_dict = get_query_dict(request)
try:
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
except ValidationError as ve:
return _validation_error_response("INVALID_QUERY", ve)
payload = manager.get_tag_histogram(
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
include_public=q.include_public,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.post("/api/assets/scan/seed")
async def seed_assets(request: web.Request) -> web.Response:
try:
payload = await request.json()
except Exception:
payload = {}
try:
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
try:
scanner.seed_assets(body.roots)
except Exception:
logging.exception("seed_assets failed for roots=%s", body.roots)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response({"synced": True, "roots": body.roots}, status=200)

View File

@ -8,10 +8,8 @@ from pydantic import (
Field,
conint,
field_validator,
model_validator,
)
from app.assets.helpers import RootType
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
@ -27,8 +25,6 @@ class ListAssetsQuery(BaseModel):
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
order: Literal["asc", "desc"] = "desc"
include_public: bool = True
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
@ -61,73 +57,6 @@ class ListAssetsQuery(BaseModel):
return None
class UpdateAssetBody(BaseModel):
name: str | None = None
mime_type: str | None = None
preview_id: str | None = None
user_metadata: dict[str, Any] | None = None
@field_validator("preview_id", mode="before")
@classmethod
def _norm_uuid(cls, v):
if v is None:
return None
s = str(v).strip()
if not s:
return None
try:
uuid.UUID(s)
except Exception:
raise ValueError("preview_id must be a UUID")
return s
@model_validator(mode="after")
def _at_least_one(self):
if self.name is None and self.mime_type is None and self.preview_id is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, mime_type, preview_id, user_metadata.")
return self
class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str
name: str
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
@field_validator("hash")
@classmethod
def _require_blake3(cls, v):
s = (v or "").strip().lower()
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return s
@field_validator("tags", mode="before")
@classmethod
def _tags_norm(cls, v):
if v is None:
return []
if isinstance(v, list):
out = [str(t).strip().lower() for t in v if str(t).strip()]
seen = set()
dedup = []
for t in out:
if t not in seen:
seen.add(t)
dedup.append(t)
return dedup
if isinstance(v, str):
return [t.strip().lower() for t in v.split(",") if t.strip()]
return []
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
@ -136,7 +65,6 @@ class TagsListQuery(BaseModel):
offset: int = Field(0, ge=0, le=10_000_000)
order: Literal["count_desc", "name_asc"] = "count_desc"
include_zero: bool = True
include_public: bool = True
@field_validator("prefix")
@classmethod
@ -147,188 +75,10 @@ class TagsListQuery(BaseModel):
return v.lower() or None
class TagsAdd(BaseModel):
model_config = ConfigDict(extra="ignore")
tags: list[str] = Field(..., min_length=1)
@field_validator("tags")
@classmethod
def normalize_tags(cls, v: list[str]) -> list[str]:
out = []
for t in v:
if not isinstance(t, str):
raise TypeError("tags must be strings")
tnorm = t.strip().lower()
if tnorm:
out.append(tnorm)
seen = set()
deduplicated = []
for x in out:
if x not in seen:
seen.add(x)
deduplicated.append(x)
return deduplicated
class TagsRemove(TagsAdd):
pass
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
and the original extension is preserved when available.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1)
name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None)
@field_validator("hash", mode="before")
@classmethod
def _parse_hash(cls, v):
if v is None:
return None
s = str(v).strip().lower()
if not s:
return None
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return f"{algo}:{digest}"
@field_validator("tags", mode="before")
@classmethod
def _parse_tags(cls, v):
"""
Accepts a list of strings (possibly multiple form fields),
where each string can be:
- JSON array (e.g., '["models","loras","foo"]')
- comma-separated ('models, loras, foo')
- single token ('models')
Returns a normalized, deduplicated, ordered list.
"""
items: list[str] = []
if v is None:
return []
if isinstance(v, str):
v = [v]
if isinstance(v, list):
for item in v:
if item is None:
continue
s = str(item).strip()
if not s:
continue
if s.startswith("["):
try:
arr = json.loads(s)
if isinstance(arr, list):
items.extend(str(x) for x in arr)
continue
except Exception:
pass # fallback to CSV parse below
items.extend([p for p in s.split(",") if p.strip()])
else:
return []
# normalize + dedupe
norm = []
seen = set()
for t in items:
tnorm = str(t).strip().lower()
if tnorm and tnorm not in seen:
seen.add(tnorm)
norm.append(tnorm)
return norm
@field_validator("user_metadata", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v or {}
if isinstance(v, str):
s = v.strip()
if not s:
return {}
try:
parsed = json.loads(s)
except Exception as e:
raise ValueError(f"user_metadata must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("user_metadata must be a JSON object")
return parsed
return {}
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
raise ValueError("tags must be provided and non-empty")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError("models uploads require a category tag as the second tag")
return self
class UploadAssetFromUrlBody(BaseModel):
"""JSON body for URL-based asset upload."""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
url: str = Field(..., description="HTTP/HTTPS URL to download the asset from")
name: str = Field(..., max_length=512, description="Display name for the asset")
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
class SetPreviewBody(BaseModel):
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
preview_id: str | None = None
@field_validator("url")
@classmethod
def _validate_url(cls, v):
s = (v or "").strip()
if not s:
raise ValueError("url must not be empty")
if not (s.startswith("http://") or s.startswith("https://")):
raise ValueError("url must start with http:// or https://")
return s
@field_validator("tags", mode="before")
@classmethod
def _parse_tags(cls, v):
if v is None:
return []
if isinstance(v, list):
out = [str(t).strip().lower() for t in v if str(t).strip()]
seen = set()
dedup = []
for t in out:
if t not in seen:
seen.add(t)
dedup.append(t)
return dedup
return []
@field_validator("user_metadata", mode="before")
@classmethod
def _parse_metadata(cls, v):
if v is None or isinstance(v, dict):
return v or {}
return {}
@field_validator("preview_id", mode="before")
@classmethod
def _norm_uuid(cls, v):
@ -342,49 +92,3 @@ class UploadAssetFromUrlBody(BaseModel):
except Exception:
raise ValueError("preview_id must be a UUID")
return s
class ScheduleAssetScanBody(BaseModel):
roots: list[RootType] = Field(..., min_length=1)
class TagsRefineQuery(BaseModel):
"""Query parameters for tag histogram/refinement endpoint."""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=1000) = 100
include_public: bool = True
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None

View File

@ -29,21 +29,6 @@ class AssetsList(BaseModel):
has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _ser_updated(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
@ -63,10 +48,6 @@ class AssetDetail(BaseModel):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel):
name: str
count: int
@ -77,26 +58,3 @@ class TagsList(BaseModel):
tags: list[TagUsage] = Field(default_factory=list)
total: int
has_more: bool
class TagsAdd(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
added: list[str] = Field(default_factory=list)
already_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class TagsRemove(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class TagHistogramEntry(BaseModel):
name: str
count: int
class TagHistogramResponse(BaseModel):
tags: list[TagHistogramEntry] = Field(default_factory=list)

View File

@ -1,17 +1,9 @@
import os
import logging
import sqlalchemy as sa
from collections import defaultdict
from datetime import datetime
from typing import Iterable, Any
from sqlalchemy import select, delete, exists, func
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy import select, exists, func
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import (
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
)
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import escape_like_prefix, normalize_tags
from typing import Sequence
@ -23,22 +15,6 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
return AssetInfo.owner_id.in_(["", owner_id])
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
@ -66,7 +42,6 @@ def apply_tag_filters(
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
@ -119,11 +94,7 @@ def apply_metadata_filter(
return stmt
def asset_exists_by_hash(
session: Session,
*,
asset_hash: str,
) -> bool:
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
@ -134,39 +105,9 @@ def asset_exists_by_hash(
).first()
return row is not None
def asset_info_exists_for_asset_id(
session: Session,
*,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_by_hash(
session: Session,
*,
asset_hash: str,
) -> Asset | None:
return (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
def get_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
) -> AssetInfo | None:
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def list_asset_infos_page(
session: Session,
owner_id: str = "",
@ -236,7 +177,6 @@ def list_asset_infos_page(
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
@ -268,489 +208,6 @@ def fetch_asset_info_asset_and_tags(
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def touch_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
session.execute(stmt.values(last_access_time=ts))
def create_asset_info_for_existing_asset(
session: Session,
*,
asset_hash: str,
name: str,
user_metadata: dict | None = None,
tags: Sequence[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
with session.begin_nested():
session.add(info)
session.flush()
except IntegrityError:
existing = (
session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
def set_asset_info_tags(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def replace_asset_info_metadata_projection(
session: Session,
*,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
session.flush()
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def ingest_fs_asset(
session: Session,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: dict | None = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
res = session.execute(
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
def update_asset_info_full(
session: Session,
*,
asset_info_id: str,
name: str | None = None,
mime_type: str | None = None,
user_metadata: dict | None = None,
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
if mime_type is not None and info.asset:
if info.asset.mime_type != mime_type:
info.asset.mime_type = mime_type
touched = True
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
session.flush()
return info
def delete_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
@ -808,163 +265,3 @@ def list_tags_with_usage(
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
return [
tag_name for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def add_tags_to_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)
def set_asset_info_preview(
session: Session,
*,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
session.flush()

View File

@ -1,6 +1,5 @@
import contextlib
import os
from decimal import Decimal
from aiohttp import web
from datetime import datetime, timezone
from pathlib import Path
@ -88,40 +87,6 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def ensure_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
@ -148,6 +113,7 @@ def compute_relative_filename(file_path: str) -> str | None:
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
@ -249,64 +215,3 @@ def collect_models_files() -> list[str]:
if allowed:
out.append(abs_path)
return out
def is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def project_kv(key: str, value):
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@ -1,34 +1,13 @@
import os
import mimetypes
import contextlib
from typing import Sequence
from app.database.db import create_session
from app.assets.api import schemas_out, schemas_in
from app.assets.api import schemas_out
from app.assets.database.queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
create_asset_info_for_existing_asset,
touch_asset_info_by_id,
update_asset_info_full,
delete_asset_info_by_id,
list_cache_states_by_asset_id,
list_asset_infos_page,
list_tags_with_usage,
get_asset_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
pick_best_live_path,
ingest_fs_asset,
set_asset_info_preview,
)
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
from app.assets.database.models import Asset
import app.assets.hashing as hashing
def _safe_sort_field(requested: str | None) -> str:
@ -40,28 +19,11 @@ def _safe_sort_field(requested: str | None) -> str:
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback
def asset_exists(*, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
def asset_exists(asset_hash: str) -> bool:
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def list_assets(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
@ -71,7 +33,6 @@ def list_assets(
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
include_public: bool = True,
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
@ -115,12 +76,7 @@ def list_assets(
has_more=(offset + len(summaries)) < total,
)
def get_asset(
*,
asset_info_id: str,
owner_id: str = "",
) -> schemas_out.AssetDetail:
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
with create_session() as session:
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
@ -141,356 +97,6 @@ def get_asset(
last_access_time=info.last_access_time,
)
def resolve_asset_content_for_download(
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
with create_session() as session:
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = pick_best_live_path(states)
if not abs_path:
raise FileNotFoundError
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
def upload_asset_from_temp_path(
spec: schemas_in.UploadAssetSpec,
*,
temp_path: str,
client_filename: str | None = None,
owner_id: str = "",
expected_asset_hash: str | None = None,
) -> schemas_out.AssetCreated:
try:
digest = hashing.blake3_hash(temp_path)
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
info = create_asset_info_for_existing_asset(
session,
asset_hash=asset_hash,
name=display_name,
user_metadata=spec.user_metadata or {},
tags=spec.tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or spec.name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
ensure_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
with create_session() as session:
result = ingest_fs_asset(
session,
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
)
def update_asset(
*,
asset_info_id: str,
name: str | None = None,
mime_type: str | None = None,
preview_id: str | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
mime_type=mime_type,
user_metadata=user_metadata,
asset_info_row=info_row,
)
if preview_id is not None:
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_id if preview_id else None,
)
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
session.commit()
return schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
session.commit()
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
session.commit()
return False
if not delete_content_if_orphan or not asset_id:
session.commit()
return True
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetCreated | None:
canonical = hash_str.strip().lower()
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
def add_tags_to_asset(
*,
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
session.commit()
return schemas_out.TagsAdd(**data)
def remove_tags_from_asset(
*,
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
session.commit()
return schemas_out.TagsRemove(**data)
def list_tags(
prefix: str | None = None,
limit: int = 100,
@ -498,7 +104,6 @@ def list_tags(
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
include_public: bool = True,
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
@ -516,17 +121,3 @@ def list_tags(
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
def get_tag_histogram(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
include_public: bool = True,
owner_id: str = "",
) -> schemas_out.TagHistogramResponse:
# TODO: Implement actual histogram query in queries.py
return schemas_out.TagHistogramResponse(tags=[])

View File

@ -8,7 +8,6 @@ class LatentFormat:
latent_rgb_factors_bias = None
latent_rgb_factors_reshape = None
taesd_decoder_name = None
spacial_downscale_ratio = 8
def process_in(self, latent):
return latent * self.scale_factor
@ -182,7 +181,6 @@ class Flux(SD3):
class Flux2(LatentFormat):
latent_channels = 128
spacial_downscale_ratio = 16
def __init__(self):
self.latent_rgb_factors =[
@ -751,7 +749,6 @@ class ACEAudio(LatentFormat):
class ChromaRadiance(LatentFormat):
latent_channels = 3
spacial_downscale_ratio = 1
def __init__(self):
self.latent_rgb_factors = [

View File

@ -18,12 +18,12 @@ class CompressedTimestep:
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
"""
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
"""
self.batch_size, num_tokens, self.feature_dim = tensor.shape
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
self.patches_per_frame = patches_per_frame
self.num_frames = num_tokens // patches_per_frame
@ -215,9 +215,22 @@ class BasicAVTransformerBlock(nn.Module):
return (*scale_shift_ada_values, *gate_ada_values)
def forward(
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
self,
x: Tuple[torch.Tensor, torch.Tensor],
v_context=None,
a_context=None,
attention_mask=None,
v_timestep=None,
a_timestep=None,
v_pe=None,
a_pe=None,
v_cross_pe=None,
a_cross_pe=None,
v_cross_scale_shift_timestep=None,
a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None,
a_cross_gate_timestep=None,
transformer_options=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True)
@ -227,102 +240,144 @@ class BasicAVTransformerBlock(nn.Module):
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
# video
if run_vx:
# video self-attention
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
vshift_msa, vscale_msa, vgate_msa = (
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
)
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
del vshift_msa, vscale_msa
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
del norm_vx
# video cross-attention
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
vx.addcmul_(attn1_out, vgate_msa)
del vgate_msa, attn1_out
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
vx += self.attn2(
comfy.ldm.common_dit.rms_norm(vx),
context=v_context,
mask=attention_mask,
transformer_options=transformer_options,
)
del vshift_msa, vscale_msa, vgate_msa
# audio
if run_ax:
# audio self-attention
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
del ashift_msa, ascale_msa
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
del norm_ax
# audio cross-attention
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
ax.addcmul_(attn1_out, agate_msa)
del agate_msa, attn1_out
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
ashift_msa, ascale_msa, agate_msa = (
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
)
# video - audio cross attention.
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
ax += (
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
* agate_msa
)
ax += self.audio_attn2(
comfy.ldm.common_dit.rms_norm(ax),
context=a_context,
mask=attention_mask,
transformer_options=transformer_options,
)
del ashift_msa, ascale_msa, agate_msa
# Audio - Video cross attention.
if run_a2v or run_v2a:
# norm3
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
# audio to video cross attention
(
scale_ca_audio_hidden_states_a2v,
shift_ca_audio_hidden_states_a2v,
scale_ca_audio_hidden_states_v2a,
shift_ca_audio_hidden_states_v2a,
gate_out_v2a,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
a_cross_scale_shift_timestep,
a_cross_gate_timestep,
)
(
scale_ca_video_hidden_states_a2v,
shift_ca_video_hidden_states_a2v,
scale_ca_video_hidden_states_v2a,
shift_ca_video_hidden_states_v2a,
gate_out_a2v,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
v_cross_scale_shift_timestep,
v_cross_gate_timestep,
)
if run_a2v:
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
vx_scaled = (
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
+ shift_ca_video_hidden_states_a2v
)
ax_scaled = (
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
+ shift_ca_audio_hidden_states_a2v
)
vx += (
self.audio_to_video_attn(
vx_scaled,
context=ax_scaled,
pe=v_cross_pe,
k_pe=a_cross_pe,
transformer_options=transformer_options,
)
* gate_out_a2v
)
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
del gate_out_a2v
del scale_ca_video_hidden_states_a2v,\
shift_ca_video_hidden_states_a2v,\
scale_ca_audio_hidden_states_a2v,\
shift_ca_audio_hidden_states_a2v,\
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
del vx_scaled, ax_scaled
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
vx.addcmul_(a2v_out, gate_out_a2v)
del gate_out_a2v, a2v_out
# video to audio cross attention
if run_v2a:
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
ax_scaled = (
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
+ shift_ca_audio_hidden_states_v2a
)
vx_scaled = (
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
+ shift_ca_video_hidden_states_v2a
)
ax += (
self.video_to_audio_attn(
ax_scaled,
context=vx_scaled,
pe=a_cross_pe,
k_pe=v_cross_pe,
transformer_options=transformer_options,
)
* gate_out_v2a
)
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
del gate_out_v2a
del scale_ca_video_hidden_states_v2a,\
shift_ca_video_hidden_states_v2a,\
scale_ca_audio_hidden_states_v2a,\
shift_ca_audio_hidden_states_v2a
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
del ax_scaled, vx_scaled
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
ax.addcmul_(v2a_out, gate_out_v2a)
del gate_out_v2a, v2a_out
del vx_norm3, ax_norm3
# video feedforward
if run_vx:
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
vshift_mlp, vscale_mlp, vgate_mlp = (
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
)
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
del vshift_mlp, vscale_mlp
vx += self.ff(vx_scaled) * vgate_mlp
del vshift_mlp, vscale_mlp, vgate_mlp
ff_out = self.ff(vx_scaled)
del vx_scaled
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
vx.addcmul_(ff_out, vgate_mlp)
del vgate_mlp, ff_out
# audio feedforward
if run_ax:
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
ashift_mlp, ascale_mlp, agate_mlp = (
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
)
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
del ashift_mlp, ascale_mlp
ax += self.audio_ff(ax_scaled) * agate_mlp
ff_out = self.audio_ff(ax_scaled)
del ax_scaled
del ashift_mlp, ascale_mlp, agate_mlp
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
ax.addcmul_(ff_out, agate_mlp)
del agate_mlp, ff_out
return vx, ax
@ -534,20 +589,9 @@ class LTXAVModel(LTXVModel):
audio_length = kwargs.get("audio_length", 0)
# Separate audio and video latents
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
has_spatial_mask = False
if denoise_mask is not None:
# check if any frame has spatial variation (inpainting)
for frame_idx in range(denoise_mask.shape[2]):
frame_mask = denoise_mask[0, 0, frame_idx]
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
has_spatial_mask = True
break
[vx, v_pixel_coords, additional_args] = super()._process_input(
vx, keyframe_idxs, denoise_mask, **kwargs
)
additional_args["has_spatial_mask"] = has_spatial_mask
ax, a_latent_coords = self.a_patchifier.patchify(ax)
ax = self.audio_patchify_proj(ax)
@ -574,9 +618,8 @@ class LTXAVModel(LTXVModel):
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
orig_shape = kwargs.get("orig_shape")
has_spatial_mask = kwargs.get("has_spatial_mask", None)
v_patches_per_frame = None
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
if orig_shape is not None and len(orig_shape) == 5:
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
v_patches_per_frame = orig_shape[3] * orig_shape[4]
@ -619,11 +662,10 @@ class LTXAVModel(LTXVModel):
)
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
cross_av_timestep_ss = [
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
]

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed
from comfy.ldm.modules.diffusionmodules.model import vae_attention
import comfy.ops
ops = comfy.ops.disable_weight_init
@ -20,29 +20,22 @@ class CausalConv3d(ops.Conv3d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = 2 * self.padding[0]
self.padding = (0, self.padding[1], self.padding[2])
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
if cache_list is not None:
cache_x = cache_list[cache_idx]
cache_list[cache_idx] = None
if cache_x is None and x.shape[2] == 1:
#Fast path - the op will pad for use by truncating the weight
#and save math on a pile of zeros.
return super().forward(x, autopad="causal_zero")
if self._padding > 0:
padding_needed = self._padding
if cache_x is not None:
cache_x = cache_x.to(x.device)
padding_needed = max(0, padding_needed - cache_x.shape[2])
padding_shape = list(x.shape)
padding_shape[2] = padding_needed
padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype)
x = torch_cat_if_needed([padding, cache_x, x], dim=2)
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
del cache_x
x = F.pad(x, padding)
return super().forward(x)

View File

@ -260,7 +260,6 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
key_map[k[:-len(".weight")]] = to #DiffSynth lora format
for k in sdk:
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
if k.endswith(".weight") and ".linear1." in k:

View File

@ -203,9 +203,7 @@ class disable_weight_init:
def reset_parameters(self):
return None
def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs):
if autopad == "causal_zero":
weight = weight[:, :, -input.shape[2]:, :, :]
def _conv_forward(self, input, weight, bias, *args, **kwargs):
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
if bias is not None:
@ -214,15 +212,15 @@ class disable_weight_init:
else:
return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input, autopad=None):
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias, autopad=autopad)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)

View File

@ -37,18 +37,12 @@ def prepare_noise(latent_image, seed, noise_inds=None):
return noises
def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None):
def fix_empty_latent_channels(model, latent_image):
if latent_image.is_nested:
return latent_image
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if torch.count_nonzero(latent_image) == 0:
if latent_format.latent_channels != latent_image.shape[1]:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if downscale_ratio_spacial is not None:
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio
latent_image = comfy.utils.common_upscale(latent_image, round(latent_image.shape[-1] * ratio), round(latent_image.shape[-2] * ratio), "nearest-exact", crop="disabled")
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
latent_image = latent_image.unsqueeze(2)
return latent_image

View File

@ -1,10 +1,12 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .video_types import VideoInput
from .video_types import VideoInput, VideoOp, SliceOp
__all__ = [
"ImageInput",
"AudioInput",
"VideoInput",
"VideoOp",
"SliceOp",
"MaskInput",
"LatentInput",
]

View File

@ -1,11 +1,48 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from fractions import Fraction
from typing import Optional, Union, IO
import copy
import io
import av
from .._util import VideoContainer, VideoCodec, VideoComponents
class VideoOp(ABC):
"""Base class for lazy video operations."""
@abstractmethod
def apply(self, components: VideoComponents) -> VideoComponents:
pass
@abstractmethod
def compute_frame_count(self, input_frame_count: int) -> int:
pass
@dataclass(frozen=True)
class SliceOp(VideoOp):
"""Extract a range of frames from the video."""
start_frame: int
frame_count: int
def apply(self, components: VideoComponents) -> VideoComponents:
total = components.images.shape[0]
start = max(0, min(self.start_frame, total))
end = min(start + self.frame_count, total)
return VideoComponents(
images=components.images[start:end],
audio=components.audio,
frame_rate=components.frame_rate,
metadata=getattr(components, 'metadata', None),
)
def compute_frame_count(self, input_frame_count: int) -> int:
start = max(0, min(self.start_frame, input_frame_count))
return min(self.frame_count, input_frame_count - start)
class VideoInput(ABC):
"""
Abstract base class for video input types.
@ -21,6 +58,12 @@ class VideoInput(ABC):
"""
pass
def sliced(self, start_frame: int, frame_count: int) -> "VideoInput":
"""Return a copy of this video with a slice operation appended."""
new = copy.copy(self)
new._operations = getattr(self, '_operations', []) + [SliceOp(start_frame, frame_count)]
return new
@abstractmethod
def save_to(
self,

View File

@ -1,7 +1,8 @@
from .video_types import VideoFromFile, VideoFromComponents
from .._input import SliceOp
__all__ = [
# Implementations
"VideoFromFile",
"VideoFromComponents",
"SliceOp",
]

View File

@ -3,7 +3,7 @@ from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from .._input import AudioInput, VideoInput
from .._input import AudioInput, VideoInput, VideoOp
import av
import io
import json
@ -63,6 +63,8 @@ class VideoFromFile(VideoInput):
containing the file contents.
"""
self.__file = file
self._operations: list[VideoOp] = []
self.__materialized: Optional[VideoFromComponents] = None
def get_stream_source(self) -> str | io.BytesIO:
"""
@ -161,6 +163,10 @@ class VideoFromFile(VideoInput):
if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
# Apply operations to get final frame count
for op in self._operations:
frame_count = op.compute_frame_count(frame_count)
return frame_count
def get_frame_rate(self) -> Fraction:
@ -239,10 +245,18 @@ class VideoFromFile(VideoInput):
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents:
if self.__materialized is not None:
return self.__materialized.get_components()
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
return self.get_components_internal(container)
components = self.get_components_internal(container)
for op in self._operations:
components = op.apply(components)
self.__materialized = VideoFromComponents(components)
self._operations = []
return components
raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to(
@ -317,14 +331,27 @@ class VideoFromComponents(VideoInput):
def __init__(self, components: VideoComponents):
self.__components = components
self._operations: list[VideoOp] = []
def get_components(self) -> VideoComponents:
if self._operations:
components = self.__components
for op in self._operations:
components = op.apply(components)
self.__components = components
self._operations = []
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate
)
def get_frame_count(self) -> int:
count = int(self.__components.images.shape[0])
for op in self._operations:
count = op.compute_frame_count(count)
return count
def save_to(
self,
path: str,
@ -332,6 +359,9 @@ class VideoFromComponents(VideoInput):
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
# Materialize ops before saving
components = self.get_components()
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
@ -345,22 +375,22 @@ class VideoFromComponents(VideoInput):
for key, value in metadata.items():
output.metadata[key] = json.dumps(value)
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
frame_rate = Fraction(round(components.frame_rate * 1000), 1000)
# Create a video stream
video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2]
video_stream.height = self.__components.images.shape[1]
video_stream.width = components.images.shape[2]
video_stream.height = components.images.shape[1]
video_stream.pix_fmt = 'yuv420p'
# Create an audio stream
audio_sample_rate = 1
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
if components.audio:
audio_sample_rate = int(components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
# Encode video
for i, frame in enumerate(self.__components.images):
for i, frame in enumerate(components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
@ -371,9 +401,9 @@ class VideoFromComponents(VideoInput):
packet = video_stream.encode(None)
output.mux(packet)
if audio_stream and self.__components.audio:
waveform = self.__components.audio['waveform']
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
if audio_stream and components.audio:
waveform = components.audio['waveform']
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * components.images.shape[0])]
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
frame.sample_rate = audio_sample_rate
frame.pts = 0

View File

@ -28,7 +28,6 @@ class AlignYourStepsScheduler(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AlignYourStepsScheduler",
search_aliases=["AYS scheduler"],
category="sampling/custom_sampling/schedulers",
inputs=[
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),

View File

@ -71,7 +71,6 @@ class CLIPAttentionMultiply(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CLIPAttentionMultiply",
search_aliases=["clip attention scale", "text encoder attention"],
category="_for_testing/attention_experiments",
inputs=[
io.Clip.Input("clip"),

View File

@ -69,7 +69,6 @@ class VAEEncodeAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="VAEEncodeAudio",
search_aliases=["audio to latent"],
display_name="VAE Encode Audio",
category="latent/audio",
inputs=[
@ -98,7 +97,6 @@ class VAEDecodeAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="VAEDecodeAudio",
search_aliases=["latent to audio"],
display_name="VAE Decode Audio",
category="latent/audio",
inputs=[
@ -124,7 +122,6 @@ class SaveAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="SaveAudio",
search_aliases=["export flac"],
display_name="Save Audio (FLAC)",
category="audio",
inputs=[
@ -149,7 +146,6 @@ class SaveAudioMP3(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="SaveAudioMP3",
search_aliases=["export mp3"],
display_name="Save Audio (MP3)",
category="audio",
inputs=[
@ -177,7 +173,6 @@ class SaveAudioOpus(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="SaveAudioOpus",
search_aliases=["export opus"],
display_name="Save Audio (Opus)",
category="audio",
inputs=[
@ -205,7 +200,6 @@ class PreviewAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="PreviewAudio",
search_aliases=["play audio"],
display_name="Preview Audio",
category="audio",
inputs=[
@ -265,7 +259,6 @@ class LoadAudio(IO.ComfyNode):
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
return IO.Schema(
node_id="LoadAudio",
search_aliases=["import audio", "open audio", "audio file"],
display_name="Load Audio",
category="audio",
inputs=[
@ -303,7 +296,6 @@ class RecordAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="RecordAudio",
search_aliases=["microphone input", "audio capture", "voice input"],
display_name="Record Audio",
category="audio",
inputs=[
@ -328,7 +320,6 @@ class TrimAudioDuration(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TrimAudioDuration",
search_aliases=["cut audio", "audio clip", "shorten audio"],
display_name="Trim Audio Duration",
description="Trim audio tensor into chosen time range.",
category="audio",
@ -381,7 +372,6 @@ class SplitAudioChannels(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="SplitAudioChannels",
search_aliases=["stereo to mono"],
display_name="Split Audio Channels",
description="Separates the audio into left and right channels.",
category="audio",
@ -482,7 +472,6 @@ class AudioConcat(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="AudioConcat",
search_aliases=["join audio", "combine audio", "append audio"],
display_name="Audio Concat",
description="Concatenates the audio1 to audio2 in the specified direction.",
category="audio",
@ -530,7 +519,6 @@ class AudioMerge(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="AudioMerge",
search_aliases=["mix audio", "overlay audio", "layer audio"],
display_name="Audio Merge",
description="Combine two audio tracks by overlaying their waveforms.",
category="audio",
@ -591,7 +579,6 @@ class AudioAdjustVolume(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="AudioAdjustVolume",
search_aliases=["audio gain", "loudness", "audio level"],
display_name="Audio Adjust Volume",
category="audio",
inputs=[
@ -627,7 +614,6 @@ class EmptyAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="EmptyAudio",
search_aliases=["blank audio"],
display_name="Empty Audio",
category="audio",
inputs=[

View File

@ -10,7 +10,6 @@ class Canny(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Canny",
search_aliases=["edge detection", "outline", "contour detection", "line art"],
category="image/preprocessors",
inputs=[
io.Image.Input("image"),

View File

@ -109,7 +109,6 @@ class PorterDuffImageComposite(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PorterDuffImageComposite",
search_aliases=["alpha composite", "blend modes", "layer blend", "transparency blend"],
display_name="Porter-Duff Image Composite",
category="mask/compositing",
inputs=[
@ -166,7 +165,6 @@ class SplitImageWithAlpha(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SplitImageWithAlpha",
search_aliases=["extract alpha", "separate transparency", "remove alpha"],
display_name="Split Image with Alpha",
category="mask/compositing",
inputs=[
@ -190,7 +188,6 @@ class JoinImageWithAlpha(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="JoinImageWithAlpha",
search_aliases=["add transparency", "apply alpha", "composite alpha", "RGBA"],
display_name="Join Image with Alpha",
category="mask/compositing",
inputs=[

View File

@ -38,7 +38,6 @@ class ControlNetInpaintingAliMamaApply(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ControlNetInpaintingAliMamaApply",
search_aliases=["masked controlnet"],
category="conditioning/controlnet",
inputs=[
io.Conditioning.Input("positive"),

View File

@ -297,7 +297,6 @@ class ExtendIntermediateSigmas(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ExtendIntermediateSigmas",
search_aliases=["interpolate sigmas"],
category="sampling/custom_sampling/sigmas",
inputs=[
io.Sigmas.Input("sigmas"),
@ -741,7 +740,7 @@ class SamplerCustom(io.ComfyNode):
latent = latent_image
latent_image = latent["samples"]
latent = latent.copy()
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
latent["samples"] = latent_image
if not add_noise:
@ -760,7 +759,6 @@ class SamplerCustom(io.ComfyNode):
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
out = latent.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = samples
if "x0" in x0_output:
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
@ -858,7 +856,6 @@ class DualCFGGuider(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="DualCFGGuider",
search_aliases=["dual prompt guidance"],
category="sampling/custom_sampling/guiders",
inputs=[
io.Model.Input("model"),
@ -886,7 +883,6 @@ class DisableNoise(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="DisableNoise",
search_aliases=["zero noise"],
category="sampling/custom_sampling/noise",
inputs=[],
outputs=[io.Noise.Output()]
@ -940,7 +936,7 @@ class SamplerCustomAdvanced(io.ComfyNode):
latent = latent_image
latent_image = latent["samples"]
latent = latent.copy()
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image, latent.get("downscale_ratio_spacial", None))
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image)
latent["samples"] = latent_image
noise_mask = None
@ -955,7 +951,6 @@ class SamplerCustomAdvanced(io.ComfyNode):
samples = samples.to(comfy.model_management.intermediate_device())
out = latent.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = samples
if "x0" in x0_output:
x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
@ -1024,7 +1019,6 @@ class ManualSigmas(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ManualSigmas",
search_aliases=["custom noise schedule", "define sigmas"],
category="_for_testing/custom_sampling",
is_experimental=True,
inputs=[

View File

@ -1223,11 +1223,11 @@ class ResolutionBucket(io.ComfyNode):
class MakeTrainingDataset(io.ComfyNode):
"""Encode images with VAE and texts with CLIP to create a training dataset."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MakeTrainingDataset",
search_aliases=["encode dataset"],
display_name="Make Training Dataset",
category="dataset",
is_experimental=True,
@ -1309,11 +1309,11 @@ class MakeTrainingDataset(io.ComfyNode):
class SaveTrainingDataset(io.ComfyNode):
"""Save encoded training dataset (latents + conditioning) to disk."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveTrainingDataset",
search_aliases=["export training data"],
display_name="Save Training Dataset",
category="dataset",
is_experimental=True,
@ -1410,11 +1410,11 @@ class SaveTrainingDataset(io.ComfyNode):
class LoadTrainingDataset(io.ComfyNode):
"""Load encoded training dataset from disk."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadTrainingDataset",
search_aliases=["import dataset", "training data"],
display_name="Load Training Dataset",
category="dataset",
is_experimental=True,

View File

@ -11,7 +11,6 @@ class DifferentialDiffusion(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="DifferentialDiffusion",
search_aliases=["inpaint gradient", "variable denoise strength"],
display_name="Differential Diffusion",
category="_for_testing",
inputs=[

View File

@ -58,7 +58,6 @@ class FreSca(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="FreSca",
search_aliases=["frequency guidance"],
display_name="FreSca",
category="_for_testing",
description="Applies frequency-dependent scaling to the guidance",

View File

@ -38,7 +38,6 @@ class CLIPTextEncodeHiDream(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeHiDream",
search_aliases=["hidream prompt"],
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),

View File

@ -259,7 +259,6 @@ class SetClipHooks:
return (clip,)
class ConditioningTimestepsRange:
SEARCH_ALIASES = ["prompt scheduling", "timestep segments", "conditioning phases"]
NodeId = 'ConditioningTimestepsRange'
NodeName = 'Timesteps Range'
@classmethod
@ -469,7 +468,6 @@ class SetHookKeyframes:
return (hooks,)
class CreateHookKeyframe:
SEARCH_ALIASES = ["hook scheduling", "strength animation", "timed hook"]
NodeId = 'CreateHookKeyframe'
NodeName = 'Create Hook Keyframe'
@classmethod
@ -499,7 +497,6 @@ class CreateHookKeyframe:
return (prev_hook_kf,)
class CreateHookKeyframesInterpolated:
SEARCH_ALIASES = ["ease hook strength", "smooth hook transition", "interpolate keyframes"]
NodeId = 'CreateHookKeyframesInterpolated'
NodeName = 'Create Hook Keyframes Interp.'
@classmethod
@ -547,7 +544,6 @@ class CreateHookKeyframesInterpolated:
return (prev_hook_kf,)
class CreateHookKeyframesFromFloats:
SEARCH_ALIASES = ["batch keyframes", "strength list to keyframes"]
NodeId = 'CreateHookKeyframesFromFloats'
NodeName = 'Create Hook Keyframes From Floats'
@classmethod
@ -622,7 +618,6 @@ class SetModelHooksOnCond:
# Combine Hooks
#------------------------------------------
class CombineHooks:
SEARCH_ALIASES = ["merge hooks"]
NodeId = 'CombineHooks2'
NodeName = 'Combine Hooks [2]'
@classmethod

View File

@ -618,7 +618,6 @@ class SaveGLB(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="SaveGLB",
search_aliases=["export 3d model", "save mesh"],
category="3d",
is_output_node=True,
inputs=[

View File

@ -22,7 +22,6 @@ class ImageCrop(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageCrop",
search_aliases=["trim"],
display_name="Image Crop",
category="image/transform",
inputs=[
@ -52,7 +51,6 @@ class RepeatImageBatch(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="RepeatImageBatch",
search_aliases=["duplicate image", "clone image"],
category="image/batch",
inputs=[
IO.Image.Input("image"),
@ -74,7 +72,6 @@ class ImageFromBatch(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageFromBatch",
search_aliases=["select image", "pick from batch", "extract image"],
category="image/batch",
inputs=[
IO.Image.Input("image"),
@ -100,7 +97,6 @@ class ImageAddNoise(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageAddNoise",
search_aliases=["film grain"],
category="image",
inputs=[
IO.Image.Input("image"),
@ -198,11 +194,11 @@ class SaveAnimatedPNG(IO.ComfyNode):
class ImageStitch(IO.ComfyNode):
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageStitch",
search_aliases=["combine images", "join images", "concatenate images", "side by side"],
display_name="Image Stitch",
description="Stitches image2 to image1 in the specified direction.\n"
"If image2 is not provided, returns image1 unchanged.\n"
@ -373,11 +369,11 @@ class ImageStitch(IO.ComfyNode):
class ResizeAndPadImage(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ResizeAndPadImage",
search_aliases=["fit to size"],
category="image/transform",
inputs=[
IO.Image.Input("image"),
@ -424,11 +420,11 @@ class ResizeAndPadImage(IO.ComfyNode):
class SaveSVGNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveSVGNode",
search_aliases=["export vector", "save vector graphics"],
description="Save SVG files on disk.",
category="image/save",
inputs=[
@ -496,11 +492,11 @@ class SaveSVGNode(IO.ComfyNode):
class GetImageSize(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GetImageSize",
search_aliases=["dimensions", "resolution", "image info"],
display_name="Get Image Size",
description="Returns width and height of the image, and passes it through unchanged.",
category="image",
@ -531,11 +527,11 @@ class GetImageSize(IO.ComfyNode):
class ImageRotate(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageRotate",
search_aliases=["turn", "flip orientation"],
category="image/transform",
inputs=[
IO.Image.Input("image"),
@ -561,11 +557,11 @@ class ImageRotate(IO.ComfyNode):
class ImageFlip(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageFlip",
search_aliases=["mirror", "reflect"],
category="image/transform",
inputs=[
IO.Image.Input("image"),

View File

@ -104,7 +104,6 @@ class CLIPTextEncodeKandinsky5(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeKandinsky5",
search_aliases=["kandinsky prompt"],
category="advanced/conditioning/kandinsky5",
inputs=[
io.Clip.Input("clip"),

View File

@ -21,7 +21,6 @@ class LatentAdd(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentAdd",
search_aliases=["combine latents", "sum latents"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
@ -48,7 +47,6 @@ class LatentSubtract(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentSubtract",
search_aliases=["difference latent", "remove features"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
@ -75,7 +73,6 @@ class LatentMultiply(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentMultiply",
search_aliases=["scale latent", "amplify latent", "latent gain"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
@ -99,7 +96,6 @@ class LatentInterpolate(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentInterpolate",
search_aliases=["blend latent", "mix latent", "lerp latent", "transition"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
@ -138,7 +134,6 @@ class LatentConcat(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentConcat",
search_aliases=["join latents", "stitch latents"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
@ -178,7 +173,6 @@ class LatentCut(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentCut",
search_aliases=["crop latent", "slice latent", "extract region"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
@ -219,7 +213,6 @@ class LatentCutToBatch(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentCutToBatch",
search_aliases=["slice to batch", "split latent", "tile latent"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
@ -261,7 +254,6 @@ class LatentBatch(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentBatch",
search_aliases=["combine latents", "merge latents", "join latents"],
category="latent/batch",
is_deprecated=True,
inputs=[
@ -318,7 +310,6 @@ class LatentApplyOperation(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentApplyOperation",
search_aliases=["transform latent"],
category="latent/advanced/operations",
is_experimental=True,
inputs=[
@ -374,7 +365,6 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentOperationTonemapReinhard",
search_aliases=["hdr latent"],
category="latent/advanced/operations",
is_experimental=True,
inputs=[

View File

@ -75,7 +75,6 @@ class Preview3D(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="Preview3D",
search_aliases=["view mesh", "3d viewer"],
display_name="Preview 3D & Animation",
category="3d",
is_experimental=True,

View File

@ -224,7 +224,6 @@ class ConvertStringToComboNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ConvertStringToComboNode",
search_aliases=["string to dropdown", "text to combo"],
display_name="Convert String to Combo",
category="logic",
inputs=[io.String.Input("string")],
@ -240,7 +239,6 @@ class InvertBooleanNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="InvertBooleanNode",
search_aliases=["not", "toggle", "negate", "flip boolean"],
display_name="Invert Boolean",
category="logic",
inputs=[io.Boolean.Input("boolean")],

View File

@ -78,7 +78,6 @@ class LoraSave(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LoraSave",
search_aliases=["export lora"],
display_name="Extract and Save Lora",
category="_for_testing",
inputs=[

View File

@ -79,7 +79,6 @@ class CLIPTextEncodeLumina2(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeLumina2",
search_aliases=["lumina prompt"],
display_name="CLIP Text Encode for Lumina2",
category="conditioning",
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "

View File

@ -50,7 +50,6 @@ class LatentCompositeMasked(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="LatentCompositeMasked",
search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"],
category="latent",
inputs=[
IO.Latent.Input("destination"),
@ -79,7 +78,6 @@ class ImageCompositeMasked(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageCompositeMasked",
search_aliases=["paste image", "overlay", "layer"],
category="image",
inputs=[
IO.Image.Input("destination"),
@ -107,7 +105,6 @@ class MaskToImage(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="MaskToImage",
search_aliases=["convert mask"],
display_name="Convert Mask to Image",
category="mask",
inputs=[
@ -129,7 +126,6 @@ class ImageToMask(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageToMask",
search_aliases=["extract channel", "channel to mask"],
display_name="Convert Image to Mask",
category="mask",
inputs=[
@ -153,7 +149,6 @@ class ImageColorToMask(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageColorToMask",
search_aliases=["color keying", "chroma key"],
category="mask",
inputs=[
IO.Image.Input("image"),
@ -199,7 +194,6 @@ class InvertMask(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="InvertMask",
search_aliases=["reverse mask", "flip mask"],
category="mask",
inputs=[
IO.Mask.Input("mask"),
@ -220,7 +214,6 @@ class CropMask(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="CropMask",
search_aliases=["cut mask", "extract mask region", "mask slice"],
category="mask",
inputs=[
IO.Mask.Input("mask"),
@ -246,7 +239,6 @@ class MaskComposite(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="MaskComposite",
search_aliases=["combine masks", "blend masks", "layer masks"],
category="mask",
inputs=[
IO.Mask.Input("destination"),
@ -295,7 +287,6 @@ class FeatherMask(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="FeatherMask",
search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"],
category="mask",
inputs=[
IO.Mask.Input("mask"),
@ -342,7 +333,6 @@ class GrowMask(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="GrowMask",
search_aliases=["expand mask", "shrink mask"],
display_name="Grow Mask",
category="mask",
inputs=[
@ -380,7 +370,6 @@ class ThresholdMask(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ThresholdMask",
search_aliases=["binary mask"],
category="mask",
inputs=[
IO.Mask.Input("mask"),
@ -405,7 +394,6 @@ class MaskPreview(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="MaskPreview",
search_aliases=["show mask", "view mask", "inspect mask", "debug mask"],
display_name="Preview Mask",
category="mask",
description="Saves the input images to your ComfyUI output directory.",

View File

@ -299,7 +299,6 @@ class RescaleCFG:
return (m, )
class ModelComputeDtype:
SEARCH_ALIASES = ["model precision", "change dtype"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),

View File

@ -91,7 +91,6 @@ class CLIPMergeSimple:
class CLIPSubtract:
SEARCH_ALIASES = ["clip difference", "text encoder subtract"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
@ -114,7 +113,6 @@ class CLIPSubtract:
class CLIPAdd:
SEARCH_ALIASES = ["combine clip"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
@ -227,7 +225,6 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
class CheckpointSave:
SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@ -340,7 +337,6 @@ class VAESave:
return {}
class ModelSave:
SEARCH_ALIASES = ["export model", "checkpoint save"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()

View File

@ -12,7 +12,6 @@ class Morphology(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Morphology",
search_aliases=["erode", "dilate"],
display_name="ImageMorphology",
category="image/postprocessing",
inputs=[
@ -58,7 +57,6 @@ class ImageRGBToYUV(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ImageRGBToYUV",
search_aliases=["color space conversion"],
category="image/batch",
inputs=[
io.Image.Input("image"),
@ -80,7 +78,6 @@ class ImageYUVToRGB(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ImageYUVToRGB",
search_aliases=["color space conversion"],
category="image/batch",
inputs=[
io.Image.Input("Y"),

View File

@ -7,7 +7,6 @@ class CLIPTextEncodePixArtAlpha(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodePixArtAlpha",
search_aliases=["pixart prompt"],
category="advanced/conditioning",
description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.",
inputs=[

View File

@ -402,6 +402,7 @@ def scale_to_multiple_cover(input: torch.Tensor, multiple: int, scale_method: st
return input[:, y0:y1, x0:x1]
class ResizeImageMaskNode(io.ComfyNode):
scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop_methods = ["disabled", "center"]
@ -420,62 +421,46 @@ class ResizeImageMaskNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.MatchType.Template("input_type", [io.Image, io.Mask])
crop_combo = io.Combo.Input(
"crop",
options=cls.crop_methods,
default="center",
tooltip="How to handle aspect ratio mismatch: 'disabled' stretches to fit, 'center' crops to maintain aspect ratio.",
)
crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center")
return io.Schema(
node_id="ResizeImageMaskNode",
display_name="Resize Image/Mask",
description="Resize an image or mask using various scaling methods.",
category="transform",
search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"],
inputs=[
io.MatchType.Input("input", template=template),
io.DynamicCombo.Input(
"resize_type",
tooltip="Select how to resize: by exact dimensions, scale factor, matching another image, etc.",
options=[
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Set to 0 to auto-calculate from height while preserving aspect ratio."),
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Set to 0 to auto-calculate from width while preserving aspect ratio."),
crop_combo,
io.DynamicCombo.Input("resize_type", options=[
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01),
]),
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01, tooltip="Scale factor (e.g., 2.0 doubles size, 0.5 halves size)."),
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
crop_combo,
]),
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The longer edge will be resized to this value. Aspect ratio is preserved."),
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
]),
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The shorter edge will be resized to this value. Aspect ratio is preserved."),
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
]),
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Height auto-adjusts to preserve aspect ratio."),
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
]),
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Width auto-adjusts to preserve aspect ratio."),
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
]),
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01, tooltip="Target total megapixels (e.g., 1.0 ≈ 1024×1024). Aspect ratio is preserved."),
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
]),
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
io.MultiType.Input("match", [io.Image, io.Mask], tooltip="Resize input to match the dimensions of this reference image or mask."),
crop_combo,
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
io.MultiType.Input("match", [io.Image, io.Mask]),
crop_combo,
]),
io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [
io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1, tooltip="Resize so width and height are divisible by this number. Useful for latent alignment (e.g., 8 or 64)."),
io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [
io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1),
]),
],
),
io.Combo.Input(
"scale_method",
options=cls.scale_methods,
default="area",
tooltip="Interpolation algorithm. 'area' is best for downscaling, 'lanczos' for upscaling, 'nearest-exact' for pixel art.",
),
]),
io.Combo.Input("scale_method", options=cls.scale_methods, default="area"),
],
outputs=[io.MatchType.Output(template=template, display_name="resized")]
)
@ -584,7 +569,6 @@ class BatchMasksNode(io.ComfyNode):
autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50)
return io.Schema(
node_id="BatchMasksNode",
search_aliases=["combine masks", "stack masks", "merge masks"],
display_name="Batch Masks",
category="mask",
inputs=[
@ -605,7 +589,6 @@ class BatchLatentsNode(io.ComfyNode):
autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50)
return io.Schema(
node_id="BatchLatentsNode",
search_aliases=["combine latents", "stack latents", "merge latents"],
display_name="Batch Latents",
category="latent",
inputs=[
@ -629,7 +612,6 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
prefix="input", min=1, max=50)
return io.Schema(
node_id="BatchImagesMasksLatentsNode",
search_aliases=["combine batch", "merge batch", "stack inputs"],
display_name="Batch Images/Masks/Latents",
category="util",
inputs=[

View File

@ -16,7 +16,7 @@ class PreviewAny():
OUTPUT_NODE = True
CATEGORY = "utils"
SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"]
SEARCH_ALIASES = ["preview", "show", "display", "view", "show text", "display text", "preview text", "show output", "inspect", "debug"]
def main(self, source=None):
value = 'None'

View File

@ -55,7 +55,7 @@ class EmptySD3LatentImage(io.ComfyNode):
@classmethod
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
return io.NodeOutput({"samples":latent})
generate = execute # TODO: remove
@ -65,7 +65,6 @@ class CLIPTextEncodeSD3(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeSD3",
search_aliases=["sd3 prompt"],
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),

View File

@ -32,7 +32,6 @@ class StringSubstring(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringSubstring",
search_aliases=["extract text", "text portion"],
display_name="Substring",
category="utils/string",
inputs=[
@ -55,7 +54,6 @@ class StringLength(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringLength",
search_aliases=["character count", "text size"],
display_name="Length",
category="utils/string",
inputs=[
@ -76,7 +74,6 @@ class CaseConverter(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CaseConverter",
search_aliases=["text case", "uppercase", "lowercase", "capitalize"],
display_name="Case Converter",
category="utils/string",
inputs=[
@ -109,7 +106,6 @@ class StringTrim(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringTrim",
search_aliases=["clean whitespace", "remove whitespace"],
display_name="Trim",
category="utils/string",
inputs=[
@ -140,7 +136,6 @@ class StringReplace(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringReplace",
search_aliases=["find and replace", "substitute", "swap text"],
display_name="Replace",
category="utils/string",
inputs=[
@ -163,7 +158,6 @@ class StringContains(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringContains",
search_aliases=["text includes", "string includes"],
display_name="Contains",
category="utils/string",
inputs=[
@ -191,7 +185,6 @@ class StringCompare(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringCompare",
search_aliases=["text match", "string equals", "starts with", "ends with"],
display_name="Compare",
category="utils/string",
inputs=[
@ -227,7 +220,6 @@ class RegexMatch(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="RegexMatch",
search_aliases=["pattern match", "text contains", "string match"],
display_name="Regex Match",
category="utils/string",
inputs=[
@ -268,7 +260,6 @@ class RegexExtract(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="RegexExtract",
search_aliases=["pattern extract", "text parser", "parse text"],
display_name="Regex Extract",
category="utils/string",
inputs=[
@ -343,7 +334,6 @@ class RegexReplace(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="RegexReplace",
search_aliases=["pattern replace", "find and replace", "substitution"],
display_name="Regex Replace",
category="utils/string",
description="Find and replace text using regex patterns.",

View File

@ -1101,7 +1101,6 @@ class SaveLoRA(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveLoRA",
search_aliases=["export lora"],
display_name="Save LoRA Weights",
category="loaders",
is_experimental=True,
@ -1145,7 +1144,6 @@ class LossGraphNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LossGraphNode",
search_aliases=["training chart", "training visualization", "plot loss"],
display_name="Plot Loss Graph",
category="training",
is_experimental=True,

View File

@ -16,7 +16,6 @@ class SaveWEBM(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveWEBM",
search_aliases=["export webm"],
category="image/video",
is_experimental=True,
inputs=[
@ -70,7 +69,6 @@ class SaveVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveVideo",
search_aliases=["export video"],
display_name="Save Video",
category="image/video",
description="Saves the input images to your ComfyUI output directory.",
@ -118,7 +116,6 @@ class CreateVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CreateVideo",
search_aliases=["images to video"],
display_name="Create Video",
category="image/video",
description="Create a video from images.",
@ -143,7 +140,6 @@ class GetVideoComponents(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="GetVideoComponents",
search_aliases=["extract frames", "split video", "video to images", "demux"],
display_name="Get Video Components",
category="image/video",
description="Extracts all components from a video: frames, audio, and framerate.",
@ -163,6 +159,29 @@ class GetVideoComponents(io.ComfyNode):
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
class VideoSlice(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VideoSlice",
display_name="Video Slice",
category="image/video",
description="Extract a range of frames from a video.",
inputs=[
io.Video.Input("video", tooltip="The video to slice."),
io.Int.Input("start_frame", default=0, min=0, tooltip="The frame index to start from (0-indexed)."),
io.Int.Input("frame_count", default=1, min=1, tooltip="Number of frames to extract."),
],
outputs=[
io.Video.Output(tooltip="The sliced video."),
],
)
@classmethod
def execute(cls, video: Input.Video, start_frame: int, frame_count: int) -> io.NodeOutput:
return io.NodeOutput(video.sliced(start_frame, frame_count))
class LoadVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -171,7 +190,6 @@ class LoadVideo(io.ComfyNode):
files = folder_paths.filter_files_content_types(files, ["video"])
return io.Schema(
node_id="LoadVideo",
search_aliases=["import video", "open video", "video file"],
display_name="Load Video",
category="image/video",
inputs=[
@ -211,6 +229,7 @@ class VideoExtension(ComfyExtension):
SaveVideo,
CreateVideo,
GetVideoComponents,
VideoSlice,
LoadVideo,
]

View File

@ -287,7 +287,6 @@ class WanVaceToVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="WanVaceToVideo",
search_aliases=["video conditioning", "video control"],
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
@ -706,7 +705,6 @@ class WanTrackToVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="WanTrackToVideo",
search_aliases=["motion tracking", "trajectory video", "point tracking", "keypoint animation"],
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),

View File

@ -324,7 +324,6 @@ class GenerateTracks(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="GenerateTracks",
search_aliases=["motion paths", "camera movement", "trajectory"],
category="conditioning/video_models",
inputs=[
io.Int.Input("width", default=832, min=16, max=4096, step=16),

View File

@ -5,7 +5,6 @@ MAX_RESOLUTION = nodes.MAX_RESOLUTION
class WebcamCapture(nodes.LoadImage):
SEARCH_ALIASES = ["camera input", "live capture", "camera feed", "snapshot"]
@classmethod
def INPUT_TYPES(s):
return {

View File

@ -326,7 +326,7 @@ def setup_database():
if dependencies_available():
init_db()
if not args.disable_assets_autoscan:
seed_assets(["models", "input", "output"], enable_logging=True)
seed_assets(["models"], enable_logging=True)
except Exception as e:
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")

View File

@ -93,8 +93,6 @@ class ConditioningCombine:
return (conditioning_1 + conditioning_2, )
class ConditioningAverage :
SEARCH_ALIASES = ["blend prompts", "interpolate conditioning", "mix prompts", "style fusion", "weighted blend"]
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ),
@ -161,8 +159,6 @@ class ConditioningConcat:
return (out, )
class ConditioningSetArea:
SEARCH_ALIASES = ["regional prompt", "area prompt", "spatial conditioning", "localized prompt"]
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
@ -221,8 +217,6 @@ class ConditioningSetAreaStrength:
class ConditioningSetMask:
SEARCH_ALIASES = ["masked prompt", "regional inpaint conditioning", "mask conditioning"]
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
@ -248,8 +242,6 @@ class ConditioningSetMask:
return (c, )
class ConditioningZeroOut:
SEARCH_ALIASES = ["null conditioning", "clear conditioning"]
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", )}}
@ -475,8 +467,6 @@ class InpaintModelConditioning:
class SaveLatent:
SEARCH_ALIASES = ["export latent"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@ -528,8 +518,6 @@ class SaveLatent:
class LoadLatent:
SEARCH_ALIASES = ["import latent", "open latent"]
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
@ -566,8 +554,6 @@ class LoadLatent:
class CheckpointLoader:
SEARCH_ALIASES = ["load model", "model loader"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
@ -607,8 +593,6 @@ class CheckpointLoaderSimple:
return out[:3]
class DiffusersLoader:
SEARCH_ALIASES = ["load diffusers model"]
@classmethod
def INPUT_TYPES(cls):
paths = []
@ -1079,8 +1063,6 @@ class StyleModelLoader:
class StyleModelApply:
SEARCH_ALIASES = ["style transfer"]
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
@ -1230,12 +1212,10 @@ class EmptyLatentImage:
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
return ({"samples": latent, "downscale_ratio_spacial": 8}, )
return ({"samples":latent}, )
class LatentFromBatch:
SEARCH_ALIASES = ["select from batch", "pick latent", "batch subset"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
@ -1268,8 +1248,6 @@ class LatentFromBatch:
return (s,)
class RepeatLatentBatch:
SEARCH_ALIASES = ["duplicate latent", "clone latent"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
@ -1296,8 +1274,6 @@ class RepeatLatentBatch:
return (s,)
class LatentUpscale:
SEARCH_ALIASES = ["enlarge latent", "resize latent"]
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
crop_methods = ["disabled", "center"]
@ -1332,8 +1308,6 @@ class LatentUpscale:
return (s,)
class LatentUpscaleBy:
SEARCH_ALIASES = ["enlarge latent", "resize latent", "scale latent"]
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
@classmethod
@ -1377,8 +1351,6 @@ class LatentRotate:
return (s,)
class LatentFlip:
SEARCH_ALIASES = ["mirror latent"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
@ -1399,8 +1371,6 @@ class LatentFlip:
return (s,)
class LatentComposite:
SEARCH_ALIASES = ["overlay latent", "layer latent", "paste latent"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples_to": ("LATENT",),
@ -1443,8 +1413,6 @@ class LatentComposite:
return (samples_out,)
class LatentBlend:
SEARCH_ALIASES = ["mix latents", "interpolate latents"]
@classmethod
def INPUT_TYPES(s):
return {"required": {
@ -1486,8 +1454,6 @@ class LatentBlend:
raise ValueError(f"Unsupported blend mode: {mode}")
class LatentCrop:
SEARCH_ALIASES = ["trim latent", "cut latent"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
@ -1538,7 +1504,7 @@ class SetLatentNoiseMask:
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
latent_image = latent["samples"]
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
@ -1556,7 +1522,6 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
out = latent.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = samples
return (out, )
@ -1774,8 +1739,6 @@ class LoadImage:
return True
class LoadImageMask:
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
_color_channels = ["alpha", "red", "green", "blue"]
@classmethod
def INPUT_TYPES(s):
@ -1826,8 +1789,6 @@ class LoadImageMask:
class LoadImageOutput(LoadImage):
SEARCH_ALIASES = ["output image", "previous generation"]
@classmethod
def INPUT_TYPES(s):
return {
@ -1901,7 +1862,6 @@ class ImageScaleBy:
return (s,)
class ImageInvert:
SEARCH_ALIASES = ["reverse colors"]
@classmethod
def INPUT_TYPES(s):
@ -1917,7 +1877,6 @@ class ImageInvert:
return (s,)
class ImageBatch:
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
@classmethod
def INPUT_TYPES(s):
@ -1963,7 +1922,6 @@ class EmptyImage:
return (torch.cat((r, g, b), dim=-1), )
class ImagePadForOutpaint:
SEARCH_ALIASES = ["extend canvas", "expand image"]
@classmethod
def INPUT_TYPES(s):

View File

@ -22,7 +22,6 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
blake3
#non essential dependencies:
kornia>=0.7.1

View File

@ -689,7 +689,7 @@ class PromptServer():
@routes.get("/object_info")
async def get_object_info(request):
try:
seed_assets(["models", "input", "output"])
seed_assets(["models"])
except Exception as e:
logging.error(f"Failed to seed assets: {e}")
with folder_paths.cache_helper:

View File

@ -1,177 +0,0 @@
"""Tests for Assets API endpoints (app/assets/api/routes.py)
Tests cover:
- Schema validation for query parameters and request bodies
"""
import pytest
from pydantic import ValidationError
from app.assets.api import schemas_in, schemas_out
class TestListAssetsQuery:
"""Tests for ListAssetsQuery schema."""
def test_defaults(self):
"""Test default values."""
q = schemas_in.ListAssetsQuery()
assert q.include_tags == []
assert q.exclude_tags == []
assert q.limit == 20
assert q.offset == 0
assert q.sort == "created_at"
assert q.order == "desc"
assert q.include_public == True
def test_include_public_false(self):
"""Test include_public can be set to False."""
q = schemas_in.ListAssetsQuery(include_public=False)
assert q.include_public == False
def test_csv_tags_parsing(self):
"""Test comma-separated tags are parsed correctly."""
q = schemas_in.ListAssetsQuery.model_validate({"include_tags": "a,b,c"})
assert q.include_tags == ["a", "b", "c"]
def test_metadata_filter_json_string(self):
"""Test metadata_filter accepts JSON string."""
q = schemas_in.ListAssetsQuery.model_validate({"metadata_filter": '{"key": "value"}'})
assert q.metadata_filter == {"key": "value"}
class TestTagsListQuery:
"""Tests for TagsListQuery schema."""
def test_defaults(self):
"""Test default values."""
q = schemas_in.TagsListQuery()
assert q.prefix is None
assert q.limit == 100
assert q.offset == 0
assert q.order == "count_desc"
assert q.include_zero == True
assert q.include_public == True
def test_include_public_false(self):
"""Test include_public can be set to False."""
q = schemas_in.TagsListQuery(include_public=False)
assert q.include_public == False
class TestUpdateAssetBody:
"""Tests for UpdateAssetBody schema."""
def test_requires_at_least_one_field(self):
"""Test that at least one field is required."""
with pytest.raises(ValidationError):
schemas_in.UpdateAssetBody()
def test_name_only(self):
"""Test updating name only."""
body = schemas_in.UpdateAssetBody(name="new name")
assert body.name == "new name"
assert body.mime_type is None
assert body.preview_id is None
def test_mime_type_only(self):
"""Test updating mime_type only."""
body = schemas_in.UpdateAssetBody(mime_type="image/png")
assert body.mime_type == "image/png"
def test_preview_id_only(self):
"""Test updating preview_id only."""
body = schemas_in.UpdateAssetBody(preview_id="550e8400-e29b-41d4-a716-446655440000")
assert body.preview_id == "550e8400-e29b-41d4-a716-446655440000"
def test_preview_id_invalid_uuid(self):
"""Test invalid UUID for preview_id."""
with pytest.raises(ValidationError):
schemas_in.UpdateAssetBody(preview_id="not-a-uuid")
def test_all_fields(self):
"""Test all fields together."""
body = schemas_in.UpdateAssetBody(
name="test",
mime_type="application/json",
preview_id="550e8400-e29b-41d4-a716-446655440000",
user_metadata={"key": "value"}
)
assert body.name == "test"
assert body.mime_type == "application/json"
class TestUploadAssetFromUrlBody:
"""Tests for UploadAssetFromUrlBody schema (JSON URL upload)."""
def test_valid_url(self):
"""Test valid HTTP URL."""
body = schemas_in.UploadAssetFromUrlBody(
url="https://example.com/model.safetensors",
name="model.safetensors"
)
assert body.url == "https://example.com/model.safetensors"
assert body.name == "model.safetensors"
def test_http_url(self):
"""Test HTTP URL (not just HTTPS)."""
body = schemas_in.UploadAssetFromUrlBody(
url="http://example.com/file.bin",
name="file.bin"
)
assert body.url == "http://example.com/file.bin"
def test_invalid_url_scheme(self):
"""Test invalid URL scheme raises error."""
with pytest.raises(ValidationError):
schemas_in.UploadAssetFromUrlBody(
url="ftp://example.com/file.bin",
name="file.bin"
)
def test_tags_normalized(self):
"""Test tags are normalized to lowercase."""
body = schemas_in.UploadAssetFromUrlBody(
url="https://example.com/model.safetensors",
name="model",
tags=["Models", "LORAS"]
)
assert body.tags == ["models", "loras"]
class TestTagsRefineQuery:
"""Tests for TagsRefineQuery schema."""
def test_defaults(self):
"""Test default values."""
q = schemas_in.TagsRefineQuery()
assert q.include_tags == []
assert q.exclude_tags == []
assert q.limit == 100
assert q.include_public == True
def test_include_public_false(self):
"""Test include_public can be set to False."""
q = schemas_in.TagsRefineQuery(include_public=False)
assert q.include_public == False
class TestTagHistogramResponse:
"""Tests for TagHistogramResponse schema."""
def test_empty_response(self):
"""Test empty response."""
resp = schemas_out.TagHistogramResponse()
assert resp.tags == []
def test_with_entries(self):
"""Test response with entries."""
resp = schemas_out.TagHistogramResponse(
tags=[
schemas_out.TagHistogramEntry(name="models", count=10),
schemas_out.TagHistogramEntry(name="loras", count=5),
]
)
assert len(resp.tags) == 2
assert resp.tags[0].name == "models"
assert resp.tags[0].count == 10

View File

@ -0,0 +1,150 @@
import pytest
import torch
import tempfile
import os
import av
from fractions import Fraction
from comfy_api.input_impl.video_types import (
VideoFromFile,
VideoFromComponents,
SliceOp,
)
from comfy_api.util.video_types import VideoComponents
def create_test_video(width=4, height=4, frames=10, fps=30):
"""Helper to create a temporary video file."""
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
with av.open(tmp.name, mode="w") as container:
stream = container.add_stream("h264", rate=fps)
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
for i in range(frames):
frame_data = torch.ones(height, width, 3, dtype=torch.uint8) * (i * 25)
frame = av.VideoFrame.from_ndarray(frame_data.numpy(), format="rgb24")
frame = frame.reformat(format="yuv420p")
packet = stream.encode(frame)
container.mux(packet)
packet = stream.encode(None)
container.mux(packet)
return tmp.name
@pytest.fixture
def video_file_10_frames():
file_path = create_test_video(frames=10)
yield file_path
os.unlink(file_path)
@pytest.fixture
def video_components_10_frames():
images = torch.rand(10, 4, 4, 3)
return VideoComponents(images=images, frame_rate=Fraction(30))
class TestSliceOp:
def test_apply_slices_correctly(self, video_components_10_frames):
op = SliceOp(start_frame=2, frame_count=3)
result = op.apply(video_components_10_frames)
assert result.images.shape[0] == 3
assert torch.equal(result.images, video_components_10_frames.images[2:5])
def test_compute_frame_count(self):
op = SliceOp(start_frame=2, frame_count=5)
assert op.compute_frame_count(10) == 5
def test_compute_frame_count_clamps(self):
op = SliceOp(start_frame=8, frame_count=5)
assert op.compute_frame_count(10) == 2
class TestVideoSliced:
def test_sliced_returns_new_instance(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert video is not sliced
assert len(video._operations) == 0
assert len(sliced._operations) == 1
def test_get_components_applies_operations(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert torch.equal(components.images, video_components_10_frames.images[2:5])
def test_get_frame_count(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert sliced.get_frame_count() == 3
def test_get_duration(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(0, 3)
assert sliced.get_duration() == pytest.approx(0.1)
def test_chained_slices_compose(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 6).sliced(1, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert torch.equal(components.images, video_components_10_frames.images[3:6])
def test_operations_list_is_immutable(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced1 = video.sliced(0, 5)
sliced2 = sliced1.sliced(1, 2)
assert len(video._operations) == 0
assert len(sliced1._operations) == 1
assert len(sliced2._operations) == 2
def test_from_file(self, video_file_10_frames):
video = VideoFromFile(video_file_10_frames)
sliced = video.sliced(2, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert sliced.get_frame_count() == 3
def test_save_sliced_video(self, video_components_10_frames, tmp_path):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
output_path = str(tmp_path / "sliced_output.mp4")
sliced.save_to(output_path)
saved_video = VideoFromFile(output_path)
assert saved_video.get_frame_count() == 3
def test_materialization_clears_ops(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert len(sliced._operations) == 1
sliced.get_components()
assert len(sliced._operations) == 0
def test_second_get_components_uses_cache(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
first = sliced.get_components()
second = sliced.get_components()
assert first.images.shape == second.images.shape
assert torch.equal(first.images, second.images)