Compare commits

..

7 Commits

Author SHA1 Message Date
7ee77ff038 Add name to LoraLoaderModelOnly. (#12078) 2026-01-25 21:01:55 -05:00
26c5bbb875 Move nodes from previous PR into their own file. (#12066) 2026-01-24 23:02:32 -05:00
a97c98068f [Weight-adapter/Trainer] Bypass forward mode in Weight adapter system (#11958)
* Add API of bypass forward module

* bypass implementation

* add bypass fwd into nodes list/trainer
2026-01-24 22:56:22 -05:00
635406e283 Only enable fp16 on z image models that actually support it. (#12065) 2026-01-24 22:32:28 -05:00
ed6002cb60 add support for kwargs inputs to allow arbitrary inputs from frontend (#12063)
used to output selected combo index

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-01-24 17:30:40 -08:00
bc72d7f8d1 [API Nodes] add TencentHunyuan3D nodes (#12026)
* feat(api-nodes): add TencentHunyuan3D nodes

* add "(Pro)" to display name

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-01-24 17:10:09 -08:00
aef4e13588 Make empty latent node work with other models. (#12062) 2026-01-24 19:23:20 -05:00
41 changed files with 2502 additions and 2430 deletions

1
.gitignore vendored
View File

@ -21,7 +21,6 @@ venv/
*.log
web_custom_versions/
.DS_Store
*:Zone.Identifier
openapi.yaml
filtered-openapi.yaml
uv.lock

139
PLAN.md
View File

@ -1,139 +0,0 @@
# Plan: Align Local Asset/Tag Endpoints with Cloud
## Endpoint Comparison
| Endpoint | Cloud (openapi.yaml) | Local (routes.py) |
|----------|---------------------|-------------------|
| `GET /api/assets` | ✅ + `include_public` param | ✅ |
| `POST /api/assets` | ✅ multipart + JSON URL upload | ✅ multipart only |
| `GET /api/assets/{id}` | ✅ | ✅ |
| `PUT /api/assets/{id}` | ✅ (`name`, `mime_type`, `preview_id`, `user_metadata`) | ✅ (`name`, `tags`, `user_metadata`) |
| `DELETE /api/assets/{id}` | ✅ | ✅ |
| `GET /api/assets/{id}/content` | ❌ | ✅ |
| `POST /api/assets/{id}/tags` | ✅ | ✅ |
| `DELETE /api/assets/{id}/tags` | ✅ | ✅ |
| `PUT /api/assets/{id}/preview` | ❌ | ✅ |
| `POST /api/assets/from-hash` | ✅ | ✅ |
| `HEAD /api/assets/hash/{hash}` | ✅ | ✅ |
| `GET /api/assets/remote-metadata` | ✅ | ❌ |
| `POST /api/assets/download` | ✅ (background download) | ❌ |
| `GET /api/assets/tags/refine` | ✅ (tag histogram) | ❌ |
| `GET /api/tags` | ✅ + `include_public` param | ✅ |
| `POST /api/assets/scan/seed` | ❌ | ✅ (local only) |
---
## Phase 1: Add Missing Cloud Endpoints to Local
### 1.1 `GET /api/assets/remote-metadata` *(deferred)*
Fetch metadata from remote URLs (CivitAI, HuggingFace) without downloading the file.
**Status:** Not supported yet. Add stub/placeholder that returns 501 Not Implemented.
**Parameters:**
- `url` (required): Download URL to retrieve metadata from
**Returns:** Asset metadata (name, size, hash if available, etc.)
### 1.2 `POST /api/assets/download` *(deferred)*
Initiate background download job for large files from HuggingFace or CivitAI.
**Status:** Not supported yet. Add stub/placeholder that returns 501 Not Implemented.
**Request body:**
- `source_url` (required): URL to download from
- `tags`: Optional tags for the asset
- `user_metadata`: Optional metadata
- `preview_id`: Optional preview asset ID
**Returns:**
- 200 if file already exists (returns asset immediately)
- 202 with `task_id` for background download tracking via `GET /api/tasks/{task_id}`
### 1.3 `GET /api/assets/tags/refine`
Get tag histogram for filtered assets (useful for search refinement UI).
**Parameters:**
- `include_tags`: Filter assets with ALL these tags
- `exclude_tags`: Exclude assets with ANY of these tags
- `name_contains`: Filter by name substring
- `metadata_filter`: JSON filter for metadata fields
- `limit`: Max tags to return (default 100)
- `include_public`: Include public/shared assets
**Returns:** List of tags with counts for matching assets
---
## Phase 2: Update Existing Endpoints for Parity
### 2.1 `GET /api/assets`
- Add `include_public` query parameter (boolean, default true)
### 2.2 `POST /api/assets`
- Add JSON body upload path for URL-based uploads:
```json
{
"url": "https://...",
"name": "model.safetensors",
"tags": ["models", "checkpoints"],
"user_metadata": {},
"preview_id": "uuid"
}
```
- Keep existing multipart upload support
### 2.3 `PUT /api/assets/{id}`
- Add `mime_type` field support
- Add `preview_id` field support
- Remove direct `tags` field (recommend using dedicated `POST/DELETE /api/assets/{id}/tags` endpoints instead)
### 2.4 `GET /api/tags`
- Add `include_public` query parameter (boolean, default true)
---
## Phase 3: Local-Only Endpoints
These endpoints exist locally but not in cloud.
### 3.1 `GET /api/assets/{id}/content`
Download asset file content. Cloud uses signed URLs instead. **Keep for local.**
### 3.2 `PUT /api/assets/{id}/preview`
**Remove this endpoint.** Merge functionality into `PUT /api/assets/{id}` by adding `preview_id` field support (aligns with cloud).
### 3.3 `POST /api/assets/scan/seed`
Filesystem seeding/scanning for local asset discovery. Not applicable to cloud. **Keep as local-only.**
---
## Phase 4: Testing
Add tests for all new and modified endpoints to ensure functionality matches cloud behavior.
### 4.1 New Endpoint Tests
- `GET /api/assets/remote-metadata` Test with valid/invalid URLs, various sources (CivitAI, HuggingFace)
- `POST /api/assets/download` Test background download initiation, existing file detection, task tracking
- `GET /api/assets/tags/refine` Test histogram generation with various filter combinations
### 4.2 Updated Endpoint Tests
- `GET /api/assets` Test `include_public` param filtering
- `POST /api/assets` Test JSON URL upload path alongside existing multipart tests
- `PUT /api/assets/{id}` Test `mime_type` and `preview_id` field updates
- `GET /api/tags` Test `include_public` param filtering
### 4.3 Removed Endpoint Tests
- Remove tests for `PUT /api/assets/{id}/preview`
- Add tests for `preview_id` in `PUT /api/assets/{id}` to cover the merged functionality
---
## Implementation Order
1. Phase 2.1, 2.4 Add `include_public` params (low effort, high compatibility)
2. Phase 2.3 Update PUT endpoint fields + remove preview endpoint
3. Phase 2.2 Add JSON URL upload to POST
4. Phase 1.3 Add tags/refine endpoint
5. Phase 1.1, 1.2 Add stub endpoints returning 501 (deferred implementation)
6. Phase 4 Add tests for each phase as implemented

View File

@ -1,20 +1,14 @@
import logging
import uuid
import urllib.parse
import os
import contextlib
from aiohttp import web
from pydantic import ValidationError
import app.assets.manager as manager
import app.assets.scanner as scanner
from app import user_manager
from app.assets.api import schemas_in
from app.assets.helpers import get_query_dict
import folder_paths
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
@ -34,18 +28,6 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
@ROUTES.head("/api/assets/hash/{hash}")
async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = hash_str.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
exists = manager.asset_exists(asset_hash=hash_str)
return web.Response(status=200 if exists else 404)
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
"""
@ -67,29 +49,10 @@ async def list_assets(request: web.Request) -> web.Response:
sort=q.sort,
order=q.order,
owner_id=USER_MANAGER.get_request_user_id(request),
include_public=q.include_public,
)
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.get("/api/assets/remote-metadata")
async def get_remote_asset_metadata(request: web.Request) -> web.Response:
"""
Fetch metadata from remote URLs (CivitAI, HuggingFace) without downloading.
Status: Not implemented yet.
"""
return _error_response(501, "NOT_IMPLEMENTED", "Remote metadata fetching is not yet supported.")
@ROUTES.post("/api/assets/download")
async def create_asset_download(request: web.Request) -> web.Response:
"""
Initiate background download job for large files from HuggingFace or CivitAI.
Status: Not implemented yet.
"""
return _error_response(501, "NOT_IMPLEMENTED", "Background asset download is not yet supported.")
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
async def get_asset(request: web.Request) -> web.Response:
"""
@ -113,306 +76,6 @@ async def get_asset(request: web.Request) -> web.Response:
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
async def download_asset_content(request: web.Request) -> web.Response:
# question: do we need disposition? could we just stick with one of these?
disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}:
disposition = "attachment"
try:
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
asset_info_id=str(uuid.UUID(request.match_info["id"])),
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
except NotImplementedError as nie:
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
except FileNotFoundError:
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
resp = web.FileResponse(abs_path)
resp.content_type = content_type
resp.headers["Content-Disposition"] = cd
return resp
@ROUTES.post("/api/assets/from-hash")
async def create_asset_from_hash(request: web.Request) -> web.Response:
try:
payload = await request.json()
body = schemas_in.CreateFromHashBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
result = manager.create_asset_from_hash(
hash_str=body.hash,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
return web.json_response(result.model_dump(mode="json"), status=201)
@ROUTES.post("/api/assets")
async def upload_asset(request: web.Request) -> web.Response:
"""Asset upload endpoint supporting multipart/form-data (file upload) or application/json (URL-based)."""
content_type = (request.content_type or "").lower()
if content_type.startswith("application/json"):
try:
payload = await request.json()
schemas_in.UploadAssetFromUrlBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
return _error_response(501, "NOT_IMPLEMENTED", "URL-based asset upload is not yet supported. Use multipart/form-data file upload.")
if not content_type.startswith("multipart/"):
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads or application/json for URL-based uploads.")
reader = await request.multipart()
file_present = False
file_client_name: str | None = None
tags_raw: list[str] = []
provided_name: str | None = None
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
file_written = 0
tmp_path: str | None = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
if s:
if ":" not in s:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
provided_hash = f"{algo}:{digest}"
try:
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
except Exception:
provided_hash_exists = None # do not fail the whole request here
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# If client supplied a hash that we know exists, drain but do not write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
continue # Do not create temp file; we will create AssetInfo from the existing content
# Otherwise, store to temp for hashing/ingest
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
try:
if os.path.exists(tmp_path or ""):
os.remove(tmp_path)
finally:
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
# If client did not send file, and we are not doing a from-hash fast path -> error
if not file_present and not (provided_hash and provided_hash_exists):
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
# Empty upload is only acceptable if we are fast-pathing from existing hash
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
try:
spec = schemas_in.UploadAssetSpec.model_validate({
"tags": tags_raw,
"name": provided_name,
"user_metadata": user_metadata_raw,
"hash": provided_hash,
})
except ValidationError as ve:
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _validation_error_response("INVALID_BODY", ve)
# Validate models category against configured folders (consistent with previous behavior)
if spec.tags and spec.tags[0] == "models":
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
)
owner_id = USER_MANAGER.get_request_user_id(request)
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
if spec.hash and provided_hash_exists is True:
try:
result = manager.create_asset_from_hash(
hash_str=spec.hash,
name=spec.name or (spec.hash.split(":", 1)[1]),
tags=spec.tags,
user_metadata=spec.user_metadata or {},
owner_id=owner_id,
)
except Exception:
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
# Drain temp if we accidentally saved (e.g., hash field came after file)
if tmp_path and os.path.exists(tmp_path):
with contextlib.suppress(Exception):
os.remove(tmp_path)
status = 200 if (not result.created_new) else 201
return web.json_response(result.model_dump(mode="json"), status=status)
# Otherwise, we must have a temp file path to ingest
if not tmp_path or not os.path.exists(tmp_path):
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
try:
created = manager.upload_asset_from_temp_path(
spec,
temp_path=tmp_path,
client_filename=file_client_name,
owner_id=owner_id,
expected_asset_hash=spec.hash,
)
status = 201 if created.created_new else 200
return web.json_response(created.model_dump(mode="json"), status=status)
except ValueError as e:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
msg = str(e)
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
return _error_response(
400,
"HASH_MISMATCH",
"Uploaded file hash does not match provided hash.",
)
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
except Exception:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
return _error_response(500, "INTERNAL", "Unexpected server error.")
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
async def update_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.update_asset(
asset_info_id=asset_info_id,
name=body.name,
mime_type=body.mime_type,
preview_id=body.preview_id,
user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"update_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
async def delete_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
delete_content = request.query.get("delete_content")
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
try:
deleted = manager.delete_asset_reference(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
delete_content_if_orphan=delete_content,
)
except Exception:
logging.exception(
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
if not deleted:
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
return web.Response(status=204)
@ROUTES.get("/api/tags")
async def get_tags(request: web.Request) -> web.Response:
"""
@ -435,109 +98,5 @@ async def get_tags(request: web.Request) -> web.Response:
order=query.order,
include_zero=query.include_zero,
owner_id=USER_MANAGER.get_request_user_id(request),
include_public=query.include_public,
)
return web.json_response(result.model_dump(mode="json"))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def add_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsAdd.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.add_tags_to_asset(
asset_info_id=asset_info_id,
tags=data.tags,
origin="manual",
owner_id=USER_MANAGER.get_request_user_id(request),
)
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
async def delete_asset_tags(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
data = schemas_in.TagsRemove.model_validate(payload)
except ValidationError as ve:
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = manager.remove_tags_from_asset(
asset_info_id=asset_info_id,
tags=data.tags,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
logging.exception(
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.get("/api/assets/tags/refine")
async def get_asset_tag_histogram(request: web.Request) -> web.Response:
"""
GET request to get a tag histogram for filtered assets.
"""
query_dict = get_query_dict(request)
try:
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
except ValidationError as ve:
return _validation_error_response("INVALID_QUERY", ve)
payload = manager.get_tag_histogram(
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
include_public=q.include_public,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.post("/api/assets/scan/seed")
async def seed_assets(request: web.Request) -> web.Response:
try:
payload = await request.json()
except Exception:
payload = {}
try:
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
try:
scanner.seed_assets(body.roots)
except Exception:
logging.exception("seed_assets failed for roots=%s", body.roots)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response({"synced": True, "roots": body.roots}, status=200)

View File

@ -8,10 +8,8 @@ from pydantic import (
Field,
conint,
field_validator,
model_validator,
)
from app.assets.helpers import RootType
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
@ -27,8 +25,6 @@ class ListAssetsQuery(BaseModel):
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
order: Literal["asc", "desc"] = "desc"
include_public: bool = True
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
@ -61,73 +57,6 @@ class ListAssetsQuery(BaseModel):
return None
class UpdateAssetBody(BaseModel):
name: str | None = None
mime_type: str | None = None
preview_id: str | None = None
user_metadata: dict[str, Any] | None = None
@field_validator("preview_id", mode="before")
@classmethod
def _norm_uuid(cls, v):
if v is None:
return None
s = str(v).strip()
if not s:
return None
try:
uuid.UUID(s)
except Exception:
raise ValueError("preview_id must be a UUID")
return s
@model_validator(mode="after")
def _at_least_one(self):
if self.name is None and self.mime_type is None and self.preview_id is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, mime_type, preview_id, user_metadata.")
return self
class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str
name: str
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
@field_validator("hash")
@classmethod
def _require_blake3(cls, v):
s = (v or "").strip().lower()
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return s
@field_validator("tags", mode="before")
@classmethod
def _tags_norm(cls, v):
if v is None:
return []
if isinstance(v, list):
out = [str(t).strip().lower() for t in v if str(t).strip()]
seen = set()
dedup = []
for t in out:
if t not in seen:
seen.add(t)
dedup.append(t)
return dedup
if isinstance(v, str):
return [t.strip().lower() for t in v.split(",") if t.strip()]
return []
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
@ -136,7 +65,6 @@ class TagsListQuery(BaseModel):
offset: int = Field(0, ge=0, le=10_000_000)
order: Literal["count_desc", "name_asc"] = "count_desc"
include_zero: bool = True
include_public: bool = True
@field_validator("prefix")
@classmethod
@ -147,188 +75,10 @@ class TagsListQuery(BaseModel):
return v.lower() or None
class TagsAdd(BaseModel):
model_config = ConfigDict(extra="ignore")
tags: list[str] = Field(..., min_length=1)
@field_validator("tags")
@classmethod
def normalize_tags(cls, v: list[str]) -> list[str]:
out = []
for t in v:
if not isinstance(t, str):
raise TypeError("tags must be strings")
tnorm = t.strip().lower()
if tnorm:
out.append(tnorm)
seen = set()
deduplicated = []
for x in out:
if x not in seen:
seen.add(x)
deduplicated.append(x)
return deduplicated
class TagsRemove(TagsAdd):
pass
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
and the original extension is preserved when available.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1)
name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None)
@field_validator("hash", mode="before")
@classmethod
def _parse_hash(cls, v):
if v is None:
return None
s = str(v).strip().lower()
if not s:
return None
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return f"{algo}:{digest}"
@field_validator("tags", mode="before")
@classmethod
def _parse_tags(cls, v):
"""
Accepts a list of strings (possibly multiple form fields),
where each string can be:
- JSON array (e.g., '["models","loras","foo"]')
- comma-separated ('models, loras, foo')
- single token ('models')
Returns a normalized, deduplicated, ordered list.
"""
items: list[str] = []
if v is None:
return []
if isinstance(v, str):
v = [v]
if isinstance(v, list):
for item in v:
if item is None:
continue
s = str(item).strip()
if not s:
continue
if s.startswith("["):
try:
arr = json.loads(s)
if isinstance(arr, list):
items.extend(str(x) for x in arr)
continue
except Exception:
pass # fallback to CSV parse below
items.extend([p for p in s.split(",") if p.strip()])
else:
return []
# normalize + dedupe
norm = []
seen = set()
for t in items:
tnorm = str(t).strip().lower()
if tnorm and tnorm not in seen:
seen.add(tnorm)
norm.append(tnorm)
return norm
@field_validator("user_metadata", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v or {}
if isinstance(v, str):
s = v.strip()
if not s:
return {}
try:
parsed = json.loads(s)
except Exception as e:
raise ValueError(f"user_metadata must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("user_metadata must be a JSON object")
return parsed
return {}
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
raise ValueError("tags must be provided and non-empty")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError("models uploads require a category tag as the second tag")
return self
class UploadAssetFromUrlBody(BaseModel):
"""JSON body for URL-based asset upload."""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
url: str = Field(..., description="HTTP/HTTPS URL to download the asset from")
name: str = Field(..., max_length=512, description="Display name for the asset")
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
class SetPreviewBody(BaseModel):
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
preview_id: str | None = None
@field_validator("url")
@classmethod
def _validate_url(cls, v):
s = (v or "").strip()
if not s:
raise ValueError("url must not be empty")
if not (s.startswith("http://") or s.startswith("https://")):
raise ValueError("url must start with http:// or https://")
return s
@field_validator("tags", mode="before")
@classmethod
def _parse_tags(cls, v):
if v is None:
return []
if isinstance(v, list):
out = [str(t).strip().lower() for t in v if str(t).strip()]
seen = set()
dedup = []
for t in out:
if t not in seen:
seen.add(t)
dedup.append(t)
return dedup
return []
@field_validator("user_metadata", mode="before")
@classmethod
def _parse_metadata(cls, v):
if v is None or isinstance(v, dict):
return v or {}
return {}
@field_validator("preview_id", mode="before")
@classmethod
def _norm_uuid(cls, v):
@ -342,49 +92,3 @@ class UploadAssetFromUrlBody(BaseModel):
except Exception:
raise ValueError("preview_id must be a UUID")
return s
class ScheduleAssetScanBody(BaseModel):
roots: list[RootType] = Field(..., min_length=1)
class TagsRefineQuery(BaseModel):
"""Query parameters for tag histogram/refinement endpoint."""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=1000) = 100
include_public: bool = True
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None

View File

@ -29,21 +29,6 @@ class AssetsList(BaseModel):
has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _ser_updated(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
@ -63,10 +48,6 @@ class AssetDetail(BaseModel):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel):
name: str
count: int
@ -77,26 +58,3 @@ class TagsList(BaseModel):
tags: list[TagUsage] = Field(default_factory=list)
total: int
has_more: bool
class TagsAdd(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
added: list[str] = Field(default_factory=list)
already_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class TagsRemove(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class TagHistogramEntry(BaseModel):
name: str
count: int
class TagHistogramResponse(BaseModel):
tags: list[TagHistogramEntry] = Field(default_factory=list)

View File

@ -1,17 +1,9 @@
import os
import logging
import sqlalchemy as sa
from collections import defaultdict
from datetime import datetime
from typing import Iterable, Any
from sqlalchemy import select, delete, exists, func
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy import select, exists, func
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import (
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
)
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import escape_like_prefix, normalize_tags
from typing import Sequence
@ -23,22 +15,6 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
return AssetInfo.owner_id.in_(["", owner_id])
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
@ -66,7 +42,6 @@ def apply_tag_filters(
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
@ -119,11 +94,7 @@ def apply_metadata_filter(
return stmt
def asset_exists_by_hash(
session: Session,
*,
asset_hash: str,
) -> bool:
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
@ -134,39 +105,9 @@ def asset_exists_by_hash(
).first()
return row is not None
def asset_info_exists_for_asset_id(
session: Session,
*,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_by_hash(
session: Session,
*,
asset_hash: str,
) -> Asset | None:
return (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
def get_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
) -> AssetInfo | None:
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def list_asset_infos_page(
session: Session,
owner_id: str = "",
@ -236,7 +177,6 @@ def list_asset_infos_page(
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
@ -268,489 +208,6 @@ def fetch_asset_info_asset_and_tags(
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def touch_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
session.execute(stmt.values(last_access_time=ts))
def create_asset_info_for_existing_asset(
session: Session,
*,
asset_hash: str,
name: str,
user_metadata: dict | None = None,
tags: Sequence[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
with session.begin_nested():
session.add(info)
session.flush()
except IntegrityError:
existing = (
session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
def set_asset_info_tags(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def replace_asset_info_metadata_projection(
session: Session,
*,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
session.flush()
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def ingest_fs_asset(
session: Session,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: dict | None = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
res = session.execute(
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
def update_asset_info_full(
session: Session,
*,
asset_info_id: str,
name: str | None = None,
mime_type: str | None = None,
user_metadata: dict | None = None,
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
if mime_type is not None and info.asset:
if info.asset.mime_type != mime_type:
info.asset.mime_type = mime_type
touched = True
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
session.flush()
return info
def delete_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
@ -808,163 +265,3 @@ def list_tags_with_usage(
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
return [
tag_name for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def add_tags_to_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)
def set_asset_info_preview(
session: Session,
*,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
session.flush()

View File

@ -1,6 +1,5 @@
import contextlib
import os
from decimal import Decimal
from aiohttp import web
from datetime import datetime, timezone
from pathlib import Path
@ -88,40 +87,6 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def ensure_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
@ -148,6 +113,7 @@ def compute_relative_filename(file_path: str) -> str | None:
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
@ -249,64 +215,3 @@ def collect_models_files() -> list[str]:
if allowed:
out.append(abs_path)
return out
def is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def project_kv(key: str, value):
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@ -1,34 +1,13 @@
import os
import mimetypes
import contextlib
from typing import Sequence
from app.database.db import create_session
from app.assets.api import schemas_out, schemas_in
from app.assets.api import schemas_out
from app.assets.database.queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
create_asset_info_for_existing_asset,
touch_asset_info_by_id,
update_asset_info_full,
delete_asset_info_by_id,
list_cache_states_by_asset_id,
list_asset_infos_page,
list_tags_with_usage,
get_asset_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
pick_best_live_path,
ingest_fs_asset,
set_asset_info_preview,
)
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
from app.assets.database.models import Asset
import app.assets.hashing as hashing
def _safe_sort_field(requested: str | None) -> str:
@ -40,28 +19,11 @@ def _safe_sort_field(requested: str | None) -> str:
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback
def asset_exists(*, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
def asset_exists(asset_hash: str) -> bool:
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def list_assets(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
@ -71,7 +33,6 @@ def list_assets(
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
include_public: bool = True,
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
@ -115,12 +76,7 @@ def list_assets(
has_more=(offset + len(summaries)) < total,
)
def get_asset(
*,
asset_info_id: str,
owner_id: str = "",
) -> schemas_out.AssetDetail:
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
with create_session() as session:
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
@ -141,356 +97,6 @@ def get_asset(
last_access_time=info.last_access_time,
)
def resolve_asset_content_for_download(
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
with create_session() as session:
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = pick_best_live_path(states)
if not abs_path:
raise FileNotFoundError
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
def upload_asset_from_temp_path(
spec: schemas_in.UploadAssetSpec,
*,
temp_path: str,
client_filename: str | None = None,
owner_id: str = "",
expected_asset_hash: str | None = None,
) -> schemas_out.AssetCreated:
try:
digest = hashing.blake3_hash(temp_path)
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
info = create_asset_info_for_existing_asset(
session,
asset_hash=asset_hash,
name=display_name,
user_metadata=spec.user_metadata or {},
tags=spec.tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or spec.name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
ensure_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
with create_session() as session:
result = ingest_fs_asset(
session,
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
)
def update_asset(
*,
asset_info_id: str,
name: str | None = None,
mime_type: str | None = None,
preview_id: str | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
mime_type=mime_type,
user_metadata=user_metadata,
asset_info_row=info_row,
)
if preview_id is not None:
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_id if preview_id else None,
)
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
session.commit()
return schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
session.commit()
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
session.commit()
return False
if not delete_content_if_orphan or not asset_id:
session.commit()
return True
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetCreated | None:
canonical = hash_str.strip().lower()
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
def add_tags_to_asset(
*,
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
session.commit()
return schemas_out.TagsAdd(**data)
def remove_tags_from_asset(
*,
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
session.commit()
return schemas_out.TagsRemove(**data)
def list_tags(
prefix: str | None = None,
limit: int = 100,
@ -498,7 +104,6 @@ def list_tags(
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
include_public: bool = True,
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
@ -516,17 +121,3 @@ def list_tags(
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
def get_tag_histogram(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
include_public: bool = True,
owner_id: str = "",
) -> schemas_out.TagHistogramResponse:
# TODO: Implement actual histogram query in queries.py
return schemas_out.TagHistogramResponse(tags=[])

View File

@ -594,6 +594,7 @@ class Wan22(Wan21):
class HunyuanImage21(LatentFormat):
latent_channels = 64
latent_dimensions = 2
spacial_downscale_ratio = 32
scale_factor = 0.75289
latent_rgb_factors = [
@ -727,6 +728,7 @@ class HunyuanVideo15(LatentFormat):
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
latent_channels = 32
latent_dimensions = 3
spacial_downscale_ratio = 16
scale_factor = 1.03682
taesd_decoder_name = "lighttaehy1_5"

View File

@ -451,6 +451,7 @@ class NextDiT(nn.Module):
device=None,
dtype=None,
operations=None,
**kwargs,
) -> None:
super().__init__()
self.dtype = dtype

View File

@ -444,6 +444,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
dit_config["z_image_modulation"] = True
dit_config["time_scale"] = 1000.0
try:
dit_config["allow_fp16"] = torch.std(state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)], unbiased=False).item() < 0.42
except Exception:
pass
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
dit_config["pad_tokens_multiple"] = 32
sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None)

View File

@ -20,6 +20,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder
import comfy.pixel_space_convert
import comfy.weight_adapter
import yaml
import math
import os
@ -101,6 +102,105 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
return (new_modelpatcher, new_clip)
def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip):
"""
Load LoRA in bypass mode without modifying base model weights.
Instead of patching weights, this injects the LoRA computation into the
forward pass: output = base_forward(x) + lora_path(x)
Non-adapter patches (bias diff, weight diff, etc.) are applied as regular patches.
This is useful for training and when model weights are offloaded.
"""
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
logging.debug(f"[BypassLoRA] key_map has {len(key_map)} entries")
lora = comfy.lora_convert.convert_lora(lora)
loaded = comfy.lora.load_lora(lora, key_map)
logging.debug(f"[BypassLoRA] loaded has {len(loaded)} entries")
# Separate adapters (for bypass) from other patches (for regular patching)
bypass_patches = {} # WeightAdapterBase instances -> bypass mode
regular_patches = {} # diff, set, bias patches -> regular weight patching
for key, patch_data in loaded.items():
if isinstance(patch_data, comfy.weight_adapter.WeightAdapterBase):
bypass_patches[key] = patch_data
else:
regular_patches[key] = patch_data
logging.debug(f"[BypassLoRA] {len(bypass_patches)} bypass adapters, {len(regular_patches)} regular patches")
k = set()
k1 = set()
if model is not None:
new_modelpatcher = model.clone()
# Apply regular patches (bias diff, weight diff, etc.) via normal patching
if regular_patches:
patched_keys = new_modelpatcher.add_patches(regular_patches, strength_model)
k.update(patched_keys)
# Apply adapter patches via bypass injection
manager = comfy.weight_adapter.BypassInjectionManager()
model_sd_keys = set(new_modelpatcher.model.state_dict().keys())
for key, adapter in bypass_patches.items():
if key in model_sd_keys:
manager.add_adapter(key, adapter, strength=strength_model)
k.add(key)
else:
logging.warning(f"[BypassLoRA] Adapter key not in model state_dict: {key}")
injections = manager.create_injections(new_modelpatcher.model)
if manager.get_hook_count() > 0:
new_modelpatcher.set_injections("bypass_lora", injections)
else:
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
# Apply regular patches to clip
if regular_patches:
patched_keys = new_clip.add_patches(regular_patches, strength_clip)
k1.update(patched_keys)
# Apply adapter patches via bypass injection
clip_manager = comfy.weight_adapter.BypassInjectionManager()
clip_sd_keys = set(new_clip.cond_stage_model.state_dict().keys())
for key, adapter in bypass_patches.items():
if key in clip_sd_keys:
clip_manager.add_adapter(key, adapter, strength=strength_clip)
k1.add(key)
clip_injections = clip_manager.create_injections(new_clip.cond_stage_model)
if clip_manager.get_hook_count() > 0:
new_clip.patcher.set_injections("bypass_lora", clip_injections)
else:
new_clip = None
for x in loaded:
if (x not in k) and (x not in k1):
patch_data = loaded[x]
patch_type = type(patch_data).__name__
if isinstance(patch_data, tuple):
patch_type = f"tuple({patch_data[0]})"
logging.warning(f"NOT LOADED: {x} (type={patch_type})")
return (new_modelpatcher, new_clip)
class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
if no_init:

View File

@ -1093,7 +1093,7 @@ class ZImage(Lumina2):
def __init__(self, unet_config):
super().__init__(unet_config)
if comfy.model_management.extended_fp16_support():
if comfy.model_management.extended_fp16_support() and unet_config.get("allow_fp16", False):
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
self.supported_inference_dtypes.insert(1, torch.float16)

View File

@ -5,6 +5,11 @@ from .lokr import LoKrAdapter
from .glora import GLoRAAdapter
from .oft import OFTAdapter
from .boft import BOFTAdapter
from .bypass import (
BypassInjectionManager,
BypassForwardHook,
create_bypass_injections_from_patches,
)
adapters: list[type[WeightAdapterBase]] = [
@ -31,4 +36,7 @@ __all__ = [
"WeightAdapterTrainBase",
"adapters",
"adapter_maps",
"BypassInjectionManager",
"BypassForwardHook",
"create_bypass_injections_from_patches",
] + [a.__name__ for a in adapters]

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Callable, Optional
import torch
import torch.nn as nn
@ -7,12 +7,35 @@ import comfy.model_management
class WeightAdapterBase:
"""
Base class for weight adapters (LoRA, LoHa, LoKr, OFT, etc.)
Bypass Mode:
All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x))
- h(x): Additive component (LoRA path). Returns delta to add to base output.
- g(y): Output transformation. Applied after base + h(x).
For LoRA/LoHa/LoKr: g = identity, h = adapter(x)
For OFT/BOFT: g = transform, h = 0
"""
name: str
loaded_keys: set[str]
weights: list[torch.Tensor]
# Attributes set by bypass system
multiplier: float = 1.0
shape: tuple = None # (out_features, in_features) or (out_ch, in_ch, *kernel)
@classmethod
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
def load(
cls,
x: str,
lora: dict[str, torch.Tensor],
alpha: float,
dora_scale: torch.Tensor,
) -> Optional["WeightAdapterBase"]:
raise NotImplementedError
def to_train(self) -> "WeightAdapterTrainBase":
@ -39,18 +62,202 @@ class WeightAdapterBase:
):
raise NotImplementedError
# ===== Bypass Mode Methods =====
#
# IMPORTANT: Bypass mode is designed for quantized models where original weights
# may not be accessible in a usable format. Therefore, h() and bypass_forward()
# do NOT take org_weight as a parameter. All necessary information (out_channels,
# in_channels, conv params, etc.) is provided via attributes set by BypassForwardHook.
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component: h(x, base_out)
Computes the adapter's contribution to be added to base forward output.
For adapters that only transform output (OFT/BOFT), returns zeros.
Note:
This method does NOT access original model weights. Bypass mode is
designed for quantized models where weights may not be in a usable format.
All shape info comes from module attributes set by BypassForwardHook.
Args:
x: Input tensor
base_out: Output from base forward f(x), can be used for shape reference
Returns:
Delta tensor to add to base output. Shape matches base output.
Reference: LyCORIS LoConModule.bypass_forward_diff
"""
# Default: no additive component (for OFT/BOFT)
# Simply return zeros matching base_out shape
return torch.zeros_like(base_out)
def g(self, y: torch.Tensor) -> torch.Tensor:
"""
Output transformation: g(y)
Applied after base forward + h(x). For most adapters this is identity.
OFT/BOFT override this to apply orthogonal transformation.
Args:
y: Combined output (base + h(x))
Returns:
Transformed output
Reference: LyCORIS OFTModule applies orthogonal transform here
"""
# Default: identity (for LoRA/LoHa/LoKr)
return y
def bypass_forward(
self,
org_forward: Callable,
x: torch.Tensor,
*args,
**kwargs,
) -> torch.Tensor:
"""
Full bypass forward: g(f(x) + h(x, f(x)))
Note:
This method does NOT take org_weight/org_bias parameters. Bypass mode
is designed for quantized models where weights may not be accessible.
The original forward function handles weight access internally.
Args:
org_forward: Original module forward function
x: Input tensor
*args, **kwargs: Additional arguments for org_forward
Returns:
Output with adapter applied in bypass mode
Reference: LyCORIS LoConModule.bypass_forward
"""
# Base forward: f(x)
base_out = org_forward(x, *args, **kwargs)
# Additive component: h(x, base_out) - base_out provided for shape reference
h_out = self.h(x, base_out)
# Output transformation: g(base + h)
return self.g(base_out + h_out)
class WeightAdapterTrainBase(nn.Module):
# We follow the scheme of PR #7032
"""
Base class for trainable weight adapters (LoRA, LoHa, LoKr, OFT, etc.)
Bypass Mode:
All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x))
- h(x): Additive component (LoRA path). Returns delta to add to base output.
- g(y): Output transformation. Applied after base + h(x).
For LoRA/LoHa/LoKr: g = identity, h = adapter(x)
For OFT: g = transform, h = 0
Note:
Unlike WeightAdapterBase, TrainBase classes have simplified weight formats
with fewer branches (e.g., LoKr only has w1/w2, not w1_a/w1_b decomposition).
We follow the scheme of PR #7032
"""
# Attributes set by bypass system (BypassForwardHook)
# These are set before h()/g()/bypass_forward() are called
multiplier: float = 1.0
is_conv: bool = False
conv_dim: int = 0 # 0=linear, 1=conv1d, 2=conv2d, 3=conv3d
kw_dict: dict = {} # Conv kwargs: stride, padding, dilation, groups
kernel_size: tuple = ()
in_channels: int = None
out_channels: int = None
def __init__(self):
super().__init__()
def __call__(self, w):
"""
w: The original weight tensor to be modified.
Weight modification mode: returns modified weight.
Args:
w: The original weight tensor to be modified.
Returns:
Modified weight tensor.
"""
raise NotImplementedError
# ===== Bypass Mode Methods =====
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component: h(x, base_out)
Computes the adapter's contribution to be added to base forward output.
For adapters that only transform output (OFT), returns zeros.
Args:
x: Input tensor
base_out: Output from base forward f(x), can be used for shape reference
Returns:
Delta tensor to add to base output. Shape matches base output.
Subclasses should override this method.
"""
raise NotImplementedError(
f"{self.__class__.__name__}.h() not implemented. "
"Subclasses must implement h() for bypass mode."
)
def g(self, y: torch.Tensor) -> torch.Tensor:
"""
Output transformation: g(y)
Applied after base forward + h(x). For most adapters this is identity.
OFT overrides this to apply orthogonal transformation.
Args:
y: Combined output (base + h(x))
Returns:
Transformed output
"""
# Default: identity (for LoRA/LoHa/LoKr)
return y
def bypass_forward(
self,
org_forward: Callable,
x: torch.Tensor,
*args,
**kwargs,
) -> torch.Tensor:
"""
Full bypass forward: g(f(x) + h(x, f(x)))
Args:
org_forward: Original module forward function
x: Input tensor
*args, **kwargs: Additional arguments for org_forward
Returns:
Output with adapter applied in bypass mode
"""
# Base forward: f(x)
base_out = org_forward(x, *args, **kwargs)
# Additive component: h(x, base_out) - base_out provided for shape reference
h_out = self.h(x, base_out)
# Output transformation: g(base + h)
return self.g(base_out + h_out)
def passive_memory_usage(self):
raise NotImplementedError("passive_memory_usage is not implemented")
@ -59,8 +266,12 @@ class WeightAdapterTrainBase(nn.Module):
return self.passive_memory_usage()
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
def weight_decompose(
dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function
):
dora_scale = comfy.model_management.cast_to_device(
dora_scale, weight.device, intermediate_dtype
)
lora_diff *= alpha
weight_calc = weight + function(lora_diff).type(weight.dtype)
@ -106,10 +317,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
the original tensor will be truncated in that dimension.
"""
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
raise ValueError(
"The new shape must be larger than the original tensor in all dimensions"
)
if len(new_shape) != len(tensor.shape):
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
raise ValueError(
"The new shape must have the same number of dimensions as the original tensor"
)
# Create a new tensor filled with zeros
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)

View File

@ -62,9 +62,13 @@ class BOFTAdapter(WeightAdapterBase):
alpha = v[2]
dora_scale = v[3]
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
blocks = comfy.model_management.cast_to_device(
blocks, weight.device, intermediate_dtype
)
if rescale is not None:
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
rescale = comfy.model_management.cast_to_device(
rescale, weight.device, intermediate_dtype
)
boft_m, block_num, boft_b, *_ = blocks.shape
@ -74,7 +78,7 @@ class BOFTAdapter(WeightAdapterBase):
# for Q = -Q^T
q = blocks - blocks.transpose(-1, -2)
normed_q = q
if alpha > 0: # alpha in boft/bboft is for constraint
if alpha > 0: # alpha in boft/bboft is for constraint
q_norm = torch.norm(q) + 1e-8
if q_norm > alpha:
normed_q = q * alpha / q_norm
@ -83,13 +87,13 @@ class BOFTAdapter(WeightAdapterBase):
r = r.to(weight)
inp = org = weight
r_b = boft_b//2
r_b = boft_b // 2
for i in range(boft_m):
bi = r[i]
g = 2
k = 2**i * r_b
if strength != 1:
bi = bi * strength + (1-strength) * I
bi = bi * strength + (1 - strength) * I
inp = (
inp.unflatten(0, (-1, g, k))
.transpose(1, 2)
@ -98,18 +102,117 @@ class BOFTAdapter(WeightAdapterBase):
)
inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
inp = (
inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2)
inp.flatten(0, 1)
.unflatten(0, (-1, k, g))
.transpose(1, 2)
.flatten(0, 2)
)
if rescale is not None:
inp = inp * rescale
lora_diff = inp - org
lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
lora_diff = comfy.model_management.cast_to_device(
lora_diff, weight.device, intermediate_dtype
)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
weight = weight_decompose(
dora_scale,
weight,
lora_diff,
alpha,
strength,
intermediate_dtype,
function,
)
else:
weight += function((strength * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight
def _get_orthogonal_matrices(self, device, dtype):
"""Compute the orthogonal rotation matrices R from BOFT blocks."""
v = self.weights
blocks = v[0].to(device=device, dtype=dtype)
alpha = v[2]
if alpha is None:
alpha = 0
boft_m, block_num, boft_b, _ = blocks.shape
I = torch.eye(boft_b, device=device, dtype=dtype)
# Q = blocks - blocks^T (skew-symmetric)
q = blocks - blocks.transpose(-1, -2)
normed_q = q
# Apply constraint if alpha > 0
if alpha > 0:
q_norm = torch.norm(q) + 1e-8
if q_norm > alpha:
normed_q = q * alpha / q_norm
# Cayley transform: R = (I + Q)(I - Q)^-1
r = (I + normed_q) @ (I - normed_q).float().inverse()
return r, boft_m, boft_b
def g(self, y: torch.Tensor) -> torch.Tensor:
"""
Output transformation for BOFT: applies butterfly orthogonal transform.
BOFT uses multiple stages of butterfly-structured orthogonal transforms.
Reference: LyCORIS ButterflyOFTModule._bypass_forward
"""
v = self.weights
rescale = v[1]
r, boft_m, boft_b = self._get_orthogonal_matrices(y.device, y.dtype)
r_b = boft_b // 2
# Apply multiplier
multiplier = getattr(self, "multiplier", 1.0)
I = torch.eye(boft_b, device=y.device, dtype=y.dtype)
# Use module info from bypass injection to determine conv vs linear
is_conv = getattr(self, "is_conv", y.dim() > 2)
if is_conv:
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
y = y.transpose(1, -1)
# Apply butterfly transform stages
inp = y
for i in range(boft_m):
bi = r[i] # (block_num, boft_b, boft_b)
g = 2
k = 2**i * r_b
# Interpolate with identity based on multiplier
if multiplier != 1:
bi = bi * multiplier + (1 - multiplier) * I
# Reshape for butterfly: unflatten last dim, transpose, flatten, unflatten
inp = (
inp.unflatten(-1, (-1, g, k))
.transpose(-2, -1)
.flatten(-3)
.unflatten(-1, (-1, boft_b))
)
# Apply block-diagonal orthogonal transform
inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp)
# Reshape back
inp = (
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
)
# Apply rescale if present
if rescale is not None:
rescale = rescale.to(device=y.device, dtype=y.dtype)
inp = inp * rescale.transpose(0, -1)
if is_conv:
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
inp = inp.transpose(1, -1)
return inp

View File

@ -0,0 +1,437 @@
"""
Bypass mode implementation for weight adapters (LoRA, LoKr, LoHa, etc.)
Bypass mode applies adapters during forward pass without modifying base weights:
bypass(f)(x) = g(f(x) + h(x))
Where:
- f(x): Original layer forward
- h(x): Additive component from adapter (LoRA path)
- g(y): Output transformation (identity for most adapters)
This is useful for:
- Training with gradient checkpointing
- Avoiding weight modifications when weights are offloaded
- Supporting multiple adapters with different strengths dynamically
"""
import logging
from typing import Optional, Union
import torch
import torch.nn as nn
from .base import WeightAdapterBase, WeightAdapterTrainBase
from comfy.patcher_extension import PatcherInjection
# Type alias for adapters that support bypass mode
BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase]
def get_module_type_info(module: nn.Module) -> dict:
"""
Determine module type and extract conv parameters from module class.
This is more reliable than checking weight.ndim, especially for quantized layers
where weight shape might be different.
Returns:
dict with keys: is_conv, conv_dim, stride, padding, dilation, groups
"""
info = {
"is_conv": False,
"conv_dim": 0,
"stride": (1,),
"padding": (0,),
"dilation": (1,),
"groups": 1,
"kernel_size": (1,),
"in_channels": None,
"out_channels": None,
}
# Determine conv type
if isinstance(module, nn.Conv1d):
info["is_conv"] = True
info["conv_dim"] = 1
elif isinstance(module, nn.Conv2d):
info["is_conv"] = True
info["conv_dim"] = 2
elif isinstance(module, nn.Conv3d):
info["is_conv"] = True
info["conv_dim"] = 3
elif isinstance(module, nn.Linear):
info["is_conv"] = False
info["conv_dim"] = 0
else:
# Try to infer from class name for custom/quantized layers
class_name = type(module).__name__.lower()
if "conv3d" in class_name:
info["is_conv"] = True
info["conv_dim"] = 3
elif "conv2d" in class_name:
info["is_conv"] = True
info["conv_dim"] = 2
elif "conv1d" in class_name:
info["is_conv"] = True
info["conv_dim"] = 1
elif "conv" in class_name:
info["is_conv"] = True
info["conv_dim"] = 2
# Extract conv parameters if it's a conv layer
if info["is_conv"]:
# Try to get stride, padding, dilation, groups, kernel_size from module
info["stride"] = getattr(module, "stride", (1,) * info["conv_dim"])
info["padding"] = getattr(module, "padding", (0,) * info["conv_dim"])
info["dilation"] = getattr(module, "dilation", (1,) * info["conv_dim"])
info["groups"] = getattr(module, "groups", 1)
info["kernel_size"] = getattr(module, "kernel_size", (1,) * info["conv_dim"])
info["in_channels"] = getattr(module, "in_channels", None)
info["out_channels"] = getattr(module, "out_channels", None)
# Ensure they're tuples
if isinstance(info["stride"], int):
info["stride"] = (info["stride"],) * info["conv_dim"]
if isinstance(info["padding"], int):
info["padding"] = (info["padding"],) * info["conv_dim"]
if isinstance(info["dilation"], int):
info["dilation"] = (info["dilation"],) * info["conv_dim"]
if isinstance(info["kernel_size"], int):
info["kernel_size"] = (info["kernel_size"],) * info["conv_dim"]
return info
class BypassForwardHook:
"""
Hook that wraps a layer's forward to apply adapter in bypass mode.
Stores the original forward and replaces it with bypass version.
Supports both:
- WeightAdapterBase: Inference adapters (uses self.weights tuple)
- WeightAdapterTrainBase: Training adapters (nn.Module with parameters)
"""
def __init__(
self,
module: nn.Module,
adapter: BypassAdapter,
multiplier: float = 1.0,
):
self.module = module
self.adapter = adapter
self.multiplier = multiplier
self.original_forward = None
# Determine layer type and conv params from module class (works for quantized layers)
module_info = get_module_type_info(module)
# Set multiplier and layer type info on adapter for use in h()
adapter.multiplier = multiplier
adapter.is_conv = module_info["is_conv"]
adapter.conv_dim = module_info["conv_dim"]
adapter.kernel_size = module_info["kernel_size"]
adapter.in_channels = module_info["in_channels"]
adapter.out_channels = module_info["out_channels"]
# Store kw_dict for conv operations (like LyCORIS extra_args)
if module_info["is_conv"]:
adapter.kw_dict = {
"stride": module_info["stride"],
"padding": module_info["padding"],
"dilation": module_info["dilation"],
"groups": module_info["groups"],
}
else:
adapter.kw_dict = {}
def _bypass_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""Bypass forward: uses adapter's bypass_forward or default g(f(x) + h(x))
Note:
Bypass mode does NOT access original model weights (org_weight).
This is intentional - bypass mode is designed for quantized models
where weights may not be in a usable format. All necessary shape
information is provided via adapter attributes set during inject().
"""
# Check if adapter has custom bypass_forward (e.g., GLoRA)
adapter_bypass = getattr(self.adapter, "bypass_forward", None)
if adapter_bypass is not None:
# Check if it's overridden (not the base class default)
# Need to check both base classes since adapter could be either type
adapter_type = type(self.adapter)
is_default_bypass = (
adapter_type.bypass_forward is WeightAdapterBase.bypass_forward
or adapter_type.bypass_forward is WeightAdapterTrainBase.bypass_forward
)
if not is_default_bypass:
return adapter_bypass(self.original_forward, x, *args, **kwargs)
# Default bypass: g(f(x) + h(x, f(x)))
base_out = self.original_forward(x, *args, **kwargs)
h_out = self.adapter.h(x, base_out)
return self.adapter.g(base_out + h_out)
def inject(self):
"""Replace module forward with bypass version."""
if self.original_forward is not None:
logging.debug(
f"[BypassHook] Already injected for {type(self.module).__name__}"
)
return # Already injected
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
device = None
dtype = None
if hasattr(self.module, "weight") and self.module.weight is not None:
device = self.module.weight.device
dtype = self.module.weight.dtype
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
device = self.module.W_q.device
dtype = self.module.W_q.dtype
if device is not None:
self._move_adapter_weights_to_device(device, dtype)
self.original_forward = self.module.forward
self.module.forward = self._bypass_forward
logging.debug(
f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})"
)
def _move_adapter_weights_to_device(self, device, dtype=None):
"""Move adapter weights to specified device to avoid per-forward transfers.
Handles both:
- WeightAdapterBase: has self.weights tuple of tensors
- WeightAdapterTrainBase: nn.Module with parameters, uses .to() method
"""
adapter = self.adapter
# Check if adapter is an nn.Module (WeightAdapterTrainBase)
if isinstance(adapter, nn.Module):
# In training mode we don't touch dtype as trainer will handle it
adapter.to(device=device)
logging.debug(
f"[BypassHook] Moved training adapter (nn.Module) to {device}"
)
return
# WeightAdapterBase: handle self.weights tuple
if not hasattr(adapter, "weights") or adapter.weights is None:
return
weights = adapter.weights
if isinstance(weights, (list, tuple)):
new_weights = []
for w in weights:
if isinstance(w, torch.Tensor):
if dtype is not None:
new_weights.append(w.to(device=device, dtype=dtype))
else:
new_weights.append(w.to(device=device))
else:
new_weights.append(w)
adapter.weights = (
tuple(new_weights) if isinstance(weights, tuple) else new_weights
)
elif isinstance(weights, torch.Tensor):
if dtype is not None:
adapter.weights = weights.to(device=device, dtype=dtype)
else:
adapter.weights = weights.to(device=device)
logging.debug(f"[BypassHook] Moved adapter weights to {device}")
def eject(self):
"""Restore original module forward."""
if self.original_forward is None:
logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}")
return # Not injected
self.module.forward = self.original_forward
self.original_forward = None
logging.debug(
f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}"
)
class BypassInjectionManager:
"""
Manages bypass mode injection for a collection of adapters.
Creates PatcherInjection objects that can be used with ModelPatcher.
Supports both inference adapters (WeightAdapterBase) and training adapters
(WeightAdapterTrainBase).
Usage:
manager = BypassInjectionManager()
manager.add_adapter("model.layers.0.self_attn.q_proj", lora_adapter, strength=0.8)
manager.add_adapter("model.layers.0.self_attn.k_proj", lora_adapter, strength=0.8)
injections = manager.create_injections(model)
model_patcher.set_injections("bypass_lora", injections)
"""
def __init__(self):
self.adapters: dict[str, tuple[BypassAdapter, float]] = {}
self.hooks: list[BypassForwardHook] = []
def add_adapter(
self,
key: str,
adapter: BypassAdapter,
strength: float = 1.0,
):
"""
Add an adapter for a specific weight key.
Args:
key: Weight key (e.g., "model.layers.0.self_attn.q_proj.weight")
adapter: The weight adapter (LoRAAdapter, LoKrAdapter, etc.)
strength: Multiplier for adapter effect
"""
# Remove .weight suffix if present for module lookup
module_key = key
if module_key.endswith(".weight"):
module_key = module_key[:-7]
logging.debug(
f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}"
)
self.adapters[module_key] = (adapter, strength)
logging.debug(
f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})"
)
def clear_adapters(self):
"""Remove all adapters."""
self.adapters.clear()
def _get_module_by_key(self, model: nn.Module, key: str) -> Optional[nn.Module]:
"""Get a submodule by dot-separated key."""
parts = key.split(".")
module = model
try:
for i, part in enumerate(parts):
if part.isdigit():
module = module[int(part)]
else:
module = getattr(module, part)
logging.debug(
f"[BypassManager] Found module for key {key}: {type(module).__name__}"
)
return module
except (AttributeError, IndexError, KeyError) as e:
logging.error(f"[BypassManager] Failed to find module for key {key}: {e}")
logging.error(
f"[BypassManager] Failed at part index {i}, part={part}, current module type={type(module).__name__}"
)
return None
def create_injections(self, model: nn.Module) -> list[PatcherInjection]:
"""
Create PatcherInjection objects for all registered adapters.
Args:
model: The model to inject into (e.g., model_patcher.model)
Returns:
List of PatcherInjection objects to use with model_patcher.set_injections()
"""
self.hooks.clear()
logging.debug(
f"[BypassManager] create_injections called with {len(self.adapters)} adapters"
)
logging.debug(f"[BypassManager] Model type: {type(model).__name__}")
for key, (adapter, strength) in self.adapters.items():
logging.debug(f"[BypassManager] Looking for module: {key}")
module = self._get_module_by_key(model, key)
if module is None:
logging.warning(f"[BypassManager] Module not found for key {key}")
continue
if not hasattr(module, "weight"):
logging.warning(
f"[BypassManager] Module {key} has no weight attribute (type={type(module).__name__})"
)
continue
logging.debug(
f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})"
)
hook = BypassForwardHook(module, adapter, multiplier=strength)
self.hooks.append(hook)
logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks")
# Create single injection that manages all hooks
def inject_all(model_patcher):
logging.debug(
f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks"
)
for hook in self.hooks:
hook.inject()
logging.debug(
f"[BypassManager] Injected hook for {type(hook.module).__name__}"
)
def eject_all(model_patcher):
logging.debug(
f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks"
)
for hook in self.hooks:
hook.eject()
return [PatcherInjection(inject=inject_all, eject=eject_all)]
def get_hook_count(self) -> int:
"""Return number of hooks that will be/are injected."""
return len(self.hooks)
def create_bypass_injections_from_patches(
model: nn.Module,
patches: dict,
strength: float = 1.0,
) -> list[PatcherInjection]:
"""
Convenience function to create bypass injections from a patches dict.
This is useful when you have patches in the format used by model_patcher.add_patches()
and want to apply them in bypass mode instead.
Args:
model: The model to inject into
patches: Dict mapping weight keys to adapter data
strength: Global strength multiplier
Returns:
List of PatcherInjection objects
"""
manager = BypassInjectionManager()
for key, patch_list in patches.items():
if not patch_list:
continue
# patches format: list of (strength_patch, patch_data, strength_model, offset, function)
for patch in patch_list:
patch_strength, patch_data, strength_model, offset, function = patch
# patch_data should be a WeightAdapterBase/WeightAdapterTrainBase or tuple
if isinstance(patch_data, (WeightAdapterBase, WeightAdapterTrainBase)):
adapter = patch_data
else:
# Skip non-adapter patches
continue
combined_strength = strength * patch_strength
manager.add_adapter(key, adapter, strength=combined_strength)
return manager.create_injections(model)

View File

@ -1,7 +1,8 @@
import logging
from typing import Optional
from typing import Callable, Optional
import torch
import torch.nn.functional as F
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose
@ -29,7 +30,14 @@ class GLoRAAdapter(WeightAdapterBase):
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
weights = (
lora[a1_name],
lora[a2_name],
lora[b1_name],
lora[b2_name],
alpha,
dora_scale,
)
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
@ -58,16 +66,28 @@ class GLoRAAdapter(WeightAdapterBase):
old_glora = True
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
if (
old_glora
and v[1].shape[0] == weight.shape[0]
and weight.shape[0] == weight.shape[1]
):
pass
else:
old_glora = False
rank = v[1].shape[0]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
a1 = comfy.model_management.cast_to_device(
v[0].flatten(start_dim=1), weight.device, intermediate_dtype
)
a2 = comfy.model_management.cast_to_device(
v[1].flatten(start_dim=1), weight.device, intermediate_dtype
)
b1 = comfy.model_management.cast_to_device(
v[2].flatten(start_dim=1), weight.device, intermediate_dtype
)
b2 = comfy.model_management.cast_to_device(
v[3].flatten(start_dim=1), weight.device, intermediate_dtype
)
if v[4] is not None:
alpha = v[4] / rank
@ -76,18 +96,195 @@ class GLoRAAdapter(WeightAdapterBase):
try:
if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
lora_diff = (
torch.mm(b2, b1)
+ torch.mm(
torch.mm(
weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2
),
a1,
)
).reshape(
weight.shape
) # old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
lora_diff = torch.einsum(
"o i ..., i j -> o j ...",
torch.einsum(
"o i ..., i j -> o j ...",
weight.to(dtype=intermediate_dtype),
a1,
),
a2,
).reshape(weight.shape)
else:
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
lora_diff = torch.mm(
torch.mm(weight.to(dtype=intermediate_dtype), a1), a2
).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
weight = weight_decompose(
dora_scale,
weight,
lora_diff,
alpha,
strength,
intermediate_dtype,
function,
)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight
def _compute_paths(self, x: torch.Tensor):
"""
Compute A path and B path outputs for GLoRA bypass.
GLoRA: f(x) = Wx + WAx + Bx
- A path: a1(a2(x)) - modifies input to base forward
- B path: b1(b2(x)) - additive component
Note:
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Returns: (a_out, b_out)
"""
v = self.weights
# v = (a1, a2, b1, b2, alpha, dora_scale)
a1 = v[0]
a2 = v[1]
b1 = v[2]
b2 = v[3]
alpha = v[4]
dtype = x.dtype
# Cast dtype (weights should already be on correct device from inject())
a1 = a1.to(dtype=dtype)
a2 = a2.to(dtype=dtype)
b1 = b1.to(dtype=dtype)
b2 = b2.to(dtype=dtype)
# Determine rank and scale
# Check for old vs new glora format
old_glora = False
if b2.shape[1] == b1.shape[0] == a1.shape[0] == a2.shape[1]:
rank = a1.shape[0]
old_glora = True
if b2.shape[0] == b1.shape[1] == a1.shape[1] == a2.shape[0]:
if old_glora and a2.shape[0] == x.shape[-1] and x.shape[-1] == x.shape[-1]:
pass
else:
old_glora = False
rank = a2.shape[0]
if alpha is not None:
scale = alpha / rank
else:
scale = 1.0
# Apply multiplier
multiplier = getattr(self, "multiplier", 1.0)
scale = scale * multiplier
# Use module info from bypass injection, not input tensor shape
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {})
if is_conv:
# Conv case - conv_dim is 1/2/3 for conv1d/2d/3d
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
# Get module's stride/padding for spatial dimension handling
module_stride = kw_dict.get("stride", (1,) * conv_dim)
module_padding = kw_dict.get("padding", (0,) * conv_dim)
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
in_channels = getattr(self, "in_channels", None)
# Ensure weights are in conv shape
# a1, a2, b1 are always 1x1 kernels
if a1.ndim == 2:
a1 = a1.view(*a1.shape, *([1] * conv_dim))
if a2.ndim == 2:
a2 = a2.view(*a2.shape, *([1] * conv_dim))
if b1.ndim == 2:
b1 = b1.view(*b1.shape, *([1] * conv_dim))
# b2 has actual kernel_size (like LoRA down)
if b2.ndim == 2:
if in_channels is not None:
b2 = b2.view(b2.shape[0], in_channels, *kernel_size)
else:
b2 = b2.view(*b2.shape, *([1] * conv_dim))
# A path: a2(x) -> a1(...) - 1x1 convs, no stride/padding needed, a_out is added to x
a2_out = conv_fn(x, a2)
a_out = conv_fn(a2_out, a1) * scale
# B path: b2(x) with kernel/stride/padding -> b1(...) 1x1
b2_out = conv_fn(x, b2, stride=module_stride, padding=module_padding)
b_out = conv_fn(b2_out, b1) * scale
else:
# Linear case
if old_glora:
# Old format: a1 @ a2 @ x, b2 @ b1
a_out = F.linear(F.linear(x, a2), a1) * scale
b_out = F.linear(F.linear(x, b1), b2) * scale
else:
# New format: x @ a1 @ a2, b1 @ b2
a_out = F.linear(F.linear(x, a1), a2) * scale
b_out = F.linear(F.linear(x, b2), b1) * scale
return a_out, b_out
def bypass_forward(
self,
org_forward: Callable,
x: torch.Tensor,
*args,
**kwargs,
) -> torch.Tensor:
"""
GLoRA bypass forward: f(x + a(x)) + b(x)
Unlike standard adapters, GLoRA modifies the input to the base forward
AND adds the B path output.
Note:
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Reference: LyCORIS GLoRAModule._bypass_forward
"""
a_out, b_out = self._compute_paths(x)
# Call base forward with modified input
base_out = org_forward(x + a_out, *args, **kwargs)
# Add B path
return base_out + b_out
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
For GLoRA, h() returns the B path output.
Note:
GLoRA's full bypass requires overriding bypass_forward() since
it also modifies the input to org_forward. This h() is provided for
compatibility but bypass_forward() should be used for correct behavior.
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
"""
_, b_out = self._compute_paths(x)
return b_out

View File

@ -1,11 +1,22 @@
import logging
from functools import cache
from typing import Optional
import torch
import torch.nn.functional as F
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
@cache
def _warn_loha_bypass_inefficient():
"""One-time warning about LoHa bypass inefficiency."""
logging.warning(
"LoHa bypass mode is inefficient: full weight diff is computed each forward pass. "
"Consider using LoRA or LoKr for training with bypass mode."
)
class HadaWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
@ -105,9 +116,19 @@ class LohaDiff(WeightAdapterTrainBase):
scale = self.alpha / self.rank
if self.use_tucker:
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
diff_weight = HadaWeightTucker.apply(
self.hada_t1,
self.hada_w1_a,
self.hada_w1_b,
self.hada_t2,
self.hada_w2_a,
self.hada_w2_b,
scale,
)
else:
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
diff_weight = HadaWeight.apply(
self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale
)
# Add the scaled difference to the original weight
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
@ -138,9 +159,7 @@ class LoHaAdapter(WeightAdapterBase):
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.normal_(mat3, 0.1)
torch.nn.init.normal_(mat4, 0.01)
return LohaDiff(
(mat1, mat2, alpha, mat3, mat4, None, None, None)
)
return LohaDiff((mat1, mat2, alpha, mat3, mat4, None, None, None))
def to_train(self):
return LohaDiff(self.weights)
@ -172,7 +191,16 @@ class LoHaAdapter(WeightAdapterBase):
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)
weights = (
lora[hada_w1_a_name],
lora[hada_w1_b_name],
alpha,
lora[hada_w2_a_name],
lora[hada_w2_b_name],
hada_t1,
hada_t2,
dora_scale,
)
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
@ -203,30 +231,148 @@ class LoHaAdapter(WeightAdapterBase):
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
if v[5] is not None: # cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
m1 = torch.einsum(
"i j k l, j r, i p -> p r k l",
comfy.model_management.cast_to_device(
t1, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w1b, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w1a, weight.device, intermediate_dtype
),
)
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
m2 = torch.einsum(
"i j k l, j r, i p -> p r k l",
comfy.model_management.cast_to_device(
t2, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w2b, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w2a, weight.device, intermediate_dtype
),
)
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
m1 = torch.mm(
comfy.model_management.cast_to_device(
w1a, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w1b, weight.device, intermediate_dtype
),
)
m2 = torch.mm(
comfy.model_management.cast_to_device(
w2a, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w2b, weight.device, intermediate_dtype
),
)
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
weight = weight_decompose(
dora_scale,
weight,
lora_diff,
alpha,
strength,
intermediate_dtype,
function,
)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component for LoHa: h(x) = diff_weight @ x
WARNING: Inefficient - computes full Hadamard product each forward.
Note:
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
Reference: LyCORIS functional/loha.py bypass_forward_diff
"""
_warn_loha_bypass_inefficient()
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
v = self.weights
# v[0]=w1a, v[1]=w1b, v[2]=alpha, v[3]=w2a, v[4]=w2b, v[5]=t1, v[6]=t2, v[7]=dora
w1a = v[0]
w1b = v[1]
alpha = v[2]
w2a = v[3]
w2b = v[4]
t1 = v[5]
t2 = v[6]
# Compute scale
rank = w1b.shape[0]
scale = (alpha / rank if alpha is not None else 1.0) * getattr(
self, "multiplier", 1.0
)
# Cast dtype
w1a = w1a.to(dtype=x.dtype)
w1b = w1b.to(dtype=x.dtype)
w2a = w2a.to(dtype=x.dtype)
w2b = w2b.to(dtype=x.dtype)
# Use module info from bypass injection, not weight dimension
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {})
# Compute diff weight using Hadamard product
if t1 is not None and t2 is not None:
t1 = t1.to(dtype=x.dtype)
t2 = t2.to(dtype=x.dtype)
m1 = torch.einsum("i j k l, j r, i p -> p r k l", t1, w1b, w1a)
m2 = torch.einsum("i j k l, j r, i p -> p r k l", t2, w2b, w2a)
diff_weight = (m1 * m2) * scale
else:
m1 = w1a @ w1b
m2 = w2a @ w2b
diff_weight = (m1 * m2) * scale
if is_conv:
op = FUNC_LIST[conv_dim + 2]
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
in_channels = getattr(self, "in_channels", None)
# Reshape 2D diff_weight to conv format using kernel_size
# diff_weight: [out_channels, in_channels * prod(kernel_size)] -> [out_channels, in_channels, *kernel_size]
if diff_weight.dim() == 2:
if in_channels is not None:
diff_weight = diff_weight.view(
diff_weight.shape[0], in_channels, *kernel_size
)
else:
diff_weight = diff_weight.view(
*diff_weight.shape, *([1] * conv_dim)
)
else:
op = F.linear
kw_dict = {}
return op(x, diff_weight, **kw_dict)

View File

@ -2,6 +2,7 @@ import logging
from typing import Optional
import torch
import torch.nn.functional as F
import comfy.model_management
from .base import (
WeightAdapterBase,
@ -14,7 +15,17 @@ from .base import (
class LokrDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
(
lokr_w1,
lokr_w2,
alpha,
lokr_w1_a,
lokr_w1_b,
lokr_w2_a,
lokr_w2_b,
lokr_t2,
dora_scale,
) = weights
self.use_tucker = False
if lokr_w1_a is not None:
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
@ -57,10 +68,10 @@ class LokrDiff(WeightAdapterTrainBase):
if self.w2_rebuild:
if self.use_tucker:
w2 = torch.einsum(
'i j k l, j r, i p -> p r k l',
"i j k l, j r, i p -> p r k l",
self.lokr_t2,
self.lokr_w2_b,
self.lokr_w2_a
self.lokr_w2_a,
)
else:
w2 = self.lokr_w2_a @ self.lokr_w2_b
@ -69,9 +80,89 @@ class LokrDiff(WeightAdapterTrainBase):
return self.lokr_w2
def __call__(self, w):
diff = torch.kron(self.w1, self.w2)
w1 = self.w1
w2 = self.w2
# Unsqueeze w1 to match w2 dims for proper kron product (like LyCORIS make_kron)
for _ in range(w2.dim() - w1.dim()):
w1 = w1.unsqueeze(-1)
diff = torch.kron(w1, w2)
return w + diff.reshape(w.shape).to(w)
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component for LoKr training: efficient Kronecker product.
Uses w1/w2 properties which handle both direct and decomposed cases.
For create_train (direct w1/w2), no alpha scaling in properties.
For to_train (decomposed), alpha/rank scaling is in properties.
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
"""
# Get w1, w2 from properties (handles rebuild vs direct)
w1 = self.w1
w2 = self.w2
# Multiplier from bypass injection
multiplier = getattr(self, "multiplier", 1.0)
# Get module info from bypass injection
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {})
# Efficient Kronecker application without materializing full weight
# kron(w1, w2) @ x can be computed as nested operations
# w1: [out_l, in_m], w2: [out_k, in_n, *k_size]
# Full weight would be [out_l*out_k, in_m*in_n, *k_size]
uq = w1.size(1) # in_m - inner grouping dimension
if is_conv:
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
B, C_in, *spatial = x.shape
# Reshape input for grouped application: [B * uq, C_in // uq, *spatial]
h_in_group = x.reshape(B * uq, -1, *spatial)
# Ensure w2 has conv dims
if w2.dim() == 2:
w2 = w2.view(*w2.shape, *([1] * conv_dim))
# Apply w2 path with stride/padding
hb = conv_fn(h_in_group, w2, **kw_dict)
# Reshape for cross-group operation
hb = hb.view(B, -1, *hb.shape[1:])
h_cross = hb.transpose(1, -1)
# Apply w1 (always 2D, applied as linear on channel dim)
hc = F.linear(h_cross, w1)
hc = hc.transpose(1, -1)
# Reshape to output
out = hc.reshape(B, -1, *hc.shape[3:])
else:
# Linear case
# Reshape input: [..., in_m * in_n] -> [..., uq (in_m), in_n]
h_in_group = x.reshape(*x.shape[:-1], uq, -1)
# Apply w2: [..., uq, in_n] @ [out_k, in_n].T -> [..., uq, out_k]
hb = F.linear(h_in_group, w2)
# Transpose for w1: [..., uq, out_k] -> [..., out_k, uq]
h_cross = hb.transpose(-1, -2)
# Apply w1: [..., out_k, uq] @ [out_l, uq].T -> [..., out_k, out_l]
hc = F.linear(h_cross, w1)
# Transpose back and flatten: [..., out_k, out_l] -> [..., out_l * out_k]
hc = hc.transpose(-1, -2)
out = hc.reshape(*hc.shape[:-2], -1)
return out * multiplier
def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())
@ -86,16 +177,22 @@ class LoKrAdapter(WeightAdapterBase):
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
out1, out2 = factorization(out_dim, rank)
in1, in2 = factorization(in_dim, rank)
mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32)
mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32)
in_dim = weight.shape[1] # Just in_channels, not flattened with kernel
k_size = weight.shape[2:] if weight.dim() > 2 else ()
out_l, out_k = factorization(out_dim, rank)
in_m, in_n = factorization(in_dim, rank)
# w1: [out_l, in_m]
mat1 = torch.empty(out_l, in_m, device=weight.device, dtype=torch.float32)
# w2: [out_k, in_n, *k_size] for conv, [out_k, in_n] for linear
mat2 = torch.empty(
out_k, in_n, *k_size, device=weight.device, dtype=torch.float32
)
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
torch.nn.init.constant_(mat1, 0.0)
return LokrDiff(
(mat1, mat2, alpha, None, None, None, None, None, None)
)
return LokrDiff((mat1, mat2, alpha, None, None, None, None, None, None))
def to_train(self):
return LokrDiff(self.weights)
@ -154,8 +251,23 @@ class LoKrAdapter(WeightAdapterBase):
lokr_t2 = lora[lokr_t2_name]
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
if (
(lokr_w1 is not None)
or (lokr_w2 is not None)
or (lokr_w1_a is not None)
or (lokr_w2_a is not None)
):
weights = (
lokr_w1,
lokr_w2,
alpha,
lokr_w1_a,
lokr_w1_b,
lokr_w2_a,
lokr_w2_b,
lokr_t2,
dora_scale,
)
return cls(loaded_keys, weights)
else:
return None
@ -184,23 +296,47 @@ class LoKrAdapter(WeightAdapterBase):
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
w1 = torch.mm(
comfy.model_management.cast_to_device(
w1_a, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w1_b, weight.device, intermediate_dtype
),
)
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
w1 = comfy.model_management.cast_to_device(
w1, weight.device, intermediate_dtype
)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
w2 = torch.mm(
comfy.model_management.cast_to_device(
w2_a, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w2_b, weight.device, intermediate_dtype
),
)
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
w2 = torch.einsum(
"i j k l, j r, i p -> p r k l",
comfy.model_management.cast_to_device(
t2, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w2_b, weight.device, intermediate_dtype
),
comfy.model_management.cast_to_device(
w2_a, weight.device, intermediate_dtype
),
)
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
w2 = comfy.model_management.cast_to_device(
w2, weight.device, intermediate_dtype
)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
@ -212,9 +348,134 @@ class LoKrAdapter(WeightAdapterBase):
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
weight = weight_decompose(
dora_scale,
weight,
lora_diff,
alpha,
strength,
intermediate_dtype,
function,
)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component for LoKr: efficient Kronecker product application.
Note:
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
Reference: LyCORIS functional/lokr.py bypass_forward_diff
"""
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
v = self.weights
# v[0]=w1, v[1]=w2, v[2]=alpha, v[3]=w1_a, v[4]=w1_b, v[5]=w2_a, v[6]=w2_b, v[7]=t2, v[8]=dora
w1 = v[0]
w2 = v[1]
alpha = v[2]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
use_w1 = w1 is not None
use_w2 = w2 is not None
tucker = t2 is not None
# Use module info from bypass injection, not weight dimension
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {}) if is_conv else {}
if is_conv:
op = FUNC_LIST[conv_dim + 2]
else:
op = F.linear
# Determine rank and scale
rank = w1_b.size(0) if not use_w1 else w2_b.size(0) if not use_w2 else alpha
scale = (alpha / rank if alpha is not None else 1.0) * getattr(
self, "multiplier", 1.0
)
# Build c (w1)
if use_w1:
c = w1.to(dtype=x.dtype)
else:
c = w1_a.to(dtype=x.dtype) @ w1_b.to(dtype=x.dtype)
uq = c.size(1)
# Build w2 components
if use_w2:
ba = w2.to(dtype=x.dtype)
else:
a = w2_b.to(dtype=x.dtype)
b = w2_a.to(dtype=x.dtype)
if is_conv:
if tucker:
# Tucker: a, b get 1s appended (kernel is in t2)
if a.dim() == 2:
a = a.view(*a.shape, *([1] * conv_dim))
if b.dim() == 2:
b = b.view(*b.shape, *([1] * conv_dim))
else:
# Non-tucker conv: b may need 1s appended
if b.dim() == 2:
b = b.view(*b.shape, *([1] * conv_dim))
# Reshape input by uq groups
if is_conv:
B, _, *rest = x.shape
h_in_group = x.reshape(B * uq, -1, *rest)
else:
h_in_group = x.reshape(*x.shape[:-1], uq, -1)
# Apply w2 path
if use_w2:
hb = op(h_in_group, ba, **kw_dict)
else:
if is_conv:
if tucker:
t = t2.to(dtype=x.dtype)
if t.dim() == 2:
t = t.view(*t.shape, *([1] * conv_dim))
ha = op(h_in_group, a)
ht = op(ha, t, **kw_dict)
hb = op(ht, b)
else:
ha = op(h_in_group, a, **kw_dict)
hb = op(ha, b)
else:
ha = op(h_in_group, a)
hb = op(ha, b)
# Reshape and apply c (w1)
if is_conv:
hb = hb.view(B, -1, *hb.shape[1:])
h_cross_group = hb.transpose(1, -1)
else:
h_cross_group = hb.transpose(-1, -2)
hc = F.linear(h_cross_group, c)
if is_conv:
hc = hc.transpose(1, -1)
out = hc.reshape(B, -1, *hc.shape[3:])
else:
hc = hc.transpose(-1, -2)
out = hc.reshape(*hc.shape[:-2], -1)
return out * scale

View File

@ -2,6 +2,7 @@ import logging
from typing import Optional
import torch
import torch.nn.functional as F
import comfy.model_management
from .base import (
WeightAdapterBase,
@ -20,11 +21,7 @@ class LoraDiff(WeightAdapterTrainBase):
rank, in_dim = mat2.shape[0], mat2.shape[1]
if mid is not None:
convdim = mid.ndim - 2
layer = (
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d
)[convdim]
layer = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)[convdim]
else:
layer = torch.nn.Linear
self.lora_up = layer(rank, out_dim, bias=False)
@ -51,6 +48,78 @@ class LoraDiff(WeightAdapterTrainBase):
weight = w + scale * diff.reshape(w.shape)
return weight.to(org_dtype)
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component for LoRA training: h(x) = up(down(x)) * scale
Simple implementation using the nn.Module weights directly.
No mid/dora/reshape branches (create_train doesn't create them).
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
"""
# Compute scale = alpha / rank * multiplier
scale = (self.alpha / self.rank) * getattr(self, "multiplier", 1.0)
# Get module info from bypass injection
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {})
# Get weights (keep in original dtype for numerical stability)
down_weight = self.lora_down.weight
up_weight = self.lora_up.weight
if is_conv:
# Conv path: use functional conv
# conv_dim: 1=conv1d, 2=conv2d, 3=conv3d
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
# Reshape 2D weights to conv format if needed
# down: [rank, in_features] -> [rank, in_channels, *kernel_size]
# up: [out_features, rank] -> [out_features, rank, 1, 1, ...]
if down_weight.dim() == 2:
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
in_channels = getattr(self, "in_channels", None)
if in_channels is not None:
down_weight = down_weight.view(
down_weight.shape[0], in_channels, *kernel_size
)
else:
# Fallback: assume 1x1 kernel
down_weight = down_weight.view(
*down_weight.shape, *([1] * conv_dim)
)
if up_weight.dim() == 2:
# up always uses 1x1 kernel
up_weight = up_weight.view(*up_weight.shape, *([1] * conv_dim))
# down conv uses stride/padding from module, up is 1x1
hidden = conv_fn(x, down_weight, **kw_dict)
# mid layer if exists (tucker decomposition)
if self.lora_mid is not None:
mid_weight = self.lora_mid.weight
if mid_weight.dim() == 2:
mid_weight = mid_weight.view(*mid_weight.shape, *([1] * conv_dim))
hidden = conv_fn(hidden, mid_weight)
# up conv is always 1x1 (no stride/padding)
out = conv_fn(hidden, up_weight)
else:
# Linear path: simple matmul chain
hidden = F.linear(x, down_weight)
# mid layer if exists
if self.lora_mid is not None:
mid_weight = self.lora_mid.weight
hidden = F.linear(hidden, mid_weight)
out = F.linear(hidden, up_weight)
return out * scale
def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())
@ -70,9 +139,7 @@ class LoRAAdapter(WeightAdapterBase):
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
torch.nn.init.constant_(mat2, 0.0)
return LoraDiff(
(mat1, mat2, alpha, None, None, None)
)
return LoraDiff((mat1, mat2, alpha, None, None, None))
def to_train(self):
return LoraDiff(self.weights)
@ -210,3 +277,85 @@ class LoRAAdapter(WeightAdapterBase):
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
Additive bypass component for LoRA: h(x) = up(down(x)) * scale
Note:
Does not access original model weights - bypass mode is designed
for quantized models where weights may not be accessible.
Args:
x: Input tensor
base_out: Output from base forward (unused, for API consistency)
Reference: LyCORIS functional/locon.py bypass_forward_diff
"""
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
v = self.weights
# v[0]=up, v[1]=down, v[2]=alpha, v[3]=mid, v[4]=dora_scale, v[5]=reshape
up = v[0]
down = v[1]
alpha = v[2]
mid = v[3]
# Compute scale = alpha / rank
rank = down.shape[0]
if alpha is not None:
scale = alpha / rank
else:
scale = 1.0
scale = scale * getattr(self, "multiplier", 1.0)
# Cast dtype
up = up.to(dtype=x.dtype)
down = down.to(dtype=x.dtype)
# Use module info from bypass injection, not weight dimension
is_conv = getattr(self, "is_conv", False)
conv_dim = getattr(self, "conv_dim", 0)
kw_dict = getattr(self, "kw_dict", {})
if is_conv:
op = FUNC_LIST[
conv_dim + 2
] # conv_dim 1->conv1d(3), 2->conv2d(4), 3->conv3d(5)
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
in_channels = getattr(self, "in_channels", None)
# Reshape 2D weights to conv format using kernel_size
# down: [rank, in_channels * prod(kernel_size)] -> [rank, in_channels, *kernel_size]
# up: [out_channels, rank] -> [out_channels, rank, 1, 1, ...] (1x1 kernel)
if down.dim() == 2:
# down.shape[1] = in_channels * prod(kernel_size)
if in_channels is not None:
down = down.view(down.shape[0], in_channels, *kernel_size)
else:
# Fallback: assume 1x1 kernel if in_channels unknown
down = down.view(*down.shape, *([1] * conv_dim))
if up.dim() == 2:
# up always uses 1x1 kernel
up = up.view(*up.shape, *([1] * conv_dim))
if mid is not None:
mid = mid.to(dtype=x.dtype)
if mid.dim() == 2:
mid = mid.view(*mid.shape, *([1] * conv_dim))
else:
op = F.linear
kw_dict = {} # linear doesn't take stride/padding
# Simple chain: down -> mid (if tucker) -> up
if mid is not None:
if not is_conv:
mid = mid.to(dtype=x.dtype)
hidden = op(x, down)
hidden = op(hidden, mid, **kw_dict)
out = op(hidden, up)
else:
hidden = op(x, down, **kw_dict)
out = op(hidden, up)
return out * scale

View File

@ -3,13 +3,18 @@ from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
from .base import (
WeightAdapterBase,
WeightAdapterTrainBase,
weight_decompose,
factorization,
)
class OFTDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
# Unpack weights tuple from LoHaAdapter
# Unpack weights tuple from OFTAdapter
blocks, rescale, alpha, _ = weights
# Create trainable parameters
@ -52,6 +57,78 @@ class OFTDiff(WeightAdapterTrainBase):
weight = self.rescale * weight
return weight.to(org_dtype)
def _get_orthogonal_matrix(self, device, dtype):
"""Compute the orthogonal rotation matrix R from OFT blocks."""
blocks = self.oft_blocks.to(device=device, dtype=dtype)
I = torch.eye(self.block_size, device=device, dtype=dtype)
# Q = blocks - blocks^T (skew-symmetric)
q = blocks - blocks.transpose(1, 2)
normed_q = q
# Apply constraint if set
if self.constraint:
q_norm = torch.norm(q) + 1e-8
if q_norm > self.constraint:
normed_q = q * self.constraint / q_norm
# Cayley transform: R = (I + Q)(I - Q)^-1
r = (I + normed_q) @ (I - normed_q).float().inverse()
return r.to(dtype)
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
"""
OFT has no additive component - returns zeros matching base_out shape.
OFT only transforms the output via g(), it doesn't add to it.
"""
return torch.zeros_like(base_out)
def g(self, y: torch.Tensor) -> torch.Tensor:
"""
Output transformation for OFT: applies orthogonal rotation.
OFT transforms output channels using block-diagonal orthogonal matrices.
"""
r = self._get_orthogonal_matrix(y.device, y.dtype)
# Apply multiplier to interpolate between identity and full transform
multiplier = getattr(self, "multiplier", 1.0)
I = torch.eye(self.block_size, device=y.device, dtype=y.dtype)
r = r * multiplier + (1 - multiplier) * I
# Use module info from bypass injection
is_conv = getattr(self, "is_conv", y.dim() > 2)
if is_conv:
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
y = y.transpose(1, -1)
# y now has channels in last dim
*batch_shape, out_features = y.shape
# Reshape to apply block-diagonal transform
# (*, out_features) -> (*, block_num, block_size)
y_blocked = y.reshape(*batch_shape, self.block_num, self.block_size)
# Apply orthogonal transform: R @ y for each block
# r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size)
out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked)
# Reshape back: (*, block_num, block_size) -> (*, out_features)
out = out_blocked.reshape(*batch_shape, out_features)
# Apply rescale if present
if self.rescaled:
rescale = self.rescale.to(device=y.device, dtype=y.dtype)
out = out * rescale.view(-1)
if is_conv:
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
out = out.transpose(1, -1)
return out
def passive_memory_usage(self):
"""Calculates memory usage of the trainable parameters."""
return sum(param.numel() * param.element_size() for param in self.parameters())
@ -68,10 +145,10 @@ class OFTAdapter(WeightAdapterBase):
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
block_size, block_num = factorization(out_dim, rank)
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32)
return OFTDiff(
(block, None, alpha, None)
block = torch.zeros(
block_num, block_size, block_size, device=weight.device, dtype=torch.float32
)
return OFTDiff((block, None, alpha, None))
def to_train(self):
return OFTDiff(self.weights)
@ -127,9 +204,13 @@ class OFTAdapter(WeightAdapterBase):
alpha = 0
dora_scale = v[3]
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
blocks = comfy.model_management.cast_to_device(
blocks, weight.device, intermediate_dtype
)
if rescale is not None:
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
rescale = comfy.model_management.cast_to_device(
rescale, weight.device, intermediate_dtype
)
block_num, block_size, *_ = blocks.shape
@ -139,23 +220,108 @@ class OFTAdapter(WeightAdapterBase):
# for Q = -Q^T
q = blocks - blocks.transpose(1, 2)
normed_q = q
if alpha > 0: # alpha in oft/boft is for constraint
if alpha > 0: # alpha in oft/boft is for constraint
q_norm = torch.norm(q) + 1e-8
if q_norm > alpha:
normed_q = q * alpha / q_norm
# use float() to prevent unsupported type in .inverse()
r = (I + normed_q) @ (I - normed_q).float().inverse()
r = r.to(weight)
# Create I in weight's dtype for the einsum
I_w = torch.eye(block_size, device=weight.device, dtype=weight.dtype)
_, *shape = weight.shape
lora_diff = torch.einsum(
"k n m, k n ... -> k m ...",
(r * strength) - strength * I,
(r * strength) - strength * I_w,
weight.view(block_num, block_size, *shape),
).view(-1, *shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
weight = weight_decompose(
dora_scale,
weight,
lora_diff,
alpha,
strength,
intermediate_dtype,
function,
)
else:
weight += function((strength * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight
def _get_orthogonal_matrix(self, device, dtype):
"""Compute the orthogonal rotation matrix R from OFT blocks."""
v = self.weights
blocks = v[0].to(device=device, dtype=dtype)
alpha = v[2]
if alpha is None:
alpha = 0
block_num, block_size, _ = blocks.shape
I = torch.eye(block_size, device=device, dtype=dtype)
# Q = blocks - blocks^T (skew-symmetric)
q = blocks - blocks.transpose(1, 2)
normed_q = q
# Apply constraint if alpha > 0
if alpha > 0:
q_norm = torch.norm(q) + 1e-8
if q_norm > alpha:
normed_q = q * alpha / q_norm
# Cayley transform: R = (I + Q)(I - Q)^-1
r = (I + normed_q) @ (I - normed_q).float().inverse()
return r, block_num, block_size
def g(self, y: torch.Tensor) -> torch.Tensor:
"""
Output transformation for OFT: applies orthogonal rotation to output.
OFT transforms the output channels using block-diagonal orthogonal matrices.
Reference: LyCORIS DiagOFTModule._bypass_forward
"""
v = self.weights
rescale = v[1]
r, block_num, block_size = self._get_orthogonal_matrix(y.device, y.dtype)
# Apply multiplier to interpolate between identity and full transform
multiplier = getattr(self, "multiplier", 1.0)
I = torch.eye(block_size, device=y.device, dtype=y.dtype)
r = r * multiplier + (1 - multiplier) * I
# Use module info from bypass injection to determine conv vs linear
is_conv = getattr(self, "is_conv", y.dim() > 2)
if is_conv:
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
y = y.transpose(1, -1)
# y now has channels in last dim
*batch_shape, out_features = y.shape
# Reshape to apply block-diagonal transform
# (*, out_features) -> (*, block_num, block_size)
y_blocked = y.view(*batch_shape, block_num, block_size)
# Apply orthogonal transform: R @ y for each block
# r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size)
out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked)
# Reshape back: (*, block_num, block_size) -> (*, out_features)
out = out_blocked.view(*batch_shape, out_features)
# Apply rescale if present
if rescale is not None:
rescale = rescale.to(device=y.device, dtype=y.dtype)
out = out * rescale.view(-1)
if is_conv:
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
out = out.transpose(1, -1)
return out

View File

@ -1383,6 +1383,8 @@ class Schema:
"""Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph."""
enable_expand: bool=False
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
accept_all_inputs: bool=False
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
def validate(self):
'''Validate the schema:
@ -1853,6 +1855,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls.GET_SCHEMA()
return cls._NOT_IDEMPOTENT
_ACCEPT_ALL_INPUTS = None
@final
@classproperty
def ACCEPT_ALL_INPUTS(cls): # noqa
if cls._ACCEPT_ALL_INPUTS is None:
cls.GET_SCHEMA()
return cls._ACCEPT_ALL_INPUTS
@final
@classmethod
def INPUT_TYPES(cls) -> dict[str, dict]:
@ -1891,6 +1901,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls._INPUT_IS_LIST = schema.is_input_list
if cls._NOT_IDEMPOTENT is None:
cls._NOT_IDEMPOTENT = schema.not_idempotent
if cls._ACCEPT_ALL_INPUTS is None:
cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs
if cls._RETURN_TYPES is None:
output = []

View File

@ -0,0 +1,66 @@
from typing import TypedDict
from pydantic import BaseModel, Field, model_validator
class InputGenerateType(TypedDict):
generate_type: str
polygon_type: str
pbr: bool
class Hunyuan3DViewImage(BaseModel):
ViewType: str = Field(..., description="Valid values: back, left, right.")
ViewImageUrl: str = Field(...)
class To3DProTaskRequest(BaseModel):
Model: str = Field(...)
Prompt: str | None = Field(None)
ImageUrl: str | None = Field(None)
MultiViewImages: list[Hunyuan3DViewImage] | None = Field(None)
EnablePBR: bool | None = Field(...)
FaceCount: int | None = Field(...)
GenerateType: str | None = Field(...)
PolygonType: str | None = Field(...)
class RequestError(BaseModel):
Code: str = Field("")
Message: str = Field("")
class To3DProTaskCreateResponse(BaseModel):
JobId: str | None = Field(None)
Error: RequestError | None = Field(None)
@model_validator(mode="before")
@classmethod
def unwrap_data(cls, values: dict) -> dict:
if "Response" in values and isinstance(values["Response"], dict):
return values["Response"]
return values
class ResultFile3D(BaseModel):
Type: str = Field(...)
Url: str = Field(...)
PreviewImageUrl: str = Field("")
class To3DProTaskResultResponse(BaseModel):
ErrorCode: str = Field("")
ErrorMessage: str = Field("")
ResultFile3Ds: list[ResultFile3D] = Field([])
Status: str = Field(...)
@model_validator(mode="before")
@classmethod
def unwrap_data(cls, values: dict) -> dict:
if "Response" in values and isinstance(values["Response"], dict):
return values["Response"]
return values
class To3DProTaskQueryRequest(BaseModel):
JobId: str = Field(...)

View File

@ -0,0 +1,297 @@
import os
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.hunyuan3d import (
Hunyuan3DViewImage,
InputGenerateType,
ResultFile3D,
To3DProTaskCreateResponse,
To3DProTaskQueryRequest,
To3DProTaskRequest,
To3DProTaskResultResponse,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_bytesio,
downscale_image_tensor_by_max_side,
poll_op,
sync_op,
upload_image_to_comfyapi,
validate_image_dimensions,
validate_string,
)
from folder_paths import get_output_directory
def get_glb_obj_from_response(response_objs: list[ResultFile3D]) -> ResultFile3D:
for i in response_objs:
if i.Type.lower() == "glb":
return i
raise ValueError("No GLB file found in response. Please report this to the developers.")
class TencentTextToModelNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TencentTextToModelNode",
display_name="Hunyuan3D: Text to Model (Pro)",
category="api node/3d/Tencent",
inputs=[
IO.Combo.Input(
"model",
options=["3.0", "3.1"],
tooltip="The LowPoly option is unavailable for the `3.1` model.",
),
IO.String.Input("prompt", multiline=True, default="", tooltip="Supports up to 1024 characters."),
IO.Int.Input("face_count", default=500000, min=40000, max=1500000),
IO.DynamicCombo.Input(
"generate_type",
options=[
IO.DynamicCombo.Option("Normal", [IO.Boolean.Input("pbr", default=False)]),
IO.DynamicCombo.Option(
"LowPoly",
[
IO.Combo.Input("polygon_type", options=["triangle", "quadrilateral"]),
IO.Boolean.Input("pbr", default=False),
],
),
IO.DynamicCombo.Option("Geometry", []),
],
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.String.Output(display_name="model_file"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["generate_type", "generate_type.pbr", "face_count"]),
expr="""
(
$base := widgets.generate_type = "normal" ? 25 : widgets.generate_type = "lowpoly" ? 30 : 15;
$pbr := $lookup(widgets, "generate_type.pbr") ? 10 : 0;
$face := widgets.face_count != 500000 ? 10 : 0;
{"type":"usd","usd": ($base + $pbr + $face) * 0.02}
)
""",
),
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
face_count: int,
generate_type: InputGenerateType,
seed: int,
) -> IO.NodeOutput:
_ = seed
validate_string(prompt, field_name="prompt", min_length=1, max_length=1024)
if model == "3.1" and generate_type["generate_type"].lower() == "lowpoly":
raise ValueError("The LowPoly option is currently unavailable for the 3.1 model.")
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro", method="POST"),
response_model=To3DProTaskCreateResponse,
data=To3DProTaskRequest(
Model=model,
Prompt=prompt,
FaceCount=face_count,
GenerateType=generate_type["generate_type"],
EnablePBR=generate_type.get("pbr", None),
PolygonType=generate_type.get("polygon_type", None),
),
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"),
data=To3DProTaskQueryRequest(JobId=response.JobId),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
model_file = f"hunyuan_model_{response.JobId}.glb"
await download_url_to_bytesio(
get_glb_obj_from_response(result.ResultFile3Ds).Url,
os.path.join(get_output_directory(), model_file),
)
return IO.NodeOutput(model_file)
class TencentImageToModelNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TencentImageToModelNode",
display_name="Hunyuan3D: Image(s) to Model (Pro)",
category="api node/3d/Tencent",
inputs=[
IO.Combo.Input(
"model",
options=["3.0", "3.1"],
tooltip="The LowPoly option is unavailable for the `3.1` model.",
),
IO.Image.Input("image"),
IO.Image.Input("image_left", optional=True),
IO.Image.Input("image_right", optional=True),
IO.Image.Input("image_back", optional=True),
IO.Int.Input("face_count", default=500000, min=40000, max=1500000),
IO.DynamicCombo.Input(
"generate_type",
options=[
IO.DynamicCombo.Option("Normal", [IO.Boolean.Input("pbr", default=False)]),
IO.DynamicCombo.Option(
"LowPoly",
[
IO.Combo.Input("polygon_type", options=["triangle", "quadrilateral"]),
IO.Boolean.Input("pbr", default=False),
],
),
IO.DynamicCombo.Option("Geometry", []),
],
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.String.Output(display_name="model_file"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["generate_type", "generate_type.pbr", "face_count"],
inputs=["image_left", "image_right", "image_back"],
),
expr="""
(
$base := widgets.generate_type = "normal" ? 25 : widgets.generate_type = "lowpoly" ? 30 : 15;
$multiview := (
inputs.image_left.connected or inputs.image_right.connected or inputs.image_back.connected
) ? 10 : 0;
$pbr := $lookup(widgets, "generate_type.pbr") ? 10 : 0;
$face := widgets.face_count != 500000 ? 10 : 0;
{"type":"usd","usd": ($base + $multiview + $pbr + $face) * 0.02}
)
""",
),
)
@classmethod
async def execute(
cls,
model: str,
image: Input.Image,
face_count: int,
generate_type: InputGenerateType,
seed: int,
image_left: Input.Image | None = None,
image_right: Input.Image | None = None,
image_back: Input.Image | None = None,
) -> IO.NodeOutput:
_ = seed
if model == "3.1" and generate_type["generate_type"].lower() == "lowpoly":
raise ValueError("The LowPoly option is currently unavailable for the 3.1 model.")
validate_image_dimensions(image, min_width=128, min_height=128)
multiview_images = []
for k, v in {
"left": image_left,
"right": image_right,
"back": image_back,
}.items():
if v is None:
continue
validate_image_dimensions(v, min_width=128, min_height=128)
multiview_images.append(
Hunyuan3DViewImage(
ViewType=k,
ViewImageUrl=await upload_image_to_comfyapi(
cls,
downscale_image_tensor_by_max_side(v, max_side=4900),
mime_type="image/webp",
total_pixels=24_010_000,
),
)
)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro", method="POST"),
response_model=To3DProTaskCreateResponse,
data=To3DProTaskRequest(
Model=model,
FaceCount=face_count,
GenerateType=generate_type["generate_type"],
ImageUrl=await upload_image_to_comfyapi(
cls,
downscale_image_tensor_by_max_side(image, max_side=4900),
mime_type="image/webp",
total_pixels=24_010_000,
),
MultiViewImages=multiview_images if multiview_images else None,
EnablePBR=generate_type.get("pbr", None),
PolygonType=generate_type.get("polygon_type", None),
),
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"),
data=To3DProTaskQueryRequest(JobId=response.JobId),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
model_file = f"hunyuan_model_{response.JobId}.glb"
await download_url_to_bytesio(
get_glb_obj_from_response(result.ResultFile3Ds).Url,
os.path.join(get_output_directory(), model_file),
)
return IO.NodeOutput(model_file)
class TencentHunyuan3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TencentTextToModelNode,
TencentImageToModelNode,
]
async def comfy_entrypoint() -> TencentHunyuan3DExtension:
return TencentHunyuan3DExtension()

View File

@ -249,7 +249,6 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
max_poll_attempts=160,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))

View File

@ -149,7 +149,6 @@ class OpenAIVideoSora2(IO.ComfyNode):
response_model=Sora2GenerationResponse,
status_extractor=lambda x: x.status,
poll_interval=8.0,
max_poll_attempts=160,
estimated_duration=int(45 * (duration / 4) * model_time_multiplier),
)
return IO.NodeOutput(

View File

@ -203,7 +203,6 @@ class TopazImageEnhance(IO.ComfyNode):
progress_extractor=lambda x: getattr(x, "progress", 0),
price_extractor=lambda x: x.credits * 0.08,
poll_interval=8.0,
max_poll_attempts=160,
estimated_duration=60,
)

View File

@ -13,6 +13,7 @@ from .conversions import (
bytesio_to_image_tensor,
convert_mask_to_image,
downscale_image_tensor,
downscale_image_tensor_by_max_side,
image_tensor_pair_to_batch,
pil_to_bytesio,
resize_mask_to_image,
@ -33,6 +34,7 @@ from .download_helpers import (
from .upload_helpers import (
upload_audio_to_comfyapi,
upload_file_to_comfyapi,
upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
)
@ -61,6 +63,7 @@ __all__ = [
# Upload helpers
"upload_audio_to_comfyapi",
"upload_file_to_comfyapi",
"upload_image_to_comfyapi",
"upload_images_to_comfyapi",
"upload_video_to_comfyapi",
# Download helpers
@ -75,6 +78,7 @@ __all__ = [
"bytesio_to_image_tensor",
"convert_mask_to_image",
"downscale_image_tensor",
"downscale_image_tensor_by_max_side",
"image_tensor_pair_to_batch",
"pil_to_bytesio",
"resize_mask_to_image",

View File

@ -141,7 +141,7 @@ async def poll_op(
queued_statuses: list[str | int] | None = None,
data: BaseModel | None = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
@ -238,7 +238,7 @@ async def poll_op_raw(
queued_statuses: list[str | int] | None = None,
data: dict[str, Any] | BaseModel | None = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,

View File

@ -144,6 +144,21 @@ def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024)
return s
def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor:
"""Downscale input image tensor so the largest dimension is at most max_side pixels."""
samples = image.movedim(-1, 1)
height, width = samples.shape[2], samples.shape[3]
max_dim = max(width, height)
if max_dim <= max_side:
return image
scale_by = max_side / max_dim
new_width = round(width * scale_by)
new_height = round(height * scale_by)
s = common_upscale(samples, new_width, new_height, "lanczos", "disabled")
s = s.movedim(1, -1)
return s
def tensor_to_data_uri(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,

View File

@ -88,6 +88,28 @@ async def upload_images_to_comfyapi(
return download_urls
async def upload_image_to_comfyapi(
cls: type[IO.ComfyNode],
image: torch.Tensor,
*,
mime_type: str | None = None,
wait_label: str | None = "Uploading",
total_pixels: int = 2048 * 2048,
) -> str:
"""Uploads a single image to ComfyUI API and returns its download URL."""
return (
await upload_images_to_comfyapi(
cls,
image,
max_images=1,
mime_type=mime_type,
wait_label=wait_label,
show_batch_index=False,
total_pixels=total_pixels,
)
)[0]
async def upload_audio_to_comfyapi(
cls: type[IO.ComfyNode],
audio: Input.Audio,

View File

@ -104,19 +104,23 @@ class CustomComboNode(io.ComfyNode):
category="utils",
is_experimental=True,
inputs=[io.Combo.Input("choice", options=[])],
outputs=[io.String.Output()]
outputs=[
io.String.Output(display_name="STRING"),
io.Int.Output(display_name="INDEX"),
],
accept_all_inputs=True,
)
@classmethod
def validate_inputs(cls, choice: io.Combo.Type) -> bool:
def validate_inputs(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> bool:
# NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs.
# I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined.
# I need to skip checking that the chosen combo option is in the options list, since those are defined by the user.
return True
@classmethod
def execute(cls, choice: io.Combo.Type) -> io.NodeOutput:
return io.NodeOutput(choice)
def execute(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> io.NodeOutput:
return io.NodeOutput(choice, index)
class DCTestNode(io.ComfyNode):

View File

@ -0,0 +1,79 @@
import folder_paths
import comfy.utils
import comfy.sd
class LoraLoaderBypass:
"""
Apply LoRA in bypass mode without modifying base model weights.
Bypass mode computes: output = base_forward(x) + lora_path(x)
This is useful for training and when model weights are offloaded.
"""
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
"clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
}
}
RETURN_TYPES = ("MODEL", "CLIP")
OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
FUNCTION = "load_lora"
CATEGORY = "loaders"
DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios."
EXPERIMENTAL = True
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
if strength_model == 0 and strength_clip == 0:
return (model, clip)
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None
if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1]
else:
self.loaded_lora = None
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip)
return (model_lora, clip_lora)
class LoraLoaderBypassModelOnly(LoraLoaderBypass):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_lora_model_only"
def load_lora_model_only(self, model, lora_name, strength_model):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
NODE_CLASS_MAPPINGS = {
"LoraLoaderBypass": LoraLoaderBypass,
"LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoraLoaderBypass": "Load LoRA (Bypass) (For debugging)",
"LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only) (for debugging)",
}

View File

@ -18,6 +18,7 @@ import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
from comfy.weight_adapter import adapters, adapter_maps
from comfy.weight_adapter.bypass import BypassInjectionManager
from comfy_api.latest import ComfyExtension, io, ui
from comfy.utils import ProgressBar
@ -339,6 +340,11 @@ class TrainSampler(comfy.samplers.Sampler):
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0:
for param_groups in self.optimizer.param_groups:
for param in param_groups["params"]:
if param.grad is None:
continue
param.grad.data = param.grad.data.to(param.data.dtype)
self.optimizer.step()
self.optimizer.zero_grad()
ui_pbar.update(1)
@ -498,9 +504,9 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode):
num_images = sum(t.shape[0] for t in latents)
multi_res = False # Not using multi_res path in bucket mode
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
for i, lat in enumerate(latents):
logging.info(f" Bucket {i}: shape {lat.shape}")
logging.debug(f" Bucket {i}: shape {lat.shape}")
return latents, num_images, multi_res
# Non-bucket mode
@ -509,7 +515,7 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode):
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
logging.debug(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
@ -545,7 +551,7 @@ def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
if bucket_mode:
return positive # Skip validation in bucket mode
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
logging.debug(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
return positive * num_images
elif len(positive) != num_images:
@ -596,6 +602,8 @@ def _create_weight_adapter(
shape = module.weight.shape
lora_params = {}
logging.debug(f"Creating weight adapter for {key} with shape {shape}")
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
@ -690,6 +698,61 @@ def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
return lora_sd, all_weight_adapters
def _setup_lora_adapters_bypass(mp, existing_weights, algorithm, lora_dtype, rank):
"""Setup LoRA adapters in bypass mode.
In bypass mode:
- Weight adapters (lora/lokr/oft) use bypass injection (forward hook)
- Bias/norm adapters (BiasDiff) still use weight wrapper (direct modification)
This is useful when the base model weights are quantized and cannot be
directly modified.
Args:
mp: Model patcher
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (lora_sd dict, all_weight_adapters list, bypass_manager)
"""
lora_sd = {}
all_weight_adapters = []
bypass_manager = BypassInjectionManager()
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
adapter, params = _create_weight_adapter(
m, n, existing_weights, algorithm, lora_dtype, rank
)
lora_sd.update(params)
all_weight_adapters.append(adapter)
key = f"{n}.weight"
# BiasDiff (for 1D weights like norm) uses weight wrapper, not bypass
# Only use bypass for adapters that have h() method (lora/lokr/oft)
if isinstance(adapter, BiasDiff):
mp.add_weight_wrapper(key, adapter)
logging.debug(f"[BypassMode] Added 1D weight adapter (weight wrapper) for {key}")
else:
bypass_manager.add_adapter(key, adapter, strength=1.0)
logging.debug(f"[BypassMode] Added weight adapter (bypass) for {key}")
if hasattr(m, "bias") and m.bias is not None:
# Bias adapters still use weight wrapper (bias is usually not quantized)
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
lora_sd.update(bias_params)
key = f"{n}.bias"
mp.add_weight_wrapper(key, bias_adapter)
all_weight_adapters.append(bias_adapter)
logging.debug(f"[BypassMode] Added bias adapter (weight wrapper) for {key}")
return lora_sd, all_weight_adapters, bypass_manager
def _create_optimizer(optimizer_name, parameters, learning_rate):
"""Create optimizer based on name.
@ -884,11 +947,13 @@ class TrainLoraNode(io.ComfyNode):
default=False,
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
),
io.Boolean.Input(
"bypass_mode",
default=False,
tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.",
),
],
outputs=[
io.Model.Output(
display_name="model", tooltip="Model with LoRA applied"
),
io.Custom("LORA_MODEL").Output(
display_name="lora", tooltip="LoRA weights"
),
@ -919,6 +984,7 @@ class TrainLoraNode(io.ComfyNode):
gradient_checkpointing,
existing_lora,
bucket_mode,
bypass_mode,
):
# Extract scalars from lists (due to is_input_list=True)
model = model[0]
@ -936,6 +1002,7 @@ class TrainLoraNode(io.ComfyNode):
gradient_checkpointing = gradient_checkpointing[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0]
# Process latents based on mode
if bucket_mode:
@ -968,9 +1035,16 @@ class TrainLoraNode(io.ComfyNode):
existing_weights, existing_steps = _load_existing_lora(existing_lora)
# Setup LoRA adapters
lora_sd, all_weight_adapters = _setup_lora_adapters(
mp, existing_weights, algorithm, lora_dtype, rank
)
bypass_manager = None
if bypass_mode:
logging.debug("Using bypass mode for training")
lora_sd, all_weight_adapters, bypass_manager = _setup_lora_adapters_bypass(
mp, existing_weights, algorithm, lora_dtype, rank
)
else:
lora_sd, all_weight_adapters = _setup_lora_adapters(
mp, existing_weights, algorithm, lora_dtype, rank
)
# Create optimizer and loss function
optimizer = _create_optimizer(
@ -1029,6 +1103,14 @@ class TrainLoraNode(io.ComfyNode):
guider = TrainGuider(mp)
guider.set_conds(positive)
# Inject bypass hooks if bypass mode is enabled
bypass_injections = None
if bypass_manager is not None:
bypass_injections = bypass_manager.create_injections(mp.model)
for injection in bypass_injections:
injection.inject(mp)
logging.debug(f"[BypassMode] Injected {bypass_manager.get_hook_count()} bypass hooks")
# Run training loop
try:
_run_training_loop(
@ -1041,6 +1123,11 @@ class TrainLoraNode(io.ComfyNode):
multi_res,
)
finally:
# Eject bypass hooks if they were injected
if bypass_injections is not None:
for injection in bypass_injections:
injection.eject(mp)
logging.debug("[BypassMode] Ejected bypass hooks")
for m in mp.model.modules():
unpatch(m)
del train_sampler, optimizer
@ -1052,7 +1139,9 @@ class TrainLoraNode(io.ComfyNode):
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
# mp in train node is highly specialized for training
# use it in inference will result in bad behavior so we don't return it
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):#

