mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-29 00:06:18 +08:00
Compare commits
1 Commits
assets-api
...
cb/video-s
| Author | SHA1 | Date | |
|---|---|---|---|
| e4f3d335dc |
1
.gitignore
vendored
1
.gitignore
vendored
@ -21,7 +21,6 @@ venv/
|
||||
*.log
|
||||
web_custom_versions/
|
||||
.DS_Store
|
||||
*:Zone.Identifier
|
||||
openapi.yaml
|
||||
filtered-openapi.yaml
|
||||
uv.lock
|
||||
|
||||
139
PLAN.md
139
PLAN.md
@ -1,139 +0,0 @@
|
||||
# Plan: Align Local Asset/Tag Endpoints with Cloud
|
||||
|
||||
## Endpoint Comparison
|
||||
|
||||
| Endpoint | Cloud (openapi.yaml) | Local (routes.py) |
|
||||
|----------|---------------------|-------------------|
|
||||
| `GET /api/assets` | ✅ + `include_public` param | ✅ |
|
||||
| `POST /api/assets` | ✅ multipart + JSON URL upload | ✅ multipart only |
|
||||
| `GET /api/assets/{id}` | ✅ | ✅ |
|
||||
| `PUT /api/assets/{id}` | ✅ (`name`, `mime_type`, `preview_id`, `user_metadata`) | ✅ (`name`, `tags`, `user_metadata`) |
|
||||
| `DELETE /api/assets/{id}` | ✅ | ✅ |
|
||||
| `GET /api/assets/{id}/content` | ❌ | ✅ |
|
||||
| `POST /api/assets/{id}/tags` | ✅ | ✅ |
|
||||
| `DELETE /api/assets/{id}/tags` | ✅ | ✅ |
|
||||
| `PUT /api/assets/{id}/preview` | ❌ | ✅ |
|
||||
| `POST /api/assets/from-hash` | ✅ | ✅ |
|
||||
| `HEAD /api/assets/hash/{hash}` | ✅ | ✅ |
|
||||
| `GET /api/assets/remote-metadata` | ✅ | ❌ |
|
||||
| `POST /api/assets/download` | ✅ (background download) | ❌ |
|
||||
| `GET /api/assets/tags/refine` | ✅ (tag histogram) | ❌ |
|
||||
| `GET /api/tags` | ✅ + `include_public` param | ✅ |
|
||||
| `POST /api/assets/scan/seed` | ❌ | ✅ (local only) |
|
||||
|
||||
---
|
||||
|
||||
## Phase 1: Add Missing Cloud Endpoints to Local
|
||||
|
||||
### 1.1 `GET /api/assets/remote-metadata` *(deferred)*
|
||||
Fetch metadata from remote URLs (CivitAI, HuggingFace) without downloading the file.
|
||||
|
||||
**Status:** Not supported yet. Add stub/placeholder that returns 501 Not Implemented.
|
||||
|
||||
**Parameters:**
|
||||
- `url` (required): Download URL to retrieve metadata from
|
||||
|
||||
**Returns:** Asset metadata (name, size, hash if available, etc.)
|
||||
|
||||
### 1.2 `POST /api/assets/download` *(deferred)*
|
||||
Initiate background download job for large files from HuggingFace or CivitAI.
|
||||
|
||||
**Status:** Not supported yet. Add stub/placeholder that returns 501 Not Implemented.
|
||||
|
||||
**Request body:**
|
||||
- `source_url` (required): URL to download from
|
||||
- `tags`: Optional tags for the asset
|
||||
- `user_metadata`: Optional metadata
|
||||
- `preview_id`: Optional preview asset ID
|
||||
|
||||
**Returns:**
|
||||
- 200 if file already exists (returns asset immediately)
|
||||
- 202 with `task_id` for background download tracking via `GET /api/tasks/{task_id}`
|
||||
|
||||
### 1.3 `GET /api/assets/tags/refine`
|
||||
Get tag histogram for filtered assets (useful for search refinement UI).
|
||||
|
||||
**Parameters:**
|
||||
- `include_tags`: Filter assets with ALL these tags
|
||||
- `exclude_tags`: Exclude assets with ANY of these tags
|
||||
- `name_contains`: Filter by name substring
|
||||
- `metadata_filter`: JSON filter for metadata fields
|
||||
- `limit`: Max tags to return (default 100)
|
||||
- `include_public`: Include public/shared assets
|
||||
|
||||
**Returns:** List of tags with counts for matching assets
|
||||
|
||||
---
|
||||
|
||||
## Phase 2: Update Existing Endpoints for Parity
|
||||
|
||||
### 2.1 `GET /api/assets`
|
||||
- Add `include_public` query parameter (boolean, default true)
|
||||
|
||||
### 2.2 `POST /api/assets`
|
||||
- Add JSON body upload path for URL-based uploads:
|
||||
```json
|
||||
{
|
||||
"url": "https://...",
|
||||
"name": "model.safetensors",
|
||||
"tags": ["models", "checkpoints"],
|
||||
"user_metadata": {},
|
||||
"preview_id": "uuid"
|
||||
}
|
||||
```
|
||||
- Keep existing multipart upload support
|
||||
|
||||
### 2.3 `PUT /api/assets/{id}`
|
||||
- Add `mime_type` field support
|
||||
- Add `preview_id` field support
|
||||
- Remove direct `tags` field (recommend using dedicated `POST/DELETE /api/assets/{id}/tags` endpoints instead)
|
||||
|
||||
### 2.4 `GET /api/tags`
|
||||
- Add `include_public` query parameter (boolean, default true)
|
||||
|
||||
---
|
||||
|
||||
## Phase 3: Local-Only Endpoints
|
||||
|
||||
These endpoints exist locally but not in cloud.
|
||||
|
||||
### 3.1 `GET /api/assets/{id}/content`
|
||||
Download asset file content. Cloud uses signed URLs instead. **Keep for local.**
|
||||
|
||||
### 3.2 `PUT /api/assets/{id}/preview`
|
||||
**Remove this endpoint.** Merge functionality into `PUT /api/assets/{id}` by adding `preview_id` field support (aligns with cloud).
|
||||
|
||||
### 3.3 `POST /api/assets/scan/seed`
|
||||
Filesystem seeding/scanning for local asset discovery. Not applicable to cloud. **Keep as local-only.**
|
||||
|
||||
---
|
||||
|
||||
## Phase 4: Testing
|
||||
|
||||
Add tests for all new and modified endpoints to ensure functionality matches cloud behavior.
|
||||
|
||||
### 4.1 New Endpoint Tests
|
||||
- `GET /api/assets/remote-metadata` – Test with valid/invalid URLs, various sources (CivitAI, HuggingFace)
|
||||
- `POST /api/assets/download` – Test background download initiation, existing file detection, task tracking
|
||||
- `GET /api/assets/tags/refine` – Test histogram generation with various filter combinations
|
||||
|
||||
### 4.2 Updated Endpoint Tests
|
||||
- `GET /api/assets` – Test `include_public` param filtering
|
||||
- `POST /api/assets` – Test JSON URL upload path alongside existing multipart tests
|
||||
- `PUT /api/assets/{id}` – Test `mime_type` and `preview_id` field updates
|
||||
- `GET /api/tags` – Test `include_public` param filtering
|
||||
|
||||
### 4.3 Removed Endpoint Tests
|
||||
- Remove tests for `PUT /api/assets/{id}/preview`
|
||||
- Add tests for `preview_id` in `PUT /api/assets/{id}` to cover the merged functionality
|
||||
|
||||
---
|
||||
|
||||
## Implementation Order
|
||||
|
||||
1. Phase 2.1, 2.4 – Add `include_public` params (low effort, high compatibility)
|
||||
2. Phase 2.3 – Update PUT endpoint fields + remove preview endpoint
|
||||
3. Phase 2.2 – Add JSON URL upload to POST
|
||||
4. Phase 1.3 – Add tags/refine endpoint
|
||||
5. Phase 1.1, 1.2 – Add stub endpoints returning 501 (deferred implementation)
|
||||
6. Phase 4 – Add tests for each phase as implemented
|
||||
@ -1,20 +1,14 @@
|
||||
import logging
|
||||
import uuid
|
||||
import urllib.parse
|
||||
import os
|
||||
import contextlib
|
||||
from aiohttp import web
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
import app.assets.manager as manager
|
||||
import app.assets.scanner as scanner
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in
|
||||
from app.assets.helpers import get_query_dict
|
||||
|
||||
import folder_paths
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
|
||||
@ -34,18 +28,6 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||
if not hash_str or ":" not in hash_str:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = hash_str.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
exists = manager.asset_exists(asset_hash=hash_str)
|
||||
return web.Response(status=200 if exists else 404)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@ -67,29 +49,10 @@ async def list_assets(request: web.Request) -> web.Response:
|
||||
sort=q.sort,
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_public=q.include_public,
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets/remote-metadata")
|
||||
async def get_remote_asset_metadata(request: web.Request) -> web.Response:
|
||||
"""
|
||||
Fetch metadata from remote URLs (CivitAI, HuggingFace) without downloading.
|
||||
Status: Not implemented yet.
|
||||
"""
|
||||
return _error_response(501, "NOT_IMPLEMENTED", "Remote metadata fetching is not yet supported.")
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/download")
|
||||
async def create_asset_download(request: web.Request) -> web.Response:
|
||||
"""
|
||||
Initiate background download job for large files from HuggingFace or CivitAI.
|
||||
Status: Not implemented yet.
|
||||
"""
|
||||
return _error_response(501, "NOT_IMPLEMENTED", "Background asset download is not yet supported.")
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def get_asset(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@ -113,306 +76,6 @@ async def get_asset(request: web.Request) -> web.Response:
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
async def download_asset_content(request: web.Request) -> web.Response:
|
||||
# question: do we need disposition? could we just stick with one of these?
|
||||
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||
if disposition not in {"inline", "attachment"}:
|
||||
disposition = "attachment"
|
||||
|
||||
try:
|
||||
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
|
||||
asset_info_id=str(uuid.UUID(request.match_info["id"])),
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||
except NotImplementedError as nie:
|
||||
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||
except FileNotFoundError:
|
||||
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||
|
||||
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
|
||||
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||
|
||||
resp = web.FileResponse(abs_path)
|
||||
resp.content_type = content_type
|
||||
resp.headers["Content-Disposition"] = cd
|
||||
return resp
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/from-hash")
|
||||
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=body.hash,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||
return web.json_response(result.model_dump(mode="json"), status=201)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
async def upload_asset(request: web.Request) -> web.Response:
|
||||
"""Asset upload endpoint supporting multipart/form-data (file upload) or application/json (URL-based)."""
|
||||
|
||||
content_type = (request.content_type or "").lower()
|
||||
|
||||
if content_type.startswith("application/json"):
|
||||
try:
|
||||
payload = await request.json()
|
||||
schemas_in.UploadAssetFromUrlBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
return _error_response(501, "NOT_IMPLEMENTED", "URL-based asset upload is not yet supported. Use multipart/form-data file upload.")
|
||||
|
||||
if not content_type.startswith("multipart/"):
|
||||
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads or application/json for URL-based uploads.")
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: str | None = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: str | None = None
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
|
||||
if s:
|
||||
if ":" not in s:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
provided_hash = f"{algo}:{digest}"
|
||||
try:
|
||||
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
|
||||
except Exception:
|
||||
provided_hash_exists = None # do not fail the whole request here
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# If client supplied a hash that we know exists, drain but do not write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
|
||||
continue # Do not create temp file; we will create AssetInfo from the existing content
|
||||
|
||||
# Otherwise, store to temp for hashing/ingest
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
try:
|
||||
if os.path.exists(tmp_path or ""):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
|
||||
# If client did not send file, and we are not doing a from-hash fast path -> error
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
|
||||
|
||||
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
|
||||
# Empty upload is only acceptable if we are fast-pathing from existing hash
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
try:
|
||||
spec = schemas_in.UploadAssetSpec.model_validate({
|
||||
"tags": tags_raw,
|
||||
"name": provided_name,
|
||||
"user_metadata": user_metadata_raw,
|
||||
"hash": provided_hash,
|
||||
})
|
||||
except ValidationError as ve:
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
# Validate models category against configured folders (consistent with previous behavior)
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
|
||||
)
|
||||
|
||||
owner_id = USER_MANAGER.get_request_user_id(request)
|
||||
|
||||
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
|
||||
if spec.hash and provided_hash_exists is True:
|
||||
try:
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=spec.hash,
|
||||
name=spec.name or (spec.hash.split(":", 1)[1]),
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
owner_id=owner_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
|
||||
|
||||
# Drain temp if we accidentally saved (e.g., hash field came after file)
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
with contextlib.suppress(Exception):
|
||||
os.remove(tmp_path)
|
||||
|
||||
status = 200 if (not result.created_new) else 201
|
||||
return web.json_response(result.model_dump(mode="json"), status=status)
|
||||
|
||||
# Otherwise, we must have a temp file path to ingest
|
||||
if not tmp_path or not os.path.exists(tmp_path):
|
||||
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
|
||||
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
|
||||
|
||||
try:
|
||||
created = manager.upload_asset_from_temp_path(
|
||||
spec,
|
||||
temp_path=tmp_path,
|
||||
client_filename=file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_asset_hash=spec.hash,
|
||||
)
|
||||
status = 201 if created.created_new else 200
|
||||
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||
except ValueError as e:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
msg = str(e)
|
||||
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
|
||||
return _error_response(
|
||||
400,
|
||||
"HASH_MISMATCH",
|
||||
"Uploaded file hash does not match provided hash.",
|
||||
)
|
||||
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||
except Exception:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def update_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.update_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
name=body.name,
|
||||
mime_type=body.mime_type,
|
||||
preview_id=body.preview_id,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"update_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def delete_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content = request.query.get("delete_content")
|
||||
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
|
||||
|
||||
try:
|
||||
deleted = manager.delete_asset_reference(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if not deleted:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
|
||||
return web.Response(status=204)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@ -435,109 +98,5 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
order=query.order,
|
||||
include_zero=query.include_zero,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_public=query.include_public,
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsAdd.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.add_tags_to_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
origin="manual",
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsRemove.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.remove_tags_from_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets/tags/refine")
|
||||
async def get_asset_tag_histogram(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to get a tag histogram for filtered assets.
|
||||
"""
|
||||
query_dict = get_query_dict(request)
|
||||
try:
|
||||
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_QUERY", ve)
|
||||
|
||||
payload = manager.get_tag_histogram(
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
include_public=q.include_public,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/scan/seed")
|
||||
async def seed_assets(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
payload = {}
|
||||
|
||||
try:
|
||||
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
try:
|
||||
scanner.seed_assets(body.roots)
|
||||
except Exception:
|
||||
logging.exception("seed_assets failed for roots=%s", body.roots)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response({"synced": True, "roots": body.roots}, status=200)
|
||||
|
||||
@ -8,10 +8,8 @@ from pydantic import (
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from app.assets.helpers import RootType
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
@ -27,8 +25,6 @@ class ListAssetsQuery(BaseModel):
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
|
||||
order: Literal["asc", "desc"] = "desc"
|
||||
|
||||
include_public: bool = True
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
@ -61,73 +57,6 @@ class ListAssetsQuery(BaseModel):
|
||||
return None
|
||||
|
||||
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: str | None = None
|
||||
mime_type: str | None = None
|
||||
preview_id: str | None = None
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
if self.name is None and self.mime_type is None and self.preview_id is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, mime_type, preview_id, user_metadata.")
|
||||
return self
|
||||
|
||||
|
||||
class CreateFromHashBody(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
hash: str
|
||||
name: str
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
def _require_blake3(cls, v):
|
||||
s = (v or "").strip().lower()
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return s
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _tags_norm(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||
seen = set()
|
||||
dedup = []
|
||||
for t in out:
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
dedup.append(t)
|
||||
return dedup
|
||||
if isinstance(v, str):
|
||||
return [t.strip().lower() for t in v.split(",") if t.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
@ -136,7 +65,6 @@ class TagsListQuery(BaseModel):
|
||||
offset: int = Field(0, ge=0, le=10_000_000)
|
||||
order: Literal["count_desc", "name_asc"] = "count_desc"
|
||||
include_zero: bool = True
|
||||
include_public: bool = True
|
||||
|
||||
@field_validator("prefix")
|
||||
@classmethod
|
||||
@ -147,188 +75,10 @@ class TagsListQuery(BaseModel):
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
|
||||
@field_validator("tags")
|
||||
@classmethod
|
||||
def normalize_tags(cls, v: list[str]) -> list[str]:
|
||||
out = []
|
||||
for t in v:
|
||||
if not isinstance(t, str):
|
||||
raise TypeError("tags must be strings")
|
||||
tnorm = t.strip().lower()
|
||||
if tnorm:
|
||||
out.append(tnorm)
|
||||
seen = set()
|
||||
deduplicated = []
|
||||
for x in out:
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
deduplicated.append(x)
|
||||
return deduplicated
|
||||
|
||||
|
||||
class TagsRemove(TagsAdd):
|
||||
pass
|
||||
|
||||
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||
|
||||
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||
and the original extension is preserved when available.
|
||||
"""
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
hash: str | None = Field(default=None)
|
||||
|
||||
@field_validator("hash", mode="before")
|
||||
@classmethod
|
||||
def _parse_hash(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip().lower()
|
||||
if not s:
|
||||
return None
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return f"{algo}:{digest}"
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _parse_tags(cls, v):
|
||||
"""
|
||||
Accepts a list of strings (possibly multiple form fields),
|
||||
where each string can be:
|
||||
- JSON array (e.g., '["models","loras","foo"]')
|
||||
- comma-separated ('models, loras, foo')
|
||||
- single token ('models')
|
||||
Returns a normalized, deduplicated, ordered list.
|
||||
"""
|
||||
items: list[str] = []
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
v = [v]
|
||||
|
||||
if isinstance(v, list):
|
||||
for item in v:
|
||||
if item is None:
|
||||
continue
|
||||
s = str(item).strip()
|
||||
if not s:
|
||||
continue
|
||||
if s.startswith("["):
|
||||
try:
|
||||
arr = json.loads(s)
|
||||
if isinstance(arr, list):
|
||||
items.extend(str(x) for x in arr)
|
||||
continue
|
||||
except Exception:
|
||||
pass # fallback to CSV parse below
|
||||
items.extend([p for p in s.split(",") if p.strip()])
|
||||
else:
|
||||
return []
|
||||
|
||||
# normalize + dedupe
|
||||
norm = []
|
||||
seen = set()
|
||||
for t in items:
|
||||
tnorm = str(t).strip().lower()
|
||||
if tnorm and tnorm not in seen:
|
||||
seen.add(tnorm)
|
||||
norm.append(tnorm)
|
||||
return norm
|
||||
|
||||
@field_validator("user_metadata", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v or {}
|
||||
if isinstance(v, str):
|
||||
s = v.strip()
|
||||
if not s:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
except Exception as e:
|
||||
raise ValueError(f"user_metadata must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("user_metadata must be a JSON object")
|
||||
return parsed
|
||||
return {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("tags must be provided and non-empty")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
return self
|
||||
|
||||
|
||||
class UploadAssetFromUrlBody(BaseModel):
|
||||
"""JSON body for URL-based asset upload."""
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
url: str = Field(..., description="HTTP/HTTPS URL to download the asset from")
|
||||
name: str = Field(..., max_length=512, description="Display name for the asset")
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: str | None = None
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def _validate_url(cls, v):
|
||||
s = (v or "").strip()
|
||||
if not s:
|
||||
raise ValueError("url must not be empty")
|
||||
if not (s.startswith("http://") or s.startswith("https://")):
|
||||
raise ValueError("url must start with http:// or https://")
|
||||
return s
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _parse_tags(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||
seen = set()
|
||||
dedup = []
|
||||
for t in out:
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
dedup.append(t)
|
||||
return dedup
|
||||
return []
|
||||
|
||||
@field_validator("user_metadata", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v or {}
|
||||
return {}
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
@ -342,49 +92,3 @@ class UploadAssetFromUrlBody(BaseModel):
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
|
||||
|
||||
class ScheduleAssetScanBody(BaseModel):
|
||||
roots: list[RootType] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class TagsRefineQuery(BaseModel):
|
||||
"""Query parameters for tag histogram/refinement endpoint."""
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: str | None = None
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
limit: conint(ge=1, le=1000) = 100
|
||||
include_public: bool = True
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [t.strip() for t in v.split(",") if t.strip()]
|
||||
if isinstance(v, list):
|
||||
out: list[str] = []
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v
|
||||
if isinstance(v, str) and v.strip():
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("metadata_filter must be a JSON object")
|
||||
return parsed
|
||||
return None
|
||||
|
||||
@ -29,21 +29,6 @@ class AssetsList(BaseModel):
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _ser_updated(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -63,10 +48,6 @@ class AssetDetail(BaseModel):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(AssetDetail):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
@ -77,26 +58,3 @@ class TagsList(BaseModel):
|
||||
tags: list[TagUsage] = Field(default_factory=list)
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
added: list[str] = Field(default_factory=list)
|
||||
already_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagsRemove(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagHistogramEntry(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
|
||||
|
||||
class TagHistogramResponse(BaseModel):
|
||||
tags: list[TagHistogramEntry] = Field(default_factory=list)
|
||||
|
||||
@ -1,17 +1,9 @@
|
||||
import os
|
||||
import logging
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Iterable, Any
|
||||
from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy import select, exists, func
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import (
|
||||
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
|
||||
)
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import escape_like_prefix, normalize_tags
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
@ -23,22 +15,6 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
@ -66,7 +42,6 @@ def apply_tag_filters(
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
@ -119,11 +94,7 @@ def apply_metadata_filter(
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
@ -134,39 +105,9 @@ def asset_exists_by_hash(
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
def asset_info_exists_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (session.execute(q)).first() is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
def get_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
) -> AssetInfo | None:
|
||||
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
@ -236,7 +177,6 @@ def list_asset_infos_page(
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
@ -268,489 +208,6 @@ def fetch_asset_info_asset_and_tags(
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
def fetch_asset_info_and_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
def list_cache_states_by_asset_id(
|
||||
session: Session, *, asset_id: str
|
||||
) -> Sequence[AssetCacheState]:
|
||||
return (
|
||||
session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def touch_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: datetime | None = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
def create_asset_info_for_existing_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
with session.begin_nested():
|
||||
session.add(info)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def set_asset_info_tags(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
def replace_asset_info_metadata_projection(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: dict | None = None,
|
||||
) -> None:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def ingest_fs_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
|
||||
if preview_id:
|
||||
if not session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
res = session.execute(
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["asset_created"] = True
|
||||
asset = (
|
||||
session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
ins = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
|
||||
res = session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
try:
|
||||
with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
def update_asset_info_full(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
if mime_type is not None and info.asset:
|
||||
if info.asset.mime_type != mime_type:
|
||||
info.asset.mime_type = mime_type
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def delete_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str,
|
||||
) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
@ -808,163 +265,3 @@ def list_tags_with_usage(
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def add_tags_to_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
def remove_tags_from_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_asset_info_preview(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import contextlib
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@ -88,40 +87,6 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
def ensure_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
@ -148,6 +113,7 @@ def compute_relative_filename(file_path: str) -> str | None:
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
@ -249,64 +215,3 @@ def collect_models_files() -> list[str]:
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def project_kv(key: str, value):
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
@ -1,34 +1,13 @@
|
||||
import os
|
||||
import mimetypes
|
||||
import contextlib
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out, schemas_in
|
||||
from app.assets.api import schemas_out
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
fetch_asset_info_and_asset,
|
||||
create_asset_info_for_existing_asset,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
delete_asset_info_by_id,
|
||||
list_cache_states_by_asset_id,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
get_asset_tags,
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
pick_best_live_path,
|
||||
ingest_fs_asset,
|
||||
set_asset_info_preview,
|
||||
)
|
||||
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
|
||||
from app.assets.database.models import Asset
|
||||
import app.assets.hashing as hashing
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
@ -40,28 +19,11 @@ def _safe_sort_field(requested: str | None) -> str:
|
||||
return "created_at"
|
||||
|
||||
|
||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
return fallback
|
||||
|
||||
|
||||
def asset_exists(*, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def list_assets(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
@ -71,7 +33,6 @@ def list_assets(
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
owner_id: str = "",
|
||||
include_public: bool = True,
|
||||
) -> schemas_out.AssetsList:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
@ -115,12 +76,7 @@ def list_assets(
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
|
||||
def get_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
@ -141,356 +97,6 @@ def get_asset(
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_content_for_download(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
with create_session() as session:
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
def upload_asset_from_temp_path(
|
||||
spec: schemas_in.UploadAssetSpec,
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: str | None = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
try:
|
||||
digest = hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or spec.name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
with create_session() as session:
|
||||
result = ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
|
||||
|
||||
def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
mime_type=mime_type,
|
||||
user_metadata=user_metadata,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
if preview_id is not None:
|
||||
set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_id if preview_id else None,
|
||||
)
|
||||
|
||||
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
def create_asset_from_hash(
|
||||
*,
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetCreated | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=canonical,
|
||||
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
|
||||
def add_tags_to_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
|
||||
|
||||
def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsRemove(**data)
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
@ -498,7 +104,6 @@ def list_tags(
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
include_public: bool = True,
|
||||
) -> schemas_out.TagsList:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
@ -516,17 +121,3 @@ def list_tags(
|
||||
|
||||
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
|
||||
|
||||
|
||||
def get_tag_histogram(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 100,
|
||||
include_public: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagHistogramResponse:
|
||||
# TODO: Implement actual histogram query in queries.py
|
||||
return schemas_out.TagHistogramResponse(tags=[])
|
||||
|
||||
@ -8,7 +8,6 @@ class LatentFormat:
|
||||
latent_rgb_factors_bias = None
|
||||
latent_rgb_factors_reshape = None
|
||||
taesd_decoder_name = None
|
||||
spacial_downscale_ratio = 8
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent * self.scale_factor
|
||||
@ -182,7 +181,6 @@ class Flux(SD3):
|
||||
|
||||
class Flux2(LatentFormat):
|
||||
latent_channels = 128
|
||||
spacial_downscale_ratio = 16
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors =[
|
||||
@ -751,7 +749,6 @@ class ACEAudio(LatentFormat):
|
||||
|
||||
class ChromaRadiance(LatentFormat):
|
||||
latent_channels = 3
|
||||
spacial_downscale_ratio = 1
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
|
||||
@ -18,12 +18,12 @@ class CompressedTimestep:
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
"""
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
||||
"""
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
@ -215,9 +215,22 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def forward(
|
||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
||||
self,
|
||||
x: Tuple[torch.Tensor, torch.Tensor],
|
||||
v_context=None,
|
||||
a_context=None,
|
||||
attention_mask=None,
|
||||
v_timestep=None,
|
||||
a_timestep=None,
|
||||
v_pe=None,
|
||||
a_pe=None,
|
||||
v_cross_pe=None,
|
||||
a_cross_pe=None,
|
||||
v_cross_scale_shift_timestep=None,
|
||||
a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None,
|
||||
a_cross_gate_timestep=None,
|
||||
transformer_options=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
@ -227,102 +240,144 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||
|
||||
# video
|
||||
if run_vx:
|
||||
# video self-attention
|
||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||
vshift_msa, vscale_msa, vgate_msa = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
del vshift_msa, vscale_msa
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
||||
del norm_vx
|
||||
# video cross-attention
|
||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||
vx.addcmul_(attn1_out, vgate_msa)
|
||||
del vgate_msa, attn1_out
|
||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
||||
vx += self.attn2(
|
||||
comfy.ldm.common_dit.rms_norm(vx),
|
||||
context=v_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del vshift_msa, vscale_msa, vgate_msa
|
||||
|
||||
# audio
|
||||
if run_ax:
|
||||
# audio self-attention
|
||||
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
|
||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||
del ashift_msa, ascale_msa
|
||||
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
del norm_ax
|
||||
# audio cross-attention
|
||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||
ax.addcmul_(attn1_out, agate_msa)
|
||||
del agate_msa, attn1_out
|
||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
ashift_msa, ascale_msa, agate_msa = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
# video - audio cross attention.
|
||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||
ax += (
|
||||
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
* agate_msa
|
||||
)
|
||||
ax += self.audio_attn2(
|
||||
comfy.ldm.common_dit.rms_norm(ax),
|
||||
context=a_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del ashift_msa, ascale_msa, agate_msa
|
||||
|
||||
# Audio - Video cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
# norm3
|
||||
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||
|
||||
# audio to video cross attention
|
||||
(
|
||||
scale_ca_audio_hidden_states_a2v,
|
||||
shift_ca_audio_hidden_states_a2v,
|
||||
scale_ca_audio_hidden_states_v2a,
|
||||
shift_ca_audio_hidden_states_v2a,
|
||||
gate_out_v2a,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
a_cross_scale_shift_timestep,
|
||||
a_cross_gate_timestep,
|
||||
)
|
||||
|
||||
(
|
||||
scale_ca_video_hidden_states_a2v,
|
||||
shift_ca_video_hidden_states_a2v,
|
||||
scale_ca_video_hidden_states_v2a,
|
||||
shift_ca_video_hidden_states_v2a,
|
||||
gate_out_a2v,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
v_cross_scale_shift_timestep,
|
||||
v_cross_gate_timestep,
|
||||
)
|
||||
|
||||
if run_a2v:
|
||||
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
|
||||
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
||||
+ shift_ca_video_hidden_states_a2v
|
||||
)
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
||||
+ shift_ca_audio_hidden_states_a2v
|
||||
)
|
||||
vx += (
|
||||
self.audio_to_video_attn(
|
||||
vx_scaled,
|
||||
context=ax_scaled,
|
||||
pe=v_cross_pe,
|
||||
k_pe=a_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_a2v
|
||||
)
|
||||
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
|
||||
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
|
||||
del gate_out_a2v
|
||||
del scale_ca_video_hidden_states_a2v,\
|
||||
shift_ca_video_hidden_states_a2v,\
|
||||
scale_ca_audio_hidden_states_a2v,\
|
||||
shift_ca_audio_hidden_states_a2v,\
|
||||
|
||||
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
|
||||
del vx_scaled, ax_scaled
|
||||
|
||||
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
|
||||
vx.addcmul_(a2v_out, gate_out_a2v)
|
||||
del gate_out_a2v, a2v_out
|
||||
|
||||
# video to audio cross attention
|
||||
if run_v2a:
|
||||
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
|
||||
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
||||
+ shift_ca_audio_hidden_states_v2a
|
||||
)
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
||||
+ shift_ca_video_hidden_states_v2a
|
||||
)
|
||||
ax += (
|
||||
self.video_to_audio_attn(
|
||||
ax_scaled,
|
||||
context=vx_scaled,
|
||||
pe=a_cross_pe,
|
||||
k_pe=v_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_v2a
|
||||
)
|
||||
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
|
||||
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
|
||||
del gate_out_v2a
|
||||
del scale_ca_video_hidden_states_v2a,\
|
||||
shift_ca_video_hidden_states_v2a,\
|
||||
scale_ca_audio_hidden_states_v2a,\
|
||||
shift_ca_audio_hidden_states_v2a
|
||||
|
||||
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
|
||||
del ax_scaled, vx_scaled
|
||||
|
||||
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
|
||||
ax.addcmul_(v2a_out, gate_out_v2a)
|
||||
del gate_out_v2a, v2a_out
|
||||
|
||||
del vx_norm3, ax_norm3
|
||||
|
||||
# video feedforward
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||
del vshift_mlp, vscale_mlp
|
||||
vx += self.ff(vx_scaled) * vgate_mlp
|
||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||
|
||||
ff_out = self.ff(vx_scaled)
|
||||
del vx_scaled
|
||||
|
||||
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
|
||||
vx.addcmul_(ff_out, vgate_mlp)
|
||||
del vgate_mlp, ff_out
|
||||
|
||||
# audio feedforward
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
|
||||
ashift_mlp, ascale_mlp, agate_mlp = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||
del ashift_mlp, ascale_mlp
|
||||
ax += self.audio_ff(ax_scaled) * agate_mlp
|
||||
|
||||
ff_out = self.audio_ff(ax_scaled)
|
||||
del ax_scaled
|
||||
del ashift_mlp, ascale_mlp, agate_mlp
|
||||
|
||||
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
|
||||
ax.addcmul_(ff_out, agate_mlp)
|
||||
del agate_mlp, ff_out
|
||||
|
||||
return vx, ax
|
||||
|
||||
@ -534,20 +589,9 @@ class LTXAVModel(LTXVModel):
|
||||
audio_length = kwargs.get("audio_length", 0)
|
||||
# Separate audio and video latents
|
||||
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||
|
||||
has_spatial_mask = False
|
||||
if denoise_mask is not None:
|
||||
# check if any frame has spatial variation (inpainting)
|
||||
for frame_idx in range(denoise_mask.shape[2]):
|
||||
frame_mask = denoise_mask[0, 0, frame_idx]
|
||||
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
|
||||
has_spatial_mask = True
|
||||
break
|
||||
|
||||
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||
)
|
||||
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
@ -574,9 +618,8 @@ class LTXAVModel(LTXVModel):
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||
v_patches_per_frame = None
|
||||
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||
if orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
@ -619,11 +662,10 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||
]
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@ -20,29 +20,22 @@ class CausalConv3d(ops.Conv3d):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._padding = 2 * self.padding[0]
|
||||
self.padding = (0, self.padding[1], self.padding[2])
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||||
if cache_list is not None:
|
||||
cache_x = cache_list[cache_idx]
|
||||
cache_list[cache_idx] = None
|
||||
|
||||
if cache_x is None and x.shape[2] == 1:
|
||||
#Fast path - the op will pad for use by truncating the weight
|
||||
#and save math on a pile of zeros.
|
||||
return super().forward(x, autopad="causal_zero")
|
||||
|
||||
if self._padding > 0:
|
||||
padding_needed = self._padding
|
||||
if cache_x is not None:
|
||||
cache_x = cache_x.to(x.device)
|
||||
padding_needed = max(0, padding_needed - cache_x.shape[2])
|
||||
padding_shape = list(x.shape)
|
||||
padding_shape[2] = padding_needed
|
||||
padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype)
|
||||
x = torch_cat_if_needed([padding, cache_x, x], dim=2)
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
del cache_x
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
@ -260,7 +260,6 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
key_map[k[:-len(".weight")]] = to #DiffSynth lora format
|
||||
for k in sdk:
|
||||
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
|
||||
if k.endswith(".weight") and ".linear1." in k:
|
||||
|
||||
10
comfy/ops.py
10
comfy/ops.py
@ -203,9 +203,7 @@ class disable_weight_init:
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs):
|
||||
if autopad == "causal_zero":
|
||||
weight = weight[:, :, -input.shape[2]:, :, :]
|
||||
def _conv_forward(self, input, weight, bias, *args, **kwargs):
|
||||
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
|
||||
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
|
||||
if bias is not None:
|
||||
@ -214,15 +212,15 @@ class disable_weight_init:
|
||||
else:
|
||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||
|
||||
def forward_comfy_cast_weights(self, input, autopad=None):
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias, autopad=autopad)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
@ -37,18 +37,12 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
|
||||
return noises
|
||||
|
||||
def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None):
|
||||
def fix_empty_latent_channels(model, latent_image):
|
||||
if latent_image.is_nested:
|
||||
return latent_image
|
||||
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||
if torch.count_nonzero(latent_image) == 0:
|
||||
if latent_format.latent_channels != latent_image.shape[1]:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
if downscale_ratio_spacial is not None:
|
||||
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
|
||||
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio
|
||||
latent_image = comfy.utils.common_upscale(latent_image, round(latent_image.shape[-1] * ratio), round(latent_image.shape[-2] * ratio), "nearest-exact", crop="disabled")
|
||||
|
||||
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
|
||||
latent_image = latent_image.unsqueeze(2)
|
||||
return latent_image
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .video_types import VideoInput
|
||||
from .video_types import VideoInput, VideoOp, SliceOp
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"VideoInput",
|
||||
"VideoOp",
|
||||
"SliceOp",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
]
|
||||
|
||||
@ -1,11 +1,48 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from fractions import Fraction
|
||||
from typing import Optional, Union, IO
|
||||
import copy
|
||||
import io
|
||||
import av
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
class VideoOp(ABC):
|
||||
"""Base class for lazy video operations."""
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, components: VideoComponents) -> VideoComponents:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_frame_count(self, input_frame_count: int) -> int:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SliceOp(VideoOp):
|
||||
"""Extract a range of frames from the video."""
|
||||
start_frame: int
|
||||
frame_count: int
|
||||
|
||||
def apply(self, components: VideoComponents) -> VideoComponents:
|
||||
total = components.images.shape[0]
|
||||
start = max(0, min(self.start_frame, total))
|
||||
end = min(start + self.frame_count, total)
|
||||
return VideoComponents(
|
||||
images=components.images[start:end],
|
||||
audio=components.audio,
|
||||
frame_rate=components.frame_rate,
|
||||
metadata=getattr(components, 'metadata', None),
|
||||
)
|
||||
|
||||
def compute_frame_count(self, input_frame_count: int) -> int:
|
||||
start = max(0, min(self.start_frame, input_frame_count))
|
||||
return min(self.frame_count, input_frame_count - start)
|
||||
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
Abstract base class for video input types.
|
||||
@ -21,6 +58,12 @@ class VideoInput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def sliced(self, start_frame: int, frame_count: int) -> "VideoInput":
|
||||
"""Return a copy of this video with a slice operation appended."""
|
||||
new = copy.copy(self)
|
||||
new._operations = getattr(self, '_operations', []) + [SliceOp(start_frame, frame_count)]
|
||||
return new
|
||||
|
||||
@abstractmethod
|
||||
def save_to(
|
||||
self,
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from .video_types import VideoFromFile, VideoFromComponents
|
||||
from .._input import SliceOp
|
||||
|
||||
__all__ = [
|
||||
# Implementations
|
||||
"VideoFromFile",
|
||||
"VideoFromComponents",
|
||||
"SliceOp",
|
||||
]
|
||||
|
||||
@ -3,7 +3,7 @@ from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from .._input import AudioInput, VideoInput
|
||||
from .._input import AudioInput, VideoInput, VideoOp
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
@ -63,6 +63,8 @@ class VideoFromFile(VideoInput):
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
self._operations: list[VideoOp] = []
|
||||
self.__materialized: Optional[VideoFromComponents] = None
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
@ -161,6 +163,10 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
if frame_count == 0:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||
|
||||
# Apply operations to get final frame count
|
||||
for op in self._operations:
|
||||
frame_count = op.compute_frame_count(frame_count)
|
||||
return frame_count
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
@ -239,10 +245,18 @@ class VideoFromFile(VideoInput):
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if self.__materialized is not None:
|
||||
return self.__materialized.get_components()
|
||||
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return self.get_components_internal(container)
|
||||
components = self.get_components_internal(container)
|
||||
for op in self._operations:
|
||||
components = op.apply(components)
|
||||
self.__materialized = VideoFromComponents(components)
|
||||
self._operations = []
|
||||
return components
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def save_to(
|
||||
@ -317,14 +331,27 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
self.__components = components
|
||||
self._operations: list[VideoOp] = []
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if self._operations:
|
||||
components = self.__components
|
||||
for op in self._operations:
|
||||
components = op.apply(components)
|
||||
self.__components = components
|
||||
self._operations = []
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
)
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
count = int(self.__components.images.shape[0])
|
||||
for op in self._operations:
|
||||
count = op.compute_frame_count(count)
|
||||
return count
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
@ -332,6 +359,9 @@ class VideoFromComponents(VideoInput):
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
# Materialize ops before saving
|
||||
components = self.get_components()
|
||||
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
@ -345,22 +375,22 @@ class VideoFromComponents(VideoInput):
|
||||
for key, value in metadata.items():
|
||||
output.metadata[key] = json.dumps(value)
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
frame_rate = Fraction(round(components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.width = components.images.shape[2]
|
||||
video_stream.height = components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
if components.audio:
|
||||
audio_sample_rate = int(components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
for i, frame in enumerate(components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
@ -371,9 +401,9 @@ class VideoFromComponents(VideoInput):
|
||||
packet = video_stream.encode(None)
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
if audio_stream and components.audio:
|
||||
waveform = components.audio['waveform']
|
||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * components.images.shape[0])]
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
|
||||
frame.sample_rate = audio_sample_rate
|
||||
frame.pts = 0
|
||||
|
||||
@ -28,7 +28,6 @@ class AlignYourStepsScheduler(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="AlignYourStepsScheduler",
|
||||
search_aliases=["AYS scheduler"],
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
|
||||
|
||||
@ -71,7 +71,6 @@ class CLIPAttentionMultiply(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="CLIPAttentionMultiply",
|
||||
search_aliases=["clip attention scale", "text encoder attention"],
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
|
||||
@ -69,7 +69,6 @@ class VAEEncodeAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEEncodeAudio",
|
||||
search_aliases=["audio to latent"],
|
||||
display_name="VAE Encode Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
@ -98,7 +97,6 @@ class VAEDecodeAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeAudio",
|
||||
search_aliases=["latent to audio"],
|
||||
display_name="VAE Decode Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
@ -124,7 +122,6 @@ class SaveAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudio",
|
||||
search_aliases=["export flac"],
|
||||
display_name="Save Audio (FLAC)",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -149,7 +146,6 @@ class SaveAudioMP3(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudioMP3",
|
||||
search_aliases=["export mp3"],
|
||||
display_name="Save Audio (MP3)",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -177,7 +173,6 @@ class SaveAudioOpus(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudioOpus",
|
||||
search_aliases=["export opus"],
|
||||
display_name="Save Audio (Opus)",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -205,7 +200,6 @@ class PreviewAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="PreviewAudio",
|
||||
search_aliases=["play audio"],
|
||||
display_name="Preview Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -265,7 +259,6 @@ class LoadAudio(IO.ComfyNode):
|
||||
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
||||
return IO.Schema(
|
||||
node_id="LoadAudio",
|
||||
search_aliases=["import audio", "open audio", "audio file"],
|
||||
display_name="Load Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -303,7 +296,6 @@ class RecordAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RecordAudio",
|
||||
search_aliases=["microphone input", "audio capture", "voice input"],
|
||||
display_name="Record Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -328,7 +320,6 @@ class TrimAudioDuration(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TrimAudioDuration",
|
||||
search_aliases=["cut audio", "audio clip", "shorten audio"],
|
||||
display_name="Trim Audio Duration",
|
||||
description="Trim audio tensor into chosen time range.",
|
||||
category="audio",
|
||||
@ -381,7 +372,6 @@ class SplitAudioChannels(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SplitAudioChannels",
|
||||
search_aliases=["stereo to mono"],
|
||||
display_name="Split Audio Channels",
|
||||
description="Separates the audio into left and right channels.",
|
||||
category="audio",
|
||||
@ -482,7 +472,6 @@ class AudioConcat(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="AudioConcat",
|
||||
search_aliases=["join audio", "combine audio", "append audio"],
|
||||
display_name="Audio Concat",
|
||||
description="Concatenates the audio1 to audio2 in the specified direction.",
|
||||
category="audio",
|
||||
@ -530,7 +519,6 @@ class AudioMerge(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="AudioMerge",
|
||||
search_aliases=["mix audio", "overlay audio", "layer audio"],
|
||||
display_name="Audio Merge",
|
||||
description="Combine two audio tracks by overlaying their waveforms.",
|
||||
category="audio",
|
||||
@ -591,7 +579,6 @@ class AudioAdjustVolume(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="AudioAdjustVolume",
|
||||
search_aliases=["audio gain", "loudness", "audio level"],
|
||||
display_name="Audio Adjust Volume",
|
||||
category="audio",
|
||||
inputs=[
|
||||
@ -627,7 +614,6 @@ class EmptyAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyAudio",
|
||||
search_aliases=["blank audio"],
|
||||
display_name="Empty Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
|
||||
@ -10,7 +10,6 @@ class Canny(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Canny",
|
||||
search_aliases=["edge detection", "outline", "contour detection", "line art"],
|
||||
category="image/preprocessors",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
|
||||
@ -109,7 +109,6 @@ class PorterDuffImageComposite(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PorterDuffImageComposite",
|
||||
search_aliases=["alpha composite", "blend modes", "layer blend", "transparency blend"],
|
||||
display_name="Porter-Duff Image Composite",
|
||||
category="mask/compositing",
|
||||
inputs=[
|
||||
@ -166,7 +165,6 @@ class SplitImageWithAlpha(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SplitImageWithAlpha",
|
||||
search_aliases=["extract alpha", "separate transparency", "remove alpha"],
|
||||
display_name="Split Image with Alpha",
|
||||
category="mask/compositing",
|
||||
inputs=[
|
||||
@ -190,7 +188,6 @@ class JoinImageWithAlpha(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="JoinImageWithAlpha",
|
||||
search_aliases=["add transparency", "apply alpha", "composite alpha", "RGBA"],
|
||||
display_name="Join Image with Alpha",
|
||||
category="mask/compositing",
|
||||
inputs=[
|
||||
|
||||
@ -38,7 +38,6 @@ class ControlNetInpaintingAliMamaApply(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ControlNetInpaintingAliMamaApply",
|
||||
search_aliases=["masked controlnet"],
|
||||
category="conditioning/controlnet",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
|
||||
@ -297,7 +297,6 @@ class ExtendIntermediateSigmas(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ExtendIntermediateSigmas",
|
||||
search_aliases=["interpolate sigmas"],
|
||||
category="sampling/custom_sampling/sigmas",
|
||||
inputs=[
|
||||
io.Sigmas.Input("sigmas"),
|
||||
@ -741,7 +740,7 @@ class SamplerCustom(io.ComfyNode):
|
||||
latent = latent_image
|
||||
latent_image = latent["samples"]
|
||||
latent = latent.copy()
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
|
||||
latent["samples"] = latent_image
|
||||
|
||||
if not add_noise:
|
||||
@ -760,7 +759,6 @@ class SamplerCustom(io.ComfyNode):
|
||||
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
||||
|
||||
out = latent.copy()
|
||||
out.pop("downscale_ratio_spacial", None)
|
||||
out["samples"] = samples
|
||||
if "x0" in x0_output:
|
||||
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
|
||||
@ -858,7 +856,6 @@ class DualCFGGuider(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DualCFGGuider",
|
||||
search_aliases=["dual prompt guidance"],
|
||||
category="sampling/custom_sampling/guiders",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
@ -886,7 +883,6 @@ class DisableNoise(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DisableNoise",
|
||||
search_aliases=["zero noise"],
|
||||
category="sampling/custom_sampling/noise",
|
||||
inputs=[],
|
||||
outputs=[io.Noise.Output()]
|
||||
@ -940,7 +936,7 @@ class SamplerCustomAdvanced(io.ComfyNode):
|
||||
latent = latent_image
|
||||
latent_image = latent["samples"]
|
||||
latent = latent.copy()
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image, latent.get("downscale_ratio_spacial", None))
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image)
|
||||
latent["samples"] = latent_image
|
||||
|
||||
noise_mask = None
|
||||
@ -955,7 +951,6 @@ class SamplerCustomAdvanced(io.ComfyNode):
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
|
||||
out = latent.copy()
|
||||
out.pop("downscale_ratio_spacial", None)
|
||||
out["samples"] = samples
|
||||
if "x0" in x0_output:
|
||||
x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
|
||||
@ -1024,7 +1019,6 @@ class ManualSigmas(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ManualSigmas",
|
||||
search_aliases=["custom noise schedule", "define sigmas"],
|
||||
category="_for_testing/custom_sampling",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
|
||||
@ -1223,11 +1223,11 @@ class ResolutionBucket(io.ComfyNode):
|
||||
|
||||
class MakeTrainingDataset(io.ComfyNode):
|
||||
"""Encode images with VAE and texts with CLIP to create a training dataset."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MakeTrainingDataset",
|
||||
search_aliases=["encode dataset"],
|
||||
display_name="Make Training Dataset",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
@ -1309,11 +1309,11 @@ class MakeTrainingDataset(io.ComfyNode):
|
||||
|
||||
class SaveTrainingDataset(io.ComfyNode):
|
||||
"""Save encoded training dataset (latents + conditioning) to disk."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveTrainingDataset",
|
||||
search_aliases=["export training data"],
|
||||
display_name="Save Training Dataset",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
@ -1410,11 +1410,11 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
|
||||
class LoadTrainingDataset(io.ComfyNode):
|
||||
"""Load encoded training dataset from disk."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadTrainingDataset",
|
||||
search_aliases=["import dataset", "training data"],
|
||||
display_name="Load Training Dataset",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
|
||||
@ -11,7 +11,6 @@ class DifferentialDiffusion(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DifferentialDiffusion",
|
||||
search_aliases=["inpaint gradient", "variable denoise strength"],
|
||||
display_name="Differential Diffusion",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
|
||||
@ -58,7 +58,6 @@ class FreSca(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="FreSca",
|
||||
search_aliases=["frequency guidance"],
|
||||
display_name="FreSca",
|
||||
category="_for_testing",
|
||||
description="Applies frequency-dependent scaling to the guidance",
|
||||
|
||||
@ -38,7 +38,6 @@ class CLIPTextEncodeHiDream(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeHiDream",
|
||||
search_aliases=["hidream prompt"],
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
|
||||
@ -259,7 +259,6 @@ class SetClipHooks:
|
||||
return (clip,)
|
||||
|
||||
class ConditioningTimestepsRange:
|
||||
SEARCH_ALIASES = ["prompt scheduling", "timestep segments", "conditioning phases"]
|
||||
NodeId = 'ConditioningTimestepsRange'
|
||||
NodeName = 'Timesteps Range'
|
||||
@classmethod
|
||||
@ -469,7 +468,6 @@ class SetHookKeyframes:
|
||||
return (hooks,)
|
||||
|
||||
class CreateHookKeyframe:
|
||||
SEARCH_ALIASES = ["hook scheduling", "strength animation", "timed hook"]
|
||||
NodeId = 'CreateHookKeyframe'
|
||||
NodeName = 'Create Hook Keyframe'
|
||||
@classmethod
|
||||
@ -499,7 +497,6 @@ class CreateHookKeyframe:
|
||||
return (prev_hook_kf,)
|
||||
|
||||
class CreateHookKeyframesInterpolated:
|
||||
SEARCH_ALIASES = ["ease hook strength", "smooth hook transition", "interpolate keyframes"]
|
||||
NodeId = 'CreateHookKeyframesInterpolated'
|
||||
NodeName = 'Create Hook Keyframes Interp.'
|
||||
@classmethod
|
||||
@ -547,7 +544,6 @@ class CreateHookKeyframesInterpolated:
|
||||
return (prev_hook_kf,)
|
||||
|
||||
class CreateHookKeyframesFromFloats:
|
||||
SEARCH_ALIASES = ["batch keyframes", "strength list to keyframes"]
|
||||
NodeId = 'CreateHookKeyframesFromFloats'
|
||||
NodeName = 'Create Hook Keyframes From Floats'
|
||||
@classmethod
|
||||
@ -622,7 +618,6 @@ class SetModelHooksOnCond:
|
||||
# Combine Hooks
|
||||
#------------------------------------------
|
||||
class CombineHooks:
|
||||
SEARCH_ALIASES = ["merge hooks"]
|
||||
NodeId = 'CombineHooks2'
|
||||
NodeName = 'Combine Hooks [2]'
|
||||
@classmethod
|
||||
|
||||
@ -618,7 +618,6 @@ class SaveGLB(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveGLB",
|
||||
search_aliases=["export 3d model", "save mesh"],
|
||||
category="3d",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
|
||||
@ -22,7 +22,6 @@ class ImageCrop(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCrop",
|
||||
search_aliases=["trim"],
|
||||
display_name="Image Crop",
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
@ -52,7 +51,6 @@ class RepeatImageBatch(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RepeatImageBatch",
|
||||
search_aliases=["duplicate image", "clone image"],
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -74,7 +72,6 @@ class ImageFromBatch(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageFromBatch",
|
||||
search_aliases=["select image", "pick from batch", "extract image"],
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -100,7 +97,6 @@ class ImageAddNoise(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageAddNoise",
|
||||
search_aliases=["film grain"],
|
||||
category="image",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -198,11 +194,11 @@ class SaveAnimatedPNG(IO.ComfyNode):
|
||||
|
||||
class ImageStitch(IO.ComfyNode):
|
||||
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageStitch",
|
||||
search_aliases=["combine images", "join images", "concatenate images", "side by side"],
|
||||
display_name="Image Stitch",
|
||||
description="Stitches image2 to image1 in the specified direction.\n"
|
||||
"If image2 is not provided, returns image1 unchanged.\n"
|
||||
@ -373,11 +369,11 @@ class ImageStitch(IO.ComfyNode):
|
||||
|
||||
|
||||
class ResizeAndPadImage(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ResizeAndPadImage",
|
||||
search_aliases=["fit to size"],
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -424,11 +420,11 @@ class ResizeAndPadImage(IO.ComfyNode):
|
||||
|
||||
|
||||
class SaveSVGNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveSVGNode",
|
||||
search_aliases=["export vector", "save vector graphics"],
|
||||
description="Save SVG files on disk.",
|
||||
category="image/save",
|
||||
inputs=[
|
||||
@ -496,11 +492,11 @@ class SaveSVGNode(IO.ComfyNode):
|
||||
|
||||
|
||||
class GetImageSize(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GetImageSize",
|
||||
search_aliases=["dimensions", "resolution", "image info"],
|
||||
display_name="Get Image Size",
|
||||
description="Returns width and height of the image, and passes it through unchanged.",
|
||||
category="image",
|
||||
@ -531,11 +527,11 @@ class GetImageSize(IO.ComfyNode):
|
||||
|
||||
|
||||
class ImageRotate(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageRotate",
|
||||
search_aliases=["turn", "flip orientation"],
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -561,11 +557,11 @@ class ImageRotate(IO.ComfyNode):
|
||||
|
||||
|
||||
class ImageFlip(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageFlip",
|
||||
search_aliases=["mirror", "reflect"],
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
|
||||
@ -104,7 +104,6 @@ class CLIPTextEncodeKandinsky5(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeKandinsky5",
|
||||
search_aliases=["kandinsky prompt"],
|
||||
category="advanced/conditioning/kandinsky5",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
|
||||
@ -21,7 +21,6 @@ class LatentAdd(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentAdd",
|
||||
search_aliases=["combine latents", "sum latents"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
@ -48,7 +47,6 @@ class LatentSubtract(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentSubtract",
|
||||
search_aliases=["difference latent", "remove features"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
@ -75,7 +73,6 @@ class LatentMultiply(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentMultiply",
|
||||
search_aliases=["scale latent", "amplify latent", "latent gain"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
@ -99,7 +96,6 @@ class LatentInterpolate(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentInterpolate",
|
||||
search_aliases=["blend latent", "mix latent", "lerp latent", "transition"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
@ -138,7 +134,6 @@ class LatentConcat(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentConcat",
|
||||
search_aliases=["join latents", "stitch latents"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
@ -178,7 +173,6 @@ class LatentCut(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentCut",
|
||||
search_aliases=["crop latent", "slice latent", "extract region"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
@ -219,7 +213,6 @@ class LatentCutToBatch(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentCutToBatch",
|
||||
search_aliases=["slice to batch", "split latent", "tile latent"],
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
@ -261,7 +254,6 @@ class LatentBatch(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentBatch",
|
||||
search_aliases=["combine latents", "merge latents", "join latents"],
|
||||
category="latent/batch",
|
||||
is_deprecated=True,
|
||||
inputs=[
|
||||
@ -318,7 +310,6 @@ class LatentApplyOperation(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentApplyOperation",
|
||||
search_aliases=["transform latent"],
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
@ -374,7 +365,6 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentOperationTonemapReinhard",
|
||||
search_aliases=["hdr latent"],
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
|
||||
@ -75,7 +75,6 @@ class Preview3D(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Preview3D",
|
||||
search_aliases=["view mesh", "3d viewer"],
|
||||
display_name="Preview 3D & Animation",
|
||||
category="3d",
|
||||
is_experimental=True,
|
||||
|
||||
@ -224,7 +224,6 @@ class ConvertStringToComboNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConvertStringToComboNode",
|
||||
search_aliases=["string to dropdown", "text to combo"],
|
||||
display_name="Convert String to Combo",
|
||||
category="logic",
|
||||
inputs=[io.String.Input("string")],
|
||||
@ -240,7 +239,6 @@ class InvertBooleanNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="InvertBooleanNode",
|
||||
search_aliases=["not", "toggle", "negate", "flip boolean"],
|
||||
display_name="Invert Boolean",
|
||||
category="logic",
|
||||
inputs=[io.Boolean.Input("boolean")],
|
||||
|
||||
@ -78,7 +78,6 @@ class LoraSave(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoraSave",
|
||||
search_aliases=["export lora"],
|
||||
display_name="Extract and Save Lora",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
|
||||
@ -79,7 +79,6 @@ class CLIPTextEncodeLumina2(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeLumina2",
|
||||
search_aliases=["lumina prompt"],
|
||||
display_name="CLIP Text Encode for Lumina2",
|
||||
category="conditioning",
|
||||
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "
|
||||
|
||||
@ -50,7 +50,6 @@ class LatentCompositeMasked(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="LatentCompositeMasked",
|
||||
search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"],
|
||||
category="latent",
|
||||
inputs=[
|
||||
IO.Latent.Input("destination"),
|
||||
@ -79,7 +78,6 @@ class ImageCompositeMasked(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCompositeMasked",
|
||||
search_aliases=["paste image", "overlay", "layer"],
|
||||
category="image",
|
||||
inputs=[
|
||||
IO.Image.Input("destination"),
|
||||
@ -107,7 +105,6 @@ class MaskToImage(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MaskToImage",
|
||||
search_aliases=["convert mask"],
|
||||
display_name="Convert Mask to Image",
|
||||
category="mask",
|
||||
inputs=[
|
||||
@ -129,7 +126,6 @@ class ImageToMask(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageToMask",
|
||||
search_aliases=["extract channel", "channel to mask"],
|
||||
display_name="Convert Image to Mask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
@ -153,7 +149,6 @@ class ImageColorToMask(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageColorToMask",
|
||||
search_aliases=["color keying", "chroma key"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -199,7 +194,6 @@ class InvertMask(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="InvertMask",
|
||||
search_aliases=["reverse mask", "flip mask"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
@ -220,7 +214,6 @@ class CropMask(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="CropMask",
|
||||
search_aliases=["cut mask", "extract mask region", "mask slice"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
@ -246,7 +239,6 @@ class MaskComposite(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MaskComposite",
|
||||
search_aliases=["combine masks", "blend masks", "layer masks"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("destination"),
|
||||
@ -295,7 +287,6 @@ class FeatherMask(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="FeatherMask",
|
||||
search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
@ -342,7 +333,6 @@ class GrowMask(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrowMask",
|
||||
search_aliases=["expand mask", "shrink mask"],
|
||||
display_name="Grow Mask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
@ -380,7 +370,6 @@ class ThresholdMask(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ThresholdMask",
|
||||
search_aliases=["binary mask"],
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
@ -405,7 +394,6 @@ class MaskPreview(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MaskPreview",
|
||||
search_aliases=["show mask", "view mask", "inspect mask", "debug mask"],
|
||||
display_name="Preview Mask",
|
||||
category="mask",
|
||||
description="Saves the input images to your ComfyUI output directory.",
|
||||
|
||||
@ -299,7 +299,6 @@ class RescaleCFG:
|
||||
return (m, )
|
||||
|
||||
class ModelComputeDtype:
|
||||
SEARCH_ALIASES = ["model precision", "change dtype"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
|
||||
@ -91,7 +91,6 @@ class CLIPMergeSimple:
|
||||
|
||||
|
||||
class CLIPSubtract:
|
||||
SEARCH_ALIASES = ["clip difference", "text encoder subtract"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip1": ("CLIP",),
|
||||
@ -114,7 +113,6 @@ class CLIPSubtract:
|
||||
|
||||
|
||||
class CLIPAdd:
|
||||
SEARCH_ALIASES = ["combine clip"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip1": ("CLIP",),
|
||||
@ -227,7 +225,6 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
||||
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
|
||||
|
||||
class CheckpointSave:
|
||||
SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"]
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@ -340,7 +337,6 @@ class VAESave:
|
||||
return {}
|
||||
|
||||
class ModelSave:
|
||||
SEARCH_ALIASES = ["export model", "checkpoint save"]
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
|
||||
@ -12,7 +12,6 @@ class Morphology(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Morphology",
|
||||
search_aliases=["erode", "dilate"],
|
||||
display_name="ImageMorphology",
|
||||
category="image/postprocessing",
|
||||
inputs=[
|
||||
@ -58,7 +57,6 @@ class ImageRGBToYUV(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageRGBToYUV",
|
||||
search_aliases=["color space conversion"],
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
@ -80,7 +78,6 @@ class ImageYUVToRGB(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageYUVToRGB",
|
||||
search_aliases=["color space conversion"],
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
io.Image.Input("Y"),
|
||||
|
||||
@ -7,7 +7,6 @@ class CLIPTextEncodePixArtAlpha(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodePixArtAlpha",
|
||||
search_aliases=["pixart prompt"],
|
||||
category="advanced/conditioning",
|
||||
description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.",
|
||||
inputs=[
|
||||
|
||||
@ -402,6 +402,7 @@ def scale_to_multiple_cover(input: torch.Tensor, multiple: int, scale_method: st
|
||||
return input[:, y0:y1, x0:x1]
|
||||
|
||||
class ResizeImageMaskNode(io.ComfyNode):
|
||||
|
||||
scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
|
||||
@ -420,62 +421,46 @@ class ResizeImageMaskNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = io.MatchType.Template("input_type", [io.Image, io.Mask])
|
||||
crop_combo = io.Combo.Input(
|
||||
"crop",
|
||||
options=cls.crop_methods,
|
||||
default="center",
|
||||
tooltip="How to handle aspect ratio mismatch: 'disabled' stretches to fit, 'center' crops to maintain aspect ratio.",
|
||||
)
|
||||
crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center")
|
||||
return io.Schema(
|
||||
node_id="ResizeImageMaskNode",
|
||||
display_name="Resize Image/Mask",
|
||||
description="Resize an image or mask using various scaling methods.",
|
||||
category="transform",
|
||||
search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"],
|
||||
inputs=[
|
||||
io.MatchType.Input("input", template=template),
|
||||
io.DynamicCombo.Input(
|
||||
"resize_type",
|
||||
tooltip="Select how to resize: by exact dimensions, scale factor, matching another image, etc.",
|
||||
options=[
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
|
||||
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Set to 0 to auto-calculate from height while preserving aspect ratio."),
|
||||
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Set to 0 to auto-calculate from width while preserving aspect ratio."),
|
||||
crop_combo,
|
||||
io.DynamicCombo.Input("resize_type", options=[
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
|
||||
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
|
||||
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01, tooltip="Scale factor (e.g., 2.0 doubles size, 0.5 halves size)."),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
|
||||
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
crop_combo,
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
|
||||
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The longer edge will be resized to this value. Aspect ratio is preserved."),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
|
||||
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
|
||||
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The shorter edge will be resized to this value. Aspect ratio is preserved."),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
|
||||
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
|
||||
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Height auto-adjusts to preserve aspect ratio."),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
|
||||
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
|
||||
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Width auto-adjusts to preserve aspect ratio."),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
|
||||
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
|
||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01, tooltip="Target total megapixels (e.g., 1.0 ≈ 1024×1024). Aspect ratio is preserved."),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
|
||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
|
||||
io.MultiType.Input("match", [io.Image, io.Mask], tooltip="Resize input to match the dimensions of this reference image or mask."),
|
||||
crop_combo,
|
||||
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
|
||||
io.MultiType.Input("match", [io.Image, io.Mask]),
|
||||
crop_combo,
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [
|
||||
io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1, tooltip="Resize so width and height are divisible by this number. Useful for latent alignment (e.g., 8 or 64)."),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [
|
||||
io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
],
|
||||
),
|
||||
io.Combo.Input(
|
||||
"scale_method",
|
||||
options=cls.scale_methods,
|
||||
default="area",
|
||||
tooltip="Interpolation algorithm. 'area' is best for downscaling, 'lanczos' for upscaling, 'nearest-exact' for pixel art.",
|
||||
),
|
||||
]),
|
||||
io.Combo.Input("scale_method", options=cls.scale_methods, default="area"),
|
||||
],
|
||||
outputs=[io.MatchType.Output(template=template, display_name="resized")]
|
||||
)
|
||||
@ -584,7 +569,6 @@ class BatchMasksNode(io.ComfyNode):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchMasksNode",
|
||||
search_aliases=["combine masks", "stack masks", "merge masks"],
|
||||
display_name="Batch Masks",
|
||||
category="mask",
|
||||
inputs=[
|
||||
@ -605,7 +589,6 @@ class BatchLatentsNode(io.ComfyNode):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchLatentsNode",
|
||||
search_aliases=["combine latents", "stack latents", "merge latents"],
|
||||
display_name="Batch Latents",
|
||||
category="latent",
|
||||
inputs=[
|
||||
@ -629,7 +612,6 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
|
||||
prefix="input", min=1, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchImagesMasksLatentsNode",
|
||||
search_aliases=["combine batch", "merge batch", "stack inputs"],
|
||||
display_name="Batch Images/Masks/Latents",
|
||||
category="util",
|
||||
inputs=[
|
||||
|
||||
@ -16,7 +16,7 @@ class PreviewAny():
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "utils"
|
||||
SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"]
|
||||
SEARCH_ALIASES = ["preview", "show", "display", "view", "show text", "display text", "preview text", "show output", "inspect", "debug"]
|
||||
|
||||
def main(self, source=None):
|
||||
value = 'None'
|
||||
|
||||
@ -55,7 +55,7 @@ class EmptySD3LatentImage(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
|
||||
return io.NodeOutput({"samples":latent})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
|
||||
@ -65,7 +65,6 @@ class CLIPTextEncodeSD3(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeSD3",
|
||||
search_aliases=["sd3 prompt"],
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
|
||||
@ -32,7 +32,6 @@ class StringSubstring(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringSubstring",
|
||||
search_aliases=["extract text", "text portion"],
|
||||
display_name="Substring",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -55,7 +54,6 @@ class StringLength(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringLength",
|
||||
search_aliases=["character count", "text size"],
|
||||
display_name="Length",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -76,7 +74,6 @@ class CaseConverter(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CaseConverter",
|
||||
search_aliases=["text case", "uppercase", "lowercase", "capitalize"],
|
||||
display_name="Case Converter",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -109,7 +106,6 @@ class StringTrim(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringTrim",
|
||||
search_aliases=["clean whitespace", "remove whitespace"],
|
||||
display_name="Trim",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -140,7 +136,6 @@ class StringReplace(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringReplace",
|
||||
search_aliases=["find and replace", "substitute", "swap text"],
|
||||
display_name="Replace",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -163,7 +158,6 @@ class StringContains(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringContains",
|
||||
search_aliases=["text includes", "string includes"],
|
||||
display_name="Contains",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -191,7 +185,6 @@ class StringCompare(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringCompare",
|
||||
search_aliases=["text match", "string equals", "starts with", "ends with"],
|
||||
display_name="Compare",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -227,7 +220,6 @@ class RegexMatch(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexMatch",
|
||||
search_aliases=["pattern match", "text contains", "string match"],
|
||||
display_name="Regex Match",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -268,7 +260,6 @@ class RegexExtract(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexExtract",
|
||||
search_aliases=["pattern extract", "text parser", "parse text"],
|
||||
display_name="Regex Extract",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
@ -343,7 +334,6 @@ class RegexReplace(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexReplace",
|
||||
search_aliases=["pattern replace", "find and replace", "substitution"],
|
||||
display_name="Regex Replace",
|
||||
category="utils/string",
|
||||
description="Find and replace text using regex patterns.",
|
||||
|
||||
@ -1101,7 +1101,6 @@ class SaveLoRA(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveLoRA",
|
||||
search_aliases=["export lora"],
|
||||
display_name="Save LoRA Weights",
|
||||
category="loaders",
|
||||
is_experimental=True,
|
||||
@ -1145,7 +1144,6 @@ class LossGraphNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LossGraphNode",
|
||||
search_aliases=["training chart", "training visualization", "plot loss"],
|
||||
display_name="Plot Loss Graph",
|
||||
category="training",
|
||||
is_experimental=True,
|
||||
|
||||
@ -16,7 +16,6 @@ class SaveWEBM(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveWEBM",
|
||||
search_aliases=["export webm"],
|
||||
category="image/video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
@ -70,7 +69,6 @@ class SaveVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveVideo",
|
||||
search_aliases=["export video"],
|
||||
display_name="Save Video",
|
||||
category="image/video",
|
||||
description="Saves the input images to your ComfyUI output directory.",
|
||||
@ -118,7 +116,6 @@ class CreateVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CreateVideo",
|
||||
search_aliases=["images to video"],
|
||||
display_name="Create Video",
|
||||
category="image/video",
|
||||
description="Create a video from images.",
|
||||
@ -143,7 +140,6 @@ class GetVideoComponents(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GetVideoComponents",
|
||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||
display_name="Get Video Components",
|
||||
category="image/video",
|
||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||
@ -163,6 +159,29 @@ class GetVideoComponents(io.ComfyNode):
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
||||
|
||||
|
||||
class VideoSlice(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VideoSlice",
|
||||
display_name="Video Slice",
|
||||
category="image/video",
|
||||
description="Extract a range of frames from a video.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to slice."),
|
||||
io.Int.Input("start_frame", default=0, min=0, tooltip="The frame index to start from (0-indexed)."),
|
||||
io.Int.Input("frame_count", default=1, min=1, tooltip="Number of frames to extract."),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(tooltip="The sliced video."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: Input.Video, start_frame: int, frame_count: int) -> io.NodeOutput:
|
||||
return io.NodeOutput(video.sliced(start_frame, frame_count))
|
||||
|
||||
|
||||
class LoadVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -171,7 +190,6 @@ class LoadVideo(io.ComfyNode):
|
||||
files = folder_paths.filter_files_content_types(files, ["video"])
|
||||
return io.Schema(
|
||||
node_id="LoadVideo",
|
||||
search_aliases=["import video", "open video", "video file"],
|
||||
display_name="Load Video",
|
||||
category="image/video",
|
||||
inputs=[
|
||||
@ -211,6 +229,7 @@ class VideoExtension(ComfyExtension):
|
||||
SaveVideo,
|
||||
CreateVideo,
|
||||
GetVideoComponents,
|
||||
VideoSlice,
|
||||
LoadVideo,
|
||||
]
|
||||
|
||||
|
||||
@ -287,7 +287,6 @@ class WanVaceToVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanVaceToVideo",
|
||||
search_aliases=["video conditioning", "video control"],
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
@ -706,7 +705,6 @@ class WanTrackToVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanTrackToVideo",
|
||||
search_aliases=["motion tracking", "trajectory video", "point tracking", "keypoint animation"],
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
|
||||
@ -324,7 +324,6 @@ class GenerateTracks(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GenerateTracks",
|
||||
search_aliases=["motion paths", "camera movement", "trajectory"],
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=832, min=16, max=4096, step=16),
|
||||
|
||||
@ -5,7 +5,6 @@ MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
|
||||
class WebcamCapture(nodes.LoadImage):
|
||||
SEARCH_ALIASES = ["camera input", "live capture", "camera feed", "snapshot"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
|
||||
2
main.py
2
main.py
@ -326,7 +326,7 @@ def setup_database():
|
||||
if dependencies_available():
|
||||
init_db()
|
||||
if not args.disable_assets_autoscan:
|
||||
seed_assets(["models", "input", "output"], enable_logging=True)
|
||||
seed_assets(["models"], enable_logging=True)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||
|
||||
|
||||
46
nodes.py
46
nodes.py
@ -93,8 +93,6 @@ class ConditioningCombine:
|
||||
return (conditioning_1 + conditioning_2, )
|
||||
|
||||
class ConditioningAverage :
|
||||
SEARCH_ALIASES = ["blend prompts", "interpolate conditioning", "mix prompts", "style fusion", "weighted blend"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ),
|
||||
@ -161,8 +159,6 @@ class ConditioningConcat:
|
||||
return (out, )
|
||||
|
||||
class ConditioningSetArea:
|
||||
SEARCH_ALIASES = ["regional prompt", "area prompt", "spatial conditioning", "localized prompt"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
@ -221,8 +217,6 @@ class ConditioningSetAreaStrength:
|
||||
|
||||
|
||||
class ConditioningSetMask:
|
||||
SEARCH_ALIASES = ["masked prompt", "regional inpaint conditioning", "mask conditioning"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
@ -248,8 +242,6 @@ class ConditioningSetMask:
|
||||
return (c, )
|
||||
|
||||
class ConditioningZeroOut:
|
||||
SEARCH_ALIASES = ["null conditioning", "clear conditioning"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", )}}
|
||||
@ -475,8 +467,6 @@ class InpaintModelConditioning:
|
||||
|
||||
|
||||
class SaveLatent:
|
||||
SEARCH_ALIASES = ["export latent"]
|
||||
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@ -528,8 +518,6 @@ class SaveLatent:
|
||||
|
||||
|
||||
class LoadLatent:
|
||||
SEARCH_ALIASES = ["import latent", "open latent"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
@ -566,8 +554,6 @@ class LoadLatent:
|
||||
|
||||
|
||||
class CheckpointLoader:
|
||||
SEARCH_ALIASES = ["load model", "model loader"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
|
||||
@ -607,8 +593,6 @@ class CheckpointLoaderSimple:
|
||||
return out[:3]
|
||||
|
||||
class DiffusersLoader:
|
||||
SEARCH_ALIASES = ["load diffusers model"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
paths = []
|
||||
@ -1079,8 +1063,6 @@ class StyleModelLoader:
|
||||
|
||||
|
||||
class StyleModelApply:
|
||||
SEARCH_ALIASES = ["style transfer"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
@ -1230,12 +1212,10 @@ class EmptyLatentImage:
|
||||
|
||||
def generate(self, width, height, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
|
||||
return ({"samples": latent, "downscale_ratio_spacial": 8}, )
|
||||
return ({"samples":latent}, )
|
||||
|
||||
|
||||
class LatentFromBatch:
|
||||
SEARCH_ALIASES = ["select from batch", "pick latent", "batch subset"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
@ -1268,8 +1248,6 @@ class LatentFromBatch:
|
||||
return (s,)
|
||||
|
||||
class RepeatLatentBatch:
|
||||
SEARCH_ALIASES = ["duplicate latent", "clone latent"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
@ -1296,8 +1274,6 @@ class RepeatLatentBatch:
|
||||
return (s,)
|
||||
|
||||
class LatentUpscale:
|
||||
SEARCH_ALIASES = ["enlarge latent", "resize latent"]
|
||||
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
|
||||
@ -1332,8 +1308,6 @@ class LatentUpscale:
|
||||
return (s,)
|
||||
|
||||
class LatentUpscaleBy:
|
||||
SEARCH_ALIASES = ["enlarge latent", "resize latent", "scale latent"]
|
||||
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
|
||||
|
||||
@classmethod
|
||||
@ -1377,8 +1351,6 @@ class LatentRotate:
|
||||
return (s,)
|
||||
|
||||
class LatentFlip:
|
||||
SEARCH_ALIASES = ["mirror latent"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
@ -1399,8 +1371,6 @@ class LatentFlip:
|
||||
return (s,)
|
||||
|
||||
class LatentComposite:
|
||||
SEARCH_ALIASES = ["overlay latent", "layer latent", "paste latent"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples_to": ("LATENT",),
|
||||
@ -1443,8 +1413,6 @@ class LatentComposite:
|
||||
return (samples_out,)
|
||||
|
||||
class LatentBlend:
|
||||
SEARCH_ALIASES = ["mix latents", "interpolate latents"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
@ -1486,8 +1454,6 @@ class LatentBlend:
|
||||
raise ValueError(f"Unsupported blend mode: {mode}")
|
||||
|
||||
class LatentCrop:
|
||||
SEARCH_ALIASES = ["trim latent", "cut latent"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
@ -1538,7 +1504,7 @@ class SetLatentNoiseMask:
|
||||
|
||||
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
||||
latent_image = latent["samples"]
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
|
||||
|
||||
if disable_noise:
|
||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||
@ -1556,7 +1522,6 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
out = latent.copy()
|
||||
out.pop("downscale_ratio_spacial", None)
|
||||
out["samples"] = samples
|
||||
return (out, )
|
||||
|
||||
@ -1774,8 +1739,6 @@ class LoadImage:
|
||||
return True
|
||||
|
||||
class LoadImageMask:
|
||||
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
|
||||
|
||||
_color_channels = ["alpha", "red", "green", "blue"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1826,8 +1789,6 @@ class LoadImageMask:
|
||||
|
||||
|
||||
class LoadImageOutput(LoadImage):
|
||||
SEARCH_ALIASES = ["output image", "previous generation"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@ -1901,7 +1862,6 @@ class ImageScaleBy:
|
||||
return (s,)
|
||||
|
||||
class ImageInvert:
|
||||
SEARCH_ALIASES = ["reverse colors"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1917,7 +1877,6 @@ class ImageInvert:
|
||||
return (s,)
|
||||
|
||||
class ImageBatch:
|
||||
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1963,7 +1922,6 @@ class EmptyImage:
|
||||
return (torch.cat((r, g, b), dim=-1), )
|
||||
|
||||
class ImagePadForOutpaint:
|
||||
SEARCH_ALIASES = ["extend canvas", "expand image"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
|
||||
@ -22,7 +22,6 @@ alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.7
|
||||
blake3
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
|
||||
@ -689,7 +689,7 @@ class PromptServer():
|
||||
@routes.get("/object_info")
|
||||
async def get_object_info(request):
|
||||
try:
|
||||
seed_assets(["models", "input", "output"])
|
||||
seed_assets(["models"])
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to seed assets: {e}")
|
||||
with folder_paths.cache_helper:
|
||||
|
||||
@ -1,177 +0,0 @@
|
||||
"""Tests for Assets API endpoints (app/assets/api/routes.py)
|
||||
|
||||
Tests cover:
|
||||
- Schema validation for query parameters and request bodies
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.assets.api import schemas_in, schemas_out
|
||||
|
||||
|
||||
class TestListAssetsQuery:
|
||||
"""Tests for ListAssetsQuery schema."""
|
||||
|
||||
def test_defaults(self):
|
||||
"""Test default values."""
|
||||
q = schemas_in.ListAssetsQuery()
|
||||
assert q.include_tags == []
|
||||
assert q.exclude_tags == []
|
||||
assert q.limit == 20
|
||||
assert q.offset == 0
|
||||
assert q.sort == "created_at"
|
||||
assert q.order == "desc"
|
||||
assert q.include_public == True
|
||||
|
||||
def test_include_public_false(self):
|
||||
"""Test include_public can be set to False."""
|
||||
q = schemas_in.ListAssetsQuery(include_public=False)
|
||||
assert q.include_public == False
|
||||
|
||||
def test_csv_tags_parsing(self):
|
||||
"""Test comma-separated tags are parsed correctly."""
|
||||
q = schemas_in.ListAssetsQuery.model_validate({"include_tags": "a,b,c"})
|
||||
assert q.include_tags == ["a", "b", "c"]
|
||||
|
||||
def test_metadata_filter_json_string(self):
|
||||
"""Test metadata_filter accepts JSON string."""
|
||||
q = schemas_in.ListAssetsQuery.model_validate({"metadata_filter": '{"key": "value"}'})
|
||||
assert q.metadata_filter == {"key": "value"}
|
||||
|
||||
|
||||
class TestTagsListQuery:
|
||||
"""Tests for TagsListQuery schema."""
|
||||
|
||||
def test_defaults(self):
|
||||
"""Test default values."""
|
||||
q = schemas_in.TagsListQuery()
|
||||
assert q.prefix is None
|
||||
assert q.limit == 100
|
||||
assert q.offset == 0
|
||||
assert q.order == "count_desc"
|
||||
assert q.include_zero == True
|
||||
assert q.include_public == True
|
||||
|
||||
def test_include_public_false(self):
|
||||
"""Test include_public can be set to False."""
|
||||
q = schemas_in.TagsListQuery(include_public=False)
|
||||
assert q.include_public == False
|
||||
|
||||
|
||||
class TestUpdateAssetBody:
|
||||
"""Tests for UpdateAssetBody schema."""
|
||||
|
||||
def test_requires_at_least_one_field(self):
|
||||
"""Test that at least one field is required."""
|
||||
with pytest.raises(ValidationError):
|
||||
schemas_in.UpdateAssetBody()
|
||||
|
||||
def test_name_only(self):
|
||||
"""Test updating name only."""
|
||||
body = schemas_in.UpdateAssetBody(name="new name")
|
||||
assert body.name == "new name"
|
||||
assert body.mime_type is None
|
||||
assert body.preview_id is None
|
||||
|
||||
def test_mime_type_only(self):
|
||||
"""Test updating mime_type only."""
|
||||
body = schemas_in.UpdateAssetBody(mime_type="image/png")
|
||||
assert body.mime_type == "image/png"
|
||||
|
||||
def test_preview_id_only(self):
|
||||
"""Test updating preview_id only."""
|
||||
body = schemas_in.UpdateAssetBody(preview_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert body.preview_id == "550e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
def test_preview_id_invalid_uuid(self):
|
||||
"""Test invalid UUID for preview_id."""
|
||||
with pytest.raises(ValidationError):
|
||||
schemas_in.UpdateAssetBody(preview_id="not-a-uuid")
|
||||
|
||||
def test_all_fields(self):
|
||||
"""Test all fields together."""
|
||||
body = schemas_in.UpdateAssetBody(
|
||||
name="test",
|
||||
mime_type="application/json",
|
||||
preview_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
user_metadata={"key": "value"}
|
||||
)
|
||||
assert body.name == "test"
|
||||
assert body.mime_type == "application/json"
|
||||
|
||||
|
||||
class TestUploadAssetFromUrlBody:
|
||||
"""Tests for UploadAssetFromUrlBody schema (JSON URL upload)."""
|
||||
|
||||
def test_valid_url(self):
|
||||
"""Test valid HTTP URL."""
|
||||
body = schemas_in.UploadAssetFromUrlBody(
|
||||
url="https://example.com/model.safetensors",
|
||||
name="model.safetensors"
|
||||
)
|
||||
assert body.url == "https://example.com/model.safetensors"
|
||||
assert body.name == "model.safetensors"
|
||||
|
||||
def test_http_url(self):
|
||||
"""Test HTTP URL (not just HTTPS)."""
|
||||
body = schemas_in.UploadAssetFromUrlBody(
|
||||
url="http://example.com/file.bin",
|
||||
name="file.bin"
|
||||
)
|
||||
assert body.url == "http://example.com/file.bin"
|
||||
|
||||
def test_invalid_url_scheme(self):
|
||||
"""Test invalid URL scheme raises error."""
|
||||
with pytest.raises(ValidationError):
|
||||
schemas_in.UploadAssetFromUrlBody(
|
||||
url="ftp://example.com/file.bin",
|
||||
name="file.bin"
|
||||
)
|
||||
|
||||
def test_tags_normalized(self):
|
||||
"""Test tags are normalized to lowercase."""
|
||||
body = schemas_in.UploadAssetFromUrlBody(
|
||||
url="https://example.com/model.safetensors",
|
||||
name="model",
|
||||
tags=["Models", "LORAS"]
|
||||
)
|
||||
assert body.tags == ["models", "loras"]
|
||||
|
||||
|
||||
class TestTagsRefineQuery:
|
||||
"""Tests for TagsRefineQuery schema."""
|
||||
|
||||
def test_defaults(self):
|
||||
"""Test default values."""
|
||||
q = schemas_in.TagsRefineQuery()
|
||||
assert q.include_tags == []
|
||||
assert q.exclude_tags == []
|
||||
assert q.limit == 100
|
||||
assert q.include_public == True
|
||||
|
||||
def test_include_public_false(self):
|
||||
"""Test include_public can be set to False."""
|
||||
q = schemas_in.TagsRefineQuery(include_public=False)
|
||||
assert q.include_public == False
|
||||
|
||||
|
||||
class TestTagHistogramResponse:
|
||||
"""Tests for TagHistogramResponse schema."""
|
||||
|
||||
def test_empty_response(self):
|
||||
"""Test empty response."""
|
||||
resp = schemas_out.TagHistogramResponse()
|
||||
assert resp.tags == []
|
||||
|
||||
def test_with_entries(self):
|
||||
"""Test response with entries."""
|
||||
resp = schemas_out.TagHistogramResponse(
|
||||
tags=[
|
||||
schemas_out.TagHistogramEntry(name="models", count=10),
|
||||
schemas_out.TagHistogramEntry(name="loras", count=5),
|
||||
]
|
||||
)
|
||||
assert len(resp.tags) == 2
|
||||
assert resp.tags[0].name == "models"
|
||||
assert resp.tags[0].count == 10
|
||||
150
tests-unit/comfy_api_test/video_slice_test.py
Normal file
150
tests-unit/comfy_api_test/video_slice_test.py
Normal file
@ -0,0 +1,150 @@
|
||||
import pytest
|
||||
import torch
|
||||
import tempfile
|
||||
import os
|
||||
import av
|
||||
from fractions import Fraction
|
||||
from comfy_api.input_impl.video_types import (
|
||||
VideoFromFile,
|
||||
VideoFromComponents,
|
||||
SliceOp,
|
||||
)
|
||||
from comfy_api.util.video_types import VideoComponents
|
||||
|
||||
|
||||
def create_test_video(width=4, height=4, frames=10, fps=30):
|
||||
"""Helper to create a temporary video file."""
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
||||
with av.open(tmp.name, mode="w") as container:
|
||||
stream = container.add_stream("h264", rate=fps)
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
for i in range(frames):
|
||||
frame_data = torch.ones(height, width, 3, dtype=torch.uint8) * (i * 25)
|
||||
frame = av.VideoFrame.from_ndarray(frame_data.numpy(), format="rgb24")
|
||||
frame = frame.reformat(format="yuv420p")
|
||||
packet = stream.encode(frame)
|
||||
container.mux(packet)
|
||||
|
||||
packet = stream.encode(None)
|
||||
container.mux(packet)
|
||||
|
||||
return tmp.name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_file_10_frames():
|
||||
file_path = create_test_video(frames=10)
|
||||
yield file_path
|
||||
os.unlink(file_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_components_10_frames():
|
||||
images = torch.rand(10, 4, 4, 3)
|
||||
return VideoComponents(images=images, frame_rate=Fraction(30))
|
||||
|
||||
|
||||
class TestSliceOp:
|
||||
def test_apply_slices_correctly(self, video_components_10_frames):
|
||||
op = SliceOp(start_frame=2, frame_count=3)
|
||||
result = op.apply(video_components_10_frames)
|
||||
|
||||
assert result.images.shape[0] == 3
|
||||
assert torch.equal(result.images, video_components_10_frames.images[2:5])
|
||||
|
||||
def test_compute_frame_count(self):
|
||||
op = SliceOp(start_frame=2, frame_count=5)
|
||||
assert op.compute_frame_count(10) == 5
|
||||
|
||||
def test_compute_frame_count_clamps(self):
|
||||
op = SliceOp(start_frame=8, frame_count=5)
|
||||
assert op.compute_frame_count(10) == 2
|
||||
|
||||
|
||||
class TestVideoSliced:
|
||||
def test_sliced_returns_new_instance(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
assert video is not sliced
|
||||
assert len(video._operations) == 0
|
||||
assert len(sliced._operations) == 1
|
||||
|
||||
def test_get_components_applies_operations(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
components = sliced.get_components()
|
||||
|
||||
assert components.images.shape[0] == 3
|
||||
assert torch.equal(components.images, video_components_10_frames.images[2:5])
|
||||
|
||||
def test_get_frame_count(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
assert sliced.get_frame_count() == 3
|
||||
|
||||
def test_get_duration(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(0, 3)
|
||||
|
||||
assert sliced.get_duration() == pytest.approx(0.1)
|
||||
|
||||
def test_chained_slices_compose(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 6).sliced(1, 3)
|
||||
|
||||
components = sliced.get_components()
|
||||
|
||||
assert components.images.shape[0] == 3
|
||||
assert torch.equal(components.images, video_components_10_frames.images[3:6])
|
||||
|
||||
def test_operations_list_is_immutable(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced1 = video.sliced(0, 5)
|
||||
sliced2 = sliced1.sliced(1, 2)
|
||||
|
||||
assert len(video._operations) == 0
|
||||
assert len(sliced1._operations) == 1
|
||||
assert len(sliced2._operations) == 2
|
||||
|
||||
def test_from_file(self, video_file_10_frames):
|
||||
video = VideoFromFile(video_file_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
components = sliced.get_components()
|
||||
|
||||
assert components.images.shape[0] == 3
|
||||
assert sliced.get_frame_count() == 3
|
||||
|
||||
def test_save_sliced_video(self, video_components_10_frames, tmp_path):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
output_path = str(tmp_path / "sliced_output.mp4")
|
||||
sliced.save_to(output_path)
|
||||
|
||||
saved_video = VideoFromFile(output_path)
|
||||
assert saved_video.get_frame_count() == 3
|
||||
|
||||
def test_materialization_clears_ops(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
assert len(sliced._operations) == 1
|
||||
sliced.get_components()
|
||||
assert len(sliced._operations) == 0
|
||||
|
||||
def test_second_get_components_uses_cache(self, video_components_10_frames):
|
||||
video = VideoFromComponents(video_components_10_frames)
|
||||
sliced = video.sliced(2, 3)
|
||||
|
||||
first = sliced.get_components()
|
||||
second = sliced.get_components()
|
||||
|
||||
assert first.images.shape == second.images.shape
|
||||
assert torch.equal(first.images, second.images)
|
||||
Reference in New Issue
Block a user