View File

@ -175,7 +175,7 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
continue
obj = cached.outputs[output_index]
input_data_all[x] = obj
elif input_category is not None:
elif input_category is not None or (is_v3 and class_def.ACCEPT_ALL_INPUTS):
input_data_all[x] = [input_data]
if is_v3:

View File

@ -326,7 +326,7 @@ def setup_database():
if dependencies_available():
init_db()
if not args.disable_assets_autoscan:
seed_assets(["models", "input", "output"], enable_logging=True)
seed_assets(["models"], enable_logging=True)
except Exception as e:
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")

View File

@ -2105,7 +2105,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
"CheckpointLoaderSimple": "Load Checkpoint",
"VAELoader": "Load VAE",
"LoraLoader": "Load LoRA",
"LoraLoader": "Load LoRA (Model and CLIP)",
"LoraLoaderModelOnly": "Load LoRA",
"CLIPLoader": "Load CLIP",
"ControlNetLoader": "Load ControlNet Model",
"DiffControlNetLoader": "Load ControlNet Model (diff)",
@ -2431,6 +2432,7 @@ async def init_builtin_extra_nodes():
"nodes_wanmove.py",
"nodes_image_compare.py",
"nodes_zimage.py",
"nodes_lora_debug.py"
]
import_failed = []

View File

@ -22,7 +22,6 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
blake3
#non essential dependencies:
kornia>=0.7.1

View File

@ -689,7 +689,7 @@ class PromptServer():
@routes.get("/object_info")
async def get_object_info(request):
try:
seed_assets(["models", "input", "output"])
seed_assets(["models"])
except Exception as e:
logging.error(f"Failed to seed assets: {e}")
with folder_paths.cache_helper:

View File

@ -1,177 +0,0 @@
"""Tests for Assets API endpoints (app/assets/api/routes.py)
Tests cover:
- Schema validation for query parameters and request bodies
"""
import pytest
from pydantic import ValidationError
from app.assets.api import schemas_in, schemas_out
class TestListAssetsQuery:
"""Tests for ListAssetsQuery schema."""
def test_defaults(self):
"""Test default values."""
q = schemas_in.ListAssetsQuery()
assert q.include_tags == []
assert q.exclude_tags == []
assert q.limit == 20
assert q.offset == 0
assert q.sort == "created_at"
assert q.order == "desc"
assert q.include_public == True
def test_include_public_false(self):
"""Test include_public can be set to False."""
q = schemas_in.ListAssetsQuery(include_public=False)
assert q.include_public == False
def test_csv_tags_parsing(self):
"""Test comma-separated tags are parsed correctly."""
q = schemas_in.ListAssetsQuery.model_validate({"include_tags": "a,b,c"})
assert q.include_tags == ["a", "b", "c"]
def test_metadata_filter_json_string(self):
"""Test metadata_filter accepts JSON string."""
q = schemas_in.ListAssetsQuery.model_validate({"metadata_filter": '{"key": "value"}'})
assert q.metadata_filter == {"key": "value"}
class TestTagsListQuery:
"""Tests for TagsListQuery schema."""
def test_defaults(self):
"""Test default values."""
q = schemas_in.TagsListQuery()
assert q.prefix is None
assert q.limit == 100
assert q.offset == 0
assert q.order == "count_desc"
assert q.include_zero == True
assert q.include_public == True
def test_include_public_false(self):
"""Test include_public can be set to False."""
q = schemas_in.TagsListQuery(include_public=False)
assert q.include_public == False
class TestUpdateAssetBody:
"""Tests for UpdateAssetBody schema."""
def test_requires_at_least_one_field(self):
"""Test that at least one field is required."""
with pytest.raises(ValidationError):
schemas_in.UpdateAssetBody()
def test_name_only(self):
"""Test updating name only."""
body = schemas_in.UpdateAssetBody(name="new name")
assert body.name == "new name"
assert body.mime_type is None
assert body.preview_id is None
def test_mime_type_only(self):
"""Test updating mime_type only."""
body = schemas_in.UpdateAssetBody(mime_type="image/png")
assert body.mime_type == "image/png"
def test_preview_id_only(self):
"""Test updating preview_id only."""
body = schemas_in.UpdateAssetBody(preview_id="550e8400-e29b-41d4-a716-446655440000")
assert body.preview_id == "550e8400-e29b-41d4-a716-446655440000"
def test_preview_id_invalid_uuid(self):
"""Test invalid UUID for preview_id."""
with pytest.raises(ValidationError):
schemas_in.UpdateAssetBody(preview_id="not-a-uuid")
def test_all_fields(self):
"""Test all fields together."""
body = schemas_in.UpdateAssetBody(
name="test",
mime_type="application/json",
preview_id="550e8400-e29b-41d4-a716-446655440000",
user_metadata={"key": "value"}
)
assert body.name == "test"
assert body.mime_type == "application/json"
class TestUploadAssetFromUrlBody:
"""Tests for UploadAssetFromUrlBody schema (JSON URL upload)."""
def test_valid_url(self):
"""Test valid HTTP URL."""
body = schemas_in.UploadAssetFromUrlBody(
url="https://example.com/model.safetensors",
name="model.safetensors"
)
assert body.url == "https://example.com/model.safetensors"
assert body.name == "model.safetensors"
def test_http_url(self):
"""Test HTTP URL (not just HTTPS)."""
body = schemas_in.UploadAssetFromUrlBody(
url="http://example.com/file.bin",
name="file.bin"
)
assert body.url == "http://example.com/file.bin"
def test_invalid_url_scheme(self):
"""Test invalid URL scheme raises error."""
with pytest.raises(ValidationError):
schemas_in.UploadAssetFromUrlBody(
url="ftp://example.com/file.bin",
name="file.bin"
)
def test_tags_normalized(self):
"""Test tags are normalized to lowercase."""
body = schemas_in.UploadAssetFromUrlBody(
url="https://example.com/model.safetensors",
name="model",
tags=["Models", "LORAS"]
)
assert body.tags == ["models", "loras"]
class TestTagsRefineQuery:
"""Tests for TagsRefineQuery schema."""
def test_defaults(self):
"""Test default values."""
q = schemas_in.TagsRefineQuery()
assert q.include_tags == []
assert q.exclude_tags == []
assert q.limit == 100
assert q.include_public == True
def test_include_public_false(self):
"""Test include_public can be set to False."""
q = schemas_in.TagsRefineQuery(include_public=False)
assert q.include_public == False
class TestTagHistogramResponse:
"""Tests for TagHistogramResponse schema."""
def test_empty_response(self):
"""Test empty response."""
resp = schemas_out.TagHistogramResponse()
assert resp.tags == []
def test_with_entries(self):
"""Test response with entries."""
resp = schemas_out.TagHistogramResponse(
tags=[
schemas_out.TagHistogramEntry(name="models", count=10),
schemas_out.TagHistogramEntry(name="loras", count=5),
]
)
assert len(resp.tags) == 2
assert resp.tags[0].name == "models"
assert resp.tags[0].count == 10