Compare commits

..

10 Commits

Author SHA1 Message Date
a8d2519058 ComfyUI v0.22.0 2026-05-20 13:49:36 -04:00
4efe1ddb5c chore: update workflow templates to v0.9.79 (#14011) 2026-05-20 23:46:20 +08:00
f9c84c94b4 Support Stable Audio 3 model. (#14010) 2026-05-20 11:34:22 -04:00
78b5dec6b6 fix: Hunyuan3D 2.1 batch size crashes in attention and forward pass (#13699) 2026-05-20 19:58:49 +08:00
72e3f6081c Add downscale ratio to empty ltxv latent. (#13999) 2026-05-19 20:28:06 -07:00
7ec7b6ffe9 Adding new StringFormat node (#13997) 2026-05-20 10:25:49 +08:00
6887165a9d docs(openapi): tighten workspace API key description field (BE-1004) (#13996)
Aligns the OSS spec with the cloud-side BE-1004 contract:

- createWorkspaceApiKey request body: add maxLength: 5000 to the
  description property (matches cloud's hub_profile.description
  MaxLen(5000) convention; enforced cloud-side via handler check).
- WorkspaceApiKey + WorkspaceApiKeyCreated response schemas:
  mark description as required (cloud's handler always populates
  the field, defaulting to empty string when not supplied on create),
  drop nullable: true, add maxLength: 5000 for symmetry, and clarify
  the doc string ("Always present in responses; empty string when no
  description was supplied on create").

Both schemas are tagged x-runtime: [cloud] at the schema level so the
tightening is correctly scoped — OSS-only implementations are not
required to honor the workspace API keys endpoints at all.

Related cloud PR: Comfy-Org/cloud#3747
2026-05-19 16:55:04 -07:00
cc4d711eb1 feat(openapi): add optional description field to workspace API key schemas (#13993)
* feat(openapi): add optional description field to workspace API key schemas

Add an optional `description` property (type: string) to three
workspace API key schemas in openapi.yaml:

- Inline request body of createWorkspaceApiKey (POST /api/workspace/api-keys)
- WorkspaceApiKey (list/info schema)
- WorkspaceApiKeyCreated (creation response schema)

The field is not added to any `required` array, making it fully
backward-compatible with existing clients.

Refs: BE-1005, BE-1004

Co-authored-by: Matt Miller <mattmillerai@users.noreply.github.com>

* fix(openapi): mark description nullable in workspace API key response schemas

Per CodeRabbit review on PR #13993: the underlying DB column is nullable
varchar (default ''), so the response schemas should permit null to match
stored data reality. Without nullable: true the OpenAPI contract would
require coercion on the handler side or risk a contract violation.

Request schema unchanged — clients shouldn't be sending null on create.
2026-05-19 14:48:47 -07:00
yy
626b082838 Fix typo in ops.py (#11925) 2026-05-20 05:45:04 +08:00
d0328b442d docs(openapi): remove top-level width/height fields on Asset schema (#13973)
These two fields were added recently to the Asset schema as nullable
integers, with the intent of exposing original image dimensions for FE
consumers (cloud-side thumbnailing makes naturalWidth/Height return
the wrong size for an image card's dimension label).

The implementation effort that consumes them subsequently converged on
a different shape — dimensions nested under the existing free-form
`metadata` JSON field as `{kind: "image", width, height}` — to avoid
introducing type-specific flat fields on the canonical Asset shape,
and to leave room for forward-compatible additions (video duration,
fps, etc.) without further schema churn.

This removes the now-unused top-level fields so the spec reflects the
agreed direction. No other schema definitions reference these fields
directly: AssetCreated, AssetUpdated, etc. inherit Asset via allOf and
do not redefine them.

The runtime ingest implementation that would have populated these
fields was not yet shipped, so no clients are relying on the
top-level shape.

Co-authored-by: Alexis Rolland <alexisrolland@hotmail.com>
2026-05-19 10:00:26 -07:00
24 changed files with 1222 additions and 10229 deletions

File diff suppressed because it is too large Load Diff

View File

@ -1,886 +0,0 @@
{
"revision": 0,
"last_node_id": 675,
"last_link_id": 0,
"nodes": [
{
"id": 675,
"type": "01b6a731-fb78-4070-9a38-c87146da9604",
"pos": [
-2480,
3400
],
"size": [
360,
433.3125
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "input",
"name": "input",
"type": "IMAGE,MASK",
"link": null
},
{
"label": "resize_target_longer_size",
"name": "resize_type.longer_size",
"type": "INT",
"widget": {
"name": "resize_type.longer_size"
},
"link": null
},
{
"name": "scale_method",
"type": "COMBO",
"widget": {
"name": "scale_method"
},
"link": null
},
{
"name": "draw_body",
"type": "BOOLEAN",
"widget": {
"name": "draw_body"
},
"link": null
},
{
"name": "draw_hands",
"type": "BOOLEAN",
"widget": {
"name": "draw_hands"
},
"link": null
},
{
"name": "draw_face",
"type": "BOOLEAN",
"widget": {
"name": "draw_face"
},
"link": null
},
{
"name": "draw_feet",
"type": "BOOLEAN",
"widget": {
"name": "draw_feet"
},
"link": null
},
{
"name": "stick_width",
"type": "INT",
"widget": {
"name": "stick_width"
},
"link": null
},
{
"name": "face_point_size",
"type": "INT",
"widget": {
"name": "face_point_size"
},
"link": null
},
{
"name": "score_threshold",
"type": "FLOAT",
"widget": {
"name": "score_threshold"
},
"link": null
},
{
"name": "ckpt_name",
"type": "COMBO",
"widget": {
"name": "ckpt_name"
},
"link": null
},
{
"name": "bboxes",
"shape": 7,
"type": "BOUNDING_BOX",
"link": null
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"type": "IMAGE",
"links": []
},
{
"name": "keypoints",
"type": "POSE_KEYPOINT",
"links": null
}
],
"properties": {
"proxyWidgets": [
[
"674",
"resize_type.longer_size"
],
[
"674",
"scale_method"
],
[
"672",
"draw_body"
],
[
"672",
"draw_hands"
],
[
"672",
"draw_face"
],
[
"672",
"draw_feet"
],
[
"672",
"stick_width"
],
[
"672",
"face_point_size"
],
[
"672",
"score_threshold"
],
[
"673",
"ckpt_name"
]
],
"cnr_id": "comfy-core",
"ver": "0.15.1",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [],
"title": "Image to Pose Map (SDPose-OOD)"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "01b6a731-fb78-4070-9a38-c87146da9604",
"version": 1,
"state": {
"lastGroupId": 0,
"lastNodeId": 676,
"lastLinkId": 1715,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Image to Pose Map (SDPose-OOD)",
"inputNode": {
"id": -10,
"bounding": [
-3290,
3590,
190.8984375,
288
]
},
"outputNode": {
"id": -20,
"bounding": [
-1756.2451602089645,
3366,
128,
88
]
},
"inputs": [
{
"id": "e24699c3-1356-4634-9eb4-19bb58e5c0b0",
"name": "input",
"type": "IMAGE,MASK",
"linkIds": [
1700
],
"localized_name": "input",
"pos": [
-3123.1015625,
3614
]
},
{
"id": "088eefc1-cd8a-4573-993f-9e4da008a12d",
"name": "resize_type.longer_size",
"type": "INT",
"linkIds": [
1704
],
"label": "resize_target_longer_size",
"pos": [
-3123.1015625,
3634
]
},
{
"id": "b6449bd3-73d4-41c8-b81f-cf8d33f76a2e",
"name": "scale_method",
"type": "COMBO",
"linkIds": [
1705
],
"pos": [
-3123.1015625,
3654
]
},
{
"id": "4cff52ad-ed07-4c97-8803-fcbd89554fd0",
"name": "draw_body",
"type": "BOOLEAN",
"linkIds": [
1706
],
"pos": [
-3123.1015625,
3674
]
},
{
"id": "7af63dce-f7df-4d7e-8215-d7c7f60bf81c",
"name": "draw_hands",
"type": "BOOLEAN",
"linkIds": [
1707
],
"pos": [
-3123.1015625,
3694
]
},
{
"id": "af3a9bce-61f9-4aca-b530-9f65e028b35e",
"name": "draw_face",
"type": "BOOLEAN",
"linkIds": [
1708
],
"pos": [
-3123.1015625,
3714
]
},
{
"id": "4620f6a3-2c85-4b79-ad8f-35d0326b568f",
"name": "draw_feet",
"type": "BOOLEAN",
"linkIds": [
1709
],
"pos": [
-3123.1015625,
3734
]
},
{
"id": "fee5d0c9-8d4b-4934-81d8-ba2206dc56cb",
"name": "stick_width",
"type": "INT",
"linkIds": [
1710
],
"pos": [
-3123.1015625,
3754
]
},
{
"id": "aafdd060-ba81-4324-a9cc-b656e1ebc133",
"name": "face_point_size",
"type": "INT",
"linkIds": [
1711
],
"pos": [
-3123.1015625,
3774
]
},
{
"id": "514c5503-f9e6-4d23-b1ae-1d3291acb2a3",
"name": "score_threshold",
"type": "FLOAT",
"linkIds": [
1712
],
"pos": [
-3123.1015625,
3794
]
},
{
"id": "ae46de61-2cc6-483e-8ee9-87e4144a2ffa",
"name": "ckpt_name",
"type": "COMBO",
"linkIds": [
1713
],
"pos": [
-3123.1015625,
3814
]
},
{
"id": "41bec0c6-dffa-4c78-9289-ee678715ae54",
"name": "bboxes",
"type": "BOUNDING_BOX",
"linkIds": [
1714
],
"pos": [
-3123.1015625,
3834
]
}
],
"outputs": [
{
"id": "f05ed8cc-9403-4f14-8085-4364b06f8a48",
"name": "IMAGE",
"type": "IMAGE",
"linkIds": [
1701
],
"localized_name": "IMAGE",
"pos": [
-1732.2451602089645,
3390
]
},
{
"id": "29a6584e-4685-4986-8ffd-e6d8539953fd",
"name": "keypoints",
"type": "POSE_KEYPOINT",
"linkIds": [
1715
],
"pos": [
-1732.2451602089645,
3410
]
}
],
"widgets": [],
"nodes": [
{
"id": 671,
"type": "SDPoseKeypointExtractor",
"pos": [
-2470,
3250
],
"size": [
270,
180
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "model",
"name": "model",
"type": "MODEL",
"link": 1696
},
{
"localized_name": "vae",
"name": "vae",
"type": "VAE",
"link": 1697
},
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 1698
},
{
"localized_name": "bboxes",
"name": "bboxes",
"shape": 7,
"type": "BOUNDING_BOX",
"link": 1714
},
{
"localized_name": "batch_size",
"name": "batch_size",
"type": "INT",
"widget": {
"name": "batch_size"
},
"link": null
}
],
"outputs": [
{
"localized_name": "keypoints",
"name": "keypoints",
"type": "POSE_KEYPOINT",
"links": [
1699,
1715
]
}
],
"properties": {
"Node name for S&R": "SDPoseKeypointExtractor",
"cnr_id": "comfy-core",
"ver": "0.15.0",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
16
]
},
{
"id": 674,
"type": "ResizeImageMaskNode",
"pos": [
-2960,
3490
],
"size": [
270,
110
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "input",
"name": "input",
"type": "IMAGE,MASK",
"link": 1700
},
{
"localized_name": "resize_type",
"name": "resize_type",
"type": "COMFY_DYNAMICCOMBO_V3",
"widget": {
"name": "resize_type"
},
"link": null
},
{
"localized_name": "resize_type.longer_size",
"name": "resize_type.longer_size",
"type": "INT",
"widget": {
"name": "resize_type.longer_size"
},
"link": 1704
},
{
"localized_name": "scale_method",
"name": "scale_method",
"type": "COMBO",
"widget": {
"name": "scale_method"
},
"link": 1705
}
],
"outputs": [
{
"localized_name": "resized",
"name": "resized",
"type": "*",
"links": [
1698
]
}
],
"properties": {
"Node name for S&R": "ResizeImageMaskNode",
"cnr_id": "comfy-core",
"ver": "0.15.0",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"scale longer dimension",
1024,
"area"
]
},
{
"id": 672,
"type": "SDPoseDrawKeypoints",
"pos": [
-2120,
3260
],
"size": [
270,
280
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "keypoints",
"name": "keypoints",
"type": "POSE_KEYPOINT",
"link": 1699
},
{
"localized_name": "draw_body",
"name": "draw_body",
"type": "BOOLEAN",
"widget": {
"name": "draw_body"
},
"link": 1706
},
{
"localized_name": "draw_hands",
"name": "draw_hands",
"type": "BOOLEAN",
"widget": {
"name": "draw_hands"
},
"link": 1707
},
{
"localized_name": "draw_face",
"name": "draw_face",
"type": "BOOLEAN",
"widget": {
"name": "draw_face"
},
"link": 1708
},
{
"localized_name": "draw_feet",
"name": "draw_feet",
"type": "BOOLEAN",
"widget": {
"name": "draw_feet"
},
"link": 1709
},
{
"localized_name": "stick_width",
"name": "stick_width",
"type": "INT",
"widget": {
"name": "stick_width"
},
"link": 1710
},
{
"localized_name": "face_point_size",
"name": "face_point_size",
"type": "INT",
"widget": {
"name": "face_point_size"
},
"link": 1711
},
{
"localized_name": "score_threshold",
"name": "score_threshold",
"type": "FLOAT",
"widget": {
"name": "score_threshold"
},
"link": 1712
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"type": "IMAGE",
"links": [
1701
]
}
],
"properties": {
"Node name for S&R": "SDPoseDrawKeypoints",
"cnr_id": "comfy-core",
"ver": "0.15.0",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
true,
true,
true,
true,
4,
2,
0.5
]
},
{
"id": 673,
"type": "CheckpointLoaderSimple",
"pos": [
-2960,
3250
],
"size": [
390,
190
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "ckpt_name",
"name": "ckpt_name",
"type": "COMBO",
"widget": {
"name": "ckpt_name"
},
"link": 1713
}
],
"outputs": [
{
"localized_name": "MODEL",
"name": "MODEL",
"type": "MODEL",
"links": [
1696
]
},
{
"localized_name": "CLIP",
"name": "CLIP",
"type": "CLIP",
"links": []
},
{
"localized_name": "VAE",
"name": "VAE",
"type": "VAE",
"links": [
1697
]
}
],
"properties": {
"Node name for S&R": "CheckpointLoaderSimple",
"cnr_id": "comfy-core",
"ver": "0.15.0",
"models": [
{
"name": "sdpose_wholebody_fp16.safetensors",
"url": "https://huggingface.co/Comfy-Org/SDPose/resolve/main/checkpoints/sdpose_wholebody_fp16.safetensors",
"directory": "checkpoints"
}
],
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"sdpose_wholebody_fp16.safetensors"
]
}
],
"groups": [],
"links": [
{
"id": 1696,
"origin_id": 673,
"origin_slot": 0,
"target_id": 671,
"target_slot": 0,
"type": "MODEL"
},
{
"id": 1697,
"origin_id": 673,
"origin_slot": 2,
"target_id": 671,
"target_slot": 1,
"type": "VAE"
},
{
"id": 1698,
"origin_id": 674,
"origin_slot": 0,
"target_id": 671,
"target_slot": 2,
"type": "IMAGE"
},
{
"id": 1699,
"origin_id": 671,
"origin_slot": 0,
"target_id": 672,
"target_slot": 0,
"type": "POSE_KEYPOINT"
},
{
"id": 1700,
"origin_id": -10,
"origin_slot": 0,
"target_id": 674,
"target_slot": 0,
"type": "IMAGE,MASK"
},
{
"id": 1701,
"origin_id": 672,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 1704,
"origin_id": -10,
"origin_slot": 1,
"target_id": 674,
"target_slot": 2,
"type": "INT"
},
{
"id": 1705,
"origin_id": -10,
"origin_slot": 2,
"target_id": 674,
"target_slot": 3,
"type": "COMBO"
},
{
"id": 1706,
"origin_id": -10,
"origin_slot": 3,
"target_id": 672,
"target_slot": 1,
"type": "BOOLEAN"
},
{
"id": 1707,
"origin_id": -10,
"origin_slot": 4,
"target_id": 672,
"target_slot": 2,
"type": "BOOLEAN"
},
{
"id": 1708,
"origin_id": -10,
"origin_slot": 5,
"target_id": 672,
"target_slot": 3,
"type": "BOOLEAN"
},
{
"id": 1709,
"origin_id": -10,
"origin_slot": 6,
"target_id": 672,
"target_slot": 4,
"type": "BOOLEAN"
},
{
"id": 1710,
"origin_id": -10,
"origin_slot": 7,
"target_id": 672,
"target_slot": 5,
"type": "INT"
},
{
"id": 1711,
"origin_id": -10,
"origin_slot": 8,
"target_id": 672,
"target_slot": 6,
"type": "INT"
},
{
"id": 1712,
"origin_id": -10,
"origin_slot": 9,
"target_id": 672,
"target_slot": 7,
"type": "FLOAT"
},
{
"id": 1713,
"origin_id": -10,
"origin_slot": 10,
"target_id": 673,
"target_slot": 0,
"type": "COMBO"
},
{
"id": 1714,
"origin_id": -10,
"origin_slot": 11,
"target_id": 671,
"target_slot": 3,
"type": "BOUNDING_BOX"
},
{
"id": 1715,
"origin_id": 671,
"origin_slot": 0,
"target_id": -20,
"target_slot": 1,
"type": "POSE_KEYPOINT"
}
],
"extra": {
"workflowRendererVersion": "LG"
}
}
]
},
"extra": {
"ue_links": []
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,484 +0,0 @@
{
"revision": 0,
"last_node_id": 10,
"last_link_id": 0,
"nodes": [
{
"id": 10,
"type": "3fb7557a-470d-4983-9d8c-6d5caa9788f0",
"pos": [
-250,
8590
],
"size": [
280,
360
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "text_per_line",
"name": "text_per_line",
"type": "STRING",
"widget": {
"name": "text_per_line"
},
"link": null
},
{
"localized_name": "index",
"name": "index",
"type": "INT",
"widget": {
"name": "index"
},
"link": null
}
],
"outputs": [
{
"localized_name": "selected_line",
"name": "selected_line",
"type": "STRING",
"links": []
}
],
"properties": {
"proxyWidgets": [
[
"2",
"string"
],
[
"3",
"value"
]
],
"cnr_id": "comfy-core",
"ver": "0.19.0",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
}
},
"widgets_values": [],
"title": "Select Per-Line Text by Index"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "3fb7557a-470d-4983-9d8c-6d5caa9788f0",
"version": 1,
"state": {
"lastGroupId": 0,
"lastNodeId": 10,
"lastLinkId": 14,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Select Per-Line Text by Index",
"inputNode": {
"id": -10,
"bounding": [
-990,
8595,
128,
88
]
},
"outputNode": {
"id": -20,
"bounding": [
710,
8585,
128,
68
]
},
"inputs": [
{
"id": "75417d82-a934-4ac9-b667-d8dcd5a3bfb3",
"name": "text_per_line",
"type": "STRING",
"linkIds": [
13
],
"localized_name": "text_per_line",
"pos": [
-886,
8619
]
},
{
"id": "46e69a73-1804-4ca6-9175-31445bf0be96",
"name": "index",
"type": "INT",
"linkIds": [
14
],
"localized_name": "index",
"pos": [
-886,
8639
]
}
],
"outputs": [
{
"id": "e34e8ad1-84d2-4bd2-a460-eb7de6067c10",
"name": "selected_line",
"type": "STRING",
"linkIds": [
10
],
"localized_name": "selected_line",
"pos": [
734,
8609
]
}
],
"widgets": [],
"nodes": [
{
"id": 1,
"type": "PreviewAny",
"pos": [
-500,
8400
],
"size": [
230,
180
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "source",
"name": "source",
"type": "*",
"link": 1
}
],
"outputs": [
{
"localized_name": "STRING",
"name": "STRING",
"type": "STRING",
"links": [
6
]
}
],
"properties": {
"Node name for S&R": "PreviewAny",
"cnr_id": "comfy-core",
"ver": "0.19.0",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
}
},
"widgets_values": [
null,
null,
null
]
},
{
"id": 2,
"type": "RegexExtract",
"pos": [
-240,
8740
],
"size": [
470,
460
],
"flags": {},
"order": 1,
"mode": 0,
"showAdvanced": false,
"inputs": [
{
"localized_name": "string",
"name": "string",
"type": "STRING",
"widget": {
"name": "string"
},
"link": 13
},
{
"localized_name": "regex_pattern",
"name": "regex_pattern",
"type": "STRING",
"widget": {
"name": "regex_pattern"
},
"link": 9
},
{
"localized_name": "mode",
"name": "mode",
"type": "COMBO",
"widget": {
"name": "mode"
},
"link": null
},
{
"localized_name": "case_insensitive",
"name": "case_insensitive",
"type": "BOOLEAN",
"widget": {
"name": "case_insensitive"
},
"link": null
},
{
"localized_name": "multiline",
"name": "multiline",
"type": "BOOLEAN",
"widget": {
"name": "multiline"
},
"link": null
},
{
"localized_name": "dotall",
"name": "dotall",
"type": "BOOLEAN",
"widget": {
"name": "dotall"
},
"link": null
},
{
"localized_name": "group_index",
"name": "group_index",
"type": "INT",
"widget": {
"name": "group_index"
},
"link": null
}
],
"outputs": [
{
"localized_name": "STRING",
"name": "STRING",
"type": "STRING",
"links": [
10
]
}
],
"properties": {
"Node name for S&R": "RegexExtract",
"cnr_id": "comfy-core",
"ver": "0.19.0",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"",
"",
"First Group",
false,
false,
false,
1
]
},
{
"id": 3,
"type": "PrimitiveInt",
"pos": [
-810,
8400
],
"size": [
270,
110
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "value",
"name": "value",
"type": "INT",
"widget": {
"name": "value"
},
"link": 14
}
],
"outputs": [
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
1
]
}
],
"title": "Int (line index)",
"properties": {
"Node name for S&R": "Int (line index)",
"cnr_id": "comfy-core",
"ver": "0.19.0",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
}
},
"widgets_values": [
0,
"fixed"
]
},
{
"id": 8,
"type": "StringReplace",
"pos": [
-240,
8400
],
"size": [
400,
280
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "string",
"name": "string",
"type": "STRING",
"widget": {
"name": "string"
},
"link": null
},
{
"localized_name": "find",
"name": "find",
"type": "STRING",
"widget": {
"name": "find"
},
"link": null
},
{
"localized_name": "replace",
"name": "replace",
"type": "STRING",
"widget": {
"name": "replace"
},
"link": 6
}
],
"outputs": [
{
"localized_name": "STRING",
"name": "STRING",
"type": "STRING",
"links": [
9
]
}
],
"properties": {
"Node name for S&R": "StringReplace",
"cnr_id": "comfy-core",
"ver": "0.19.0",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"(?:[^\\n]*\\n){index}([^\\n]+)",
"index",
""
]
}
],
"groups": [],
"links": [
{
"id": 1,
"origin_id": 3,
"origin_slot": 0,
"target_id": 1,
"target_slot": 0,
"type": "INT"
},
{
"id": 9,
"origin_id": 8,
"origin_slot": 0,
"target_id": 2,
"target_slot": 1,
"type": "STRING"
},
{
"id": 6,
"origin_id": 1,
"origin_slot": 0,
"target_id": 8,
"target_slot": 2,
"type": "STRING"
},
{
"id": 10,
"origin_id": 2,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "STRING"
},
{
"id": 13,
"origin_id": -10,
"origin_slot": 0,
"target_id": 2,
"target_slot": 0,
"type": "STRING"
},
{
"id": 14,
"origin_id": -10,
"origin_slot": 1,
"target_id": 3,
"target_slot": 0,
"type": "INT"
}
],
"extra": {},
"category": "Text Tools"
}
]
},
"extra": {
"ue_links": [],
"links_added_by_ue": []
}
}

View File

@ -1,713 +0,0 @@
{
"revision": 0,
"last_node_id": 251,
"last_link_id": 0,
"nodes": [
{
"id": 251,
"type": "609e1fd1-b731-4b78-89ac-d19b1156b025",
"pos": [
-1490,
130
],
"size": [
230,
164
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "source_image",
"name": "source_image",
"type": "IMAGE",
"link": null
},
{
"localized_name": "columns",
"name": "columns",
"type": "INT",
"widget": {
"name": "columns"
},
"link": null
},
{
"localized_name": "rows",
"name": "rows",
"type": "INT",
"widget": {
"name": "rows"
},
"link": null
}
],
"outputs": [
{
"localized_name": "tiles",
"name": "tiles",
"type": "IMAGE",
"links": []
}
],
"properties": {
"proxyWidgets": [
[
"228",
"value"
],
[
"252",
"value"
]
],
"cnr_id": "comfy-core",
"ver": "0.20.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [],
"title": "Split Image Grid to Tiles"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "609e1fd1-b731-4b78-89ac-d19b1156b025",
"version": 1,
"state": {
"lastGroupId": 9,
"lastNodeId": 252,
"lastLinkId": 429,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Split Image Grid to Tiles",
"inputNode": {
"id": -10,
"bounding": [
-1690,
260,
128,
108
]
},
"outputNode": {
"id": -20,
"bounding": [
-510,
590,
128,
68
]
},
"inputs": [
{
"id": "866ac798-cfbc-450a-b755-e704f86404d9",
"name": "source_image",
"type": "IMAGE",
"linkIds": [
386,
389
],
"localized_name": "source_image",
"pos": [
-1586,
284
]
},
{
"id": "bc37b1f8-8ab2-4f19-bd00-75d4fbc4feb3",
"name": "columns",
"type": "INT",
"linkIds": [
427
],
"localized_name": "columns",
"pos": [
-1586,
304
]
},
{
"id": "d45915da-e848-43dd-9ccc-e3161e9c99d9",
"name": "rows",
"type": "INT",
"linkIds": [
428
],
"localized_name": "rows",
"pos": [
-1586,
324
]
}
],
"outputs": [
{
"id": "18bc780f-064b-4038-87c6-67dba71deb08",
"name": "tiles",
"type": "IMAGE",
"linkIds": [
394
],
"localized_name": "tiles",
"shape": 6,
"pos": [
-486,
614
]
}
],
"widgets": [],
"nodes": [
{
"id": 225,
"type": "SplitImageToTileList",
"pos": [
-1010,
620
],
"size": [
290,
170
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 386
},
{
"localized_name": "tile_width",
"name": "tile_width",
"type": "INT",
"widget": {
"name": "tile_width"
},
"link": 403
},
{
"localized_name": "tile_height",
"name": "tile_height",
"type": "INT",
"widget": {
"name": "tile_height"
},
"link": 404
},
{
"localized_name": "overlap",
"name": "overlap",
"type": "INT",
"widget": {
"name": "overlap"
},
"link": null
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"shape": 6,
"type": "IMAGE",
"links": [
394
]
}
],
"properties": {
"Node name for S&R": "SplitImageToTileList",
"cnr_id": "comfy-core",
"ver": "0.20.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
1024,
1024,
0
]
},
{
"id": 231,
"type": "ComfyMathExpression",
"pos": [
-1080,
330
],
"size": [
370,
190
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"label": "a",
"localized_name": "values.a",
"name": "values.a",
"type": "FLOAT,INT,BOOLEAN",
"link": 390
},
{
"label": "b",
"localized_name": "values.b",
"name": "values.b",
"shape": 7,
"type": "FLOAT,INT,BOOLEAN",
"link": 429
},
{
"label": "c",
"localized_name": "values.c",
"name": "values.c",
"shape": 7,
"type": "FLOAT,INT,BOOLEAN",
"link": null
},
{
"localized_name": "expression",
"name": "expression",
"type": "STRING",
"widget": {
"name": "expression"
},
"link": null
}
],
"outputs": [
{
"localized_name": "FLOAT",
"name": "FLOAT",
"type": "FLOAT",
"links": null
},
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
404
]
},
{
"localized_name": "BOOL",
"name": "BOOL",
"type": "BOOLEAN",
"links": null
}
],
"title": "Math Expression Height",
"properties": {
"Node name for S&R": "ComfyMathExpression",
"cnr_id": "comfy-core",
"ver": "0.18.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"max(1, (int(a) + int(b) - 1) // int(b))"
]
},
{
"id": 229,
"type": "ComfyMathExpression",
"pos": [
-1090,
-30
],
"size": [
370,
190
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"label": "a",
"localized_name": "values.a",
"name": "values.a",
"type": "FLOAT,INT,BOOLEAN",
"link": 387
},
{
"label": "b",
"localized_name": "values.b",
"name": "values.b",
"shape": 7,
"type": "FLOAT,INT,BOOLEAN",
"link": 388
},
{
"label": "c",
"localized_name": "values.c",
"name": "values.c",
"shape": 7,
"type": "FLOAT,INT,BOOLEAN",
"link": null
},
{
"localized_name": "expression",
"name": "expression",
"type": "STRING",
"widget": {
"name": "expression"
},
"link": null
}
],
"outputs": [
{
"localized_name": "FLOAT",
"name": "FLOAT",
"type": "FLOAT",
"links": null
},
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
403
]
},
{
"localized_name": "BOOL",
"name": "BOOL",
"type": "BOOLEAN",
"links": null
}
],
"title": "Math Expression Width",
"properties": {
"Node name for S&R": "ComfyMathExpression",
"cnr_id": "comfy-core",
"ver": "0.18.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"max(1, (int(a) + int(b) - 1) // int(b))"
]
},
{
"id": 228,
"type": "PrimitiveInt",
"pos": [
-1380,
90
],
"size": [
230,
110
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "value",
"name": "value",
"type": "INT",
"widget": {
"name": "value"
},
"link": 427
}
],
"outputs": [
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
388
]
}
],
"title": "Int (grid columns)",
"properties": {
"Node name for S&R": "Int (grid columns)",
"cnr_id": "comfy-core",
"ver": "0.18.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
2,
"fixed"
]
},
{
"id": 230,
"type": "GetImageSize",
"pos": [
-1380,
290
],
"size": [
230,
100
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 389
}
],
"outputs": [
{
"localized_name": "width",
"name": "width",
"type": "INT",
"links": [
387
]
},
{
"localized_name": "height",
"name": "height",
"type": "INT",
"links": [
390
]
},
{
"localized_name": "batch_size",
"name": "batch_size",
"type": "INT",
"links": null
}
],
"properties": {
"Node name for S&R": "GetImageSize",
"cnr_id": "comfy-core",
"ver": "0.18.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
}
},
{
"id": 252,
"type": "PrimitiveInt",
"pos": [
-1380,
470
],
"size": [
230,
110
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"localized_name": "value",
"name": "value",
"type": "INT",
"widget": {
"name": "value"
},
"link": 428
}
],
"outputs": [
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
429
]
}
],
"title": "Int (grid rows)",
"properties": {
"Node name for S&R": "Int (grid rows)",
"cnr_id": "comfy-core",
"ver": "0.18.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
3,
"fixed"
]
}
],
"groups": [],
"links": [
{
"id": 403,
"origin_id": 229,
"origin_slot": 1,
"target_id": 225,
"target_slot": 1,
"type": "INT"
},
{
"id": 404,
"origin_id": 231,
"origin_slot": 1,
"target_id": 225,
"target_slot": 2,
"type": "INT"
},
{
"id": 390,
"origin_id": 230,
"origin_slot": 1,
"target_id": 231,
"target_slot": 0,
"type": "INT"
},
{
"id": 387,
"origin_id": 230,
"origin_slot": 0,
"target_id": 229,
"target_slot": 0,
"type": "INT"
},
{
"id": 388,
"origin_id": 228,
"origin_slot": 0,
"target_id": 229,
"target_slot": 1,
"type": "INT"
},
{
"id": 386,
"origin_id": -10,
"origin_slot": 0,
"target_id": 225,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 389,
"origin_id": -10,
"origin_slot": 0,
"target_id": 230,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 394,
"origin_id": 225,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 427,
"origin_id": -10,
"origin_slot": 1,
"target_id": 228,
"target_slot": 0,
"type": "INT"
},
{
"id": 428,
"origin_id": -10,
"origin_slot": 2,
"target_id": 252,
"target_slot": 0,
"type": "INT"
},
{
"id": 429,
"origin_id": 252,
"origin_slot": 0,
"target_id": 231,
"target_slot": 1,
"type": "INT"
}
],
"extra": {},
"category": "Image Tools/Crop"
}
]
},
"extra": {}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -152,6 +152,11 @@ class StableAudio1(LatentFormat):
latent_dimensions = 1
temporal_downscale_ratio = 2048
class StableAudio3(LatentFormat):
latent_channels = 256
latent_dimensions = 1
temporal_downscale_ratio = 4096
class Flux(SD3):
latent_channels = 16
def __init__(self):

View File

@ -10,6 +10,17 @@ from torch import nn
from torch.nn import functional as F
import math
import comfy.ops
from .embedders import ExpoFourierFeatures
def _left_pad_to_match(emb, target_len):
emb_len = emb.shape[-2]
if emb_len < target_len:
return F.pad(emb, (0, 0, target_len - emb_len, 0), value=0.)
elif emb_len > target_len:
return emb[:, -target_len:, :]
return emb
class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
@ -22,6 +33,7 @@ class FourierFeatures(nn.Module):
f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input)
return torch.cat([f.cos(), f.sin()], dim=-1)
# norms
class LayerNorm(nn.Module):
def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None):
@ -43,6 +55,16 @@ class LayerNorm(nn.Module):
beta = comfy.ops.cast_to_input(beta, x)
return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta)
class RMSNorm(nn.Module):
def __init__(self, dim, dtype=None, device=None):
super().__init__()
self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
def forward(self, x):
return F.rms_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x))
class GLU(nn.Module):
def __init__(
self,
@ -236,13 +258,6 @@ class FeedForward(nn.Module):
linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device)
# # init last linear layer to 0
# if zero_init_output:
# nn.init.zeros_(linear_out.weight)
# if not no_bias:
# nn.init.zeros_(linear_out.bias)
self.ff = nn.Sequential(
linear_in,
rearrange('b d n -> b n d') if use_conv else nn.Identity(),
@ -261,8 +276,10 @@ class Attention(nn.Module):
dim_context = None,
causal = False,
zero_init_output=True,
qk_norm = False,
qk_norm = "none",
differential = False,
natten_kernel_size = None,
feat_scale = False,
dtype=None,
device=None,
operations=None,
@ -271,6 +288,7 @@ class Attention(nn.Module):
self.dim = dim
self.dim_heads = dim_heads
self.causal = causal
self.differential = differential
dim_kv = dim_context if dim_context is not None else dim
@ -278,18 +296,37 @@ class Attention(nn.Module):
self.kv_heads = dim_kv // dim_heads
if dim_context is not None:
self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device)
if differential:
self.to_q = operations.Linear(dim, dim * 2, bias=False, dtype=dtype, device=device)
self.to_kv = operations.Linear(dim_kv, dim_kv * 3, bias=False, dtype=dtype, device=device)
else:
self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device)
else:
self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device)
if differential:
self.to_qkv = operations.Linear(dim, dim * 5, bias=False, dtype=dtype, device=device)
else:
self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device)
self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
# if zero_init_output:
# nn.init.zeros_(self.to_out.weight)
# Accept bool for backward compat
if isinstance(qk_norm, bool):
qk_norm = "l2" if qk_norm else "none"
self.qk_norm = qk_norm
if self.qk_norm == "ln":
self.q_norm = operations.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
self.k_norm = operations.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
elif self.qk_norm == "rms":
self.q_norm = RMSNorm(dim_heads, dtype=dtype, device=device)
self.k_norm = RMSNorm(dim_heads, dtype=dtype, device=device)
self.feat_scale = feat_scale
if self.feat_scale:
self.lambda_dc = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
self.lambda_hf = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
def forward(
self,
@ -306,22 +343,51 @@ class Attention(nn.Module):
kv_input = context if has_context else x
if hasattr(self, 'to_q'):
# Use separate linear projections for q and k/v
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
if self.differential:
# cross-attention differential: to_q → (q, q_diff), to_kv → (k, k_diff, v)
q, q_diff = self.to_q(x).chunk(2, dim=-1)
q = rearrange(q, 'b n (h d) -> b h n d', h=h)
q_diff = rearrange(q_diff, 'b n (h d) -> b h n d', h=h)
q = torch.stack([q, q_diff], dim=1) # (B, 2, H, N, D)
k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1)
k = rearrange(k, 'b n (h d) -> b h n d', h=kv_h)
k_diff = rearrange(k_diff, 'b n (h d) -> b h n d', h=kv_h)
v = rearrange(v, 'b n (h d) -> b h n d', h=kv_h)
k = torch.stack([k, k_diff], dim=1) # (B, 2, H, M, D)
else:
# Use separate linear projections for q and k/v
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
else:
# Use fused linear projection
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
if self.differential:
# self-attention differential: to_qkv → (q, k, v, q_diff, k_diff)
q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1)
q, k, v, q_diff, k_diff = map(
lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
(q, k, v, q_diff, k_diff)
)
q = torch.stack([q, q_diff], dim=1) # (B, 2, H, N, D)
k = torch.stack([k, k_diff], dim=1)
else:
# Use fused linear projection
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# Normalize q and k for cosine sim attention
if self.qk_norm:
if self.qk_norm == "l2":
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
elif self.qk_norm == "rms":
q_type, k_type = q.dtype, k.dtype
q = self.q_norm(q).to(q_type)
k = self.k_norm(k).to(k_type)
elif self.qk_norm != 'none':
q = self.q_norm(q)
k = self.k_norm(k)
if rotary_pos_emb is not None and not has_context:
freqs, _ = rotary_pos_emb
@ -364,9 +430,24 @@ class Attention(nn.Module):
heads_per_kv_head = h // kv_h
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
if self.differential:
q, q_diff = q.unbind(dim=1)
k, k_diff = k.unbind(dim=1)
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, transformer_options=transformer_options)
out = out - out_diff
else:
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out = self.to_out(out)
if self.feat_scale:
out_dc = out.mean(dim=-2, keepdim=True)
out_hf = out - out_dc
# Selectively modulate DC and high frequency components
out = out + comfy.ops.cast_to_input(self.lambda_dc, out) * out_dc + comfy.ops.cast_to_input(self.lambda_hf, out) * out_hf
if mask is not None:
mask = rearrange(mask, 'b n -> b n 1')
out = out.masked_fill(~mask, 0.)
@ -417,11 +498,14 @@ class TransformerBlock(nn.Module):
cross_attend = False,
dim_context = None,
global_cond_dim = None,
global_cond_shared_embed = False,
local_add_cond_dim = None,
causal = False,
zero_init_branch_outputs = True,
conformer = False,
layer_ix = -1,
remove_norms = False,
norm_type = "layer_norm",
attn_kwargs = {},
ff_kwargs = {},
norm_kwargs = {},
@ -436,8 +520,20 @@ class TransformerBlock(nn.Module):
self.cross_attend = cross_attend
self.dim_context = dim_context
self.causal = causal
self.global_cond_shared_embed = global_cond_shared_embed
self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
norm_layer_map = {
"layer_norm": LayerNorm,
"rms_norm": RMSNorm,
}
norm_cls = norm_layer_map.get(norm_type, LayerNorm)
def make_norm():
if remove_norms:
return nn.Identity()
return norm_cls(dim, dtype=dtype, device=device, **norm_kwargs)
self.pre_norm = make_norm()
self.self_attn = Attention(
dim,
@ -451,7 +547,7 @@ class TransformerBlock(nn.Module):
)
if cross_attend:
self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
self.cross_attend_norm = make_norm()
self.cross_attn = Attention(
dim,
dim_heads = dim_heads,
@ -464,37 +560,56 @@ class TransformerBlock(nn.Module):
**attn_kwargs
)
self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs)
self.ff_norm = make_norm()
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations, **ff_kwargs)
self.layer_ix = layer_ix
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
self.global_cond_dim = global_cond_dim
# Global conditioning
self.has_global_cond = (global_cond_dim is not None) or global_cond_shared_embed
if global_cond_dim is not None:
if global_cond_shared_embed:
# SA3 style: learnable per-block additive bias; global_cond is pre-projected to (B, dim*6)
self.to_scale_shift_gate = nn.Parameter(torch.empty(dim * 6, device=device, dtype=dtype))
elif global_cond_dim is not None:
# SA1 style: per-block MLP projects global_cond → (B, dim*6)
self.to_scale_shift_gate = nn.Sequential(
nn.SiLU(),
nn.Linear(global_cond_dim, dim * 6, bias=False)
operations.Linear(global_cond_dim, dim * 6, bias=False, device=device, dtype=dtype)
)
nn.init.zeros_(self.to_scale_shift_gate[1].weight)
#nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
# Local additive conditioning (e.g. inpaint mask + masked latent)
self.local_add_cond_dim = local_add_cond_dim
if local_add_cond_dim is not None:
self.to_local_embed = nn.Sequential(
operations.Linear(local_add_cond_dim, dim, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(dim, dim, bias=True, dtype=dtype, device=device),
)
else:
self.to_local_embed = None
def forward(
self,
x,
context = None,
global_cond=None,
local_add_cond=None,
mask = None,
context_mask = None,
rotary_pos_emb = None,
transformer_options={}
):
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
if self.has_global_cond and global_cond is not None:
if self.global_cond_shared_embed:
# global_cond already has shape (B, dim*6)
ssg = (comfy.ops.cast_to_input(self.to_scale_shift_gate, global_cond) + global_cond).unsqueeze(1)
else:
ssg = self.to_scale_shift_gate(global_cond).unsqueeze(1)
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = ssg.chunk(6, dim = -1)
# self-attention with adaLN
residual = x
@ -510,6 +625,9 @@ class TransformerBlock(nn.Module):
if self.conformer is not None:
x = x + self.conformer(x)
if local_add_cond is not None and self.to_local_embed is not None:
x = x + _left_pad_to_match(self.to_local_embed(local_add_cond), x.shape[-2])
# feedforward with adaLN
residual = x
x = self.ff_norm(x)
@ -527,6 +645,9 @@ class TransformerBlock(nn.Module):
if self.conformer is not None:
x = x + self.conformer(x)
if local_add_cond is not None and self.to_local_embed is not None:
x = x + _left_pad_to_match(self.to_local_embed(local_add_cond), x.shape[-2])
x = x + self.ff(self.ff_norm(x))
return x
@ -543,6 +664,8 @@ class ContinuousTransformer(nn.Module):
cross_attend=False,
cond_token_dim=None,
global_cond_dim=None,
global_cond_shared_embed=False,
local_add_cond_dim=None,
causal=False,
rotary_pos_emb=True,
zero_init_branch_outputs=True,
@ -550,6 +673,7 @@ class ContinuousTransformer(nn.Module):
use_sinusoidal_emb=False,
use_abs_pos_emb=False,
abs_pos_emb_max_length=10000,
num_memory_tokens=0,
dtype=None,
device=None,
operations=None,
@ -562,6 +686,8 @@ class ContinuousTransformer(nn.Module):
self.depth = depth
self.causal = causal
self.layers = nn.ModuleList([])
self.num_memory_tokens = num_memory_tokens
self.global_cond_shared_embed = global_cond_shared_embed
self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity()
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
@ -577,7 +703,22 @@ class ContinuousTransformer(nn.Module):
self.use_abs_pos_emb = use_abs_pos_emb
if use_abs_pos_emb:
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + num_memory_tokens)
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.empty(num_memory_tokens, dim, device=device, dtype=dtype))
# Shared global-cond embedder (SA3 style): projects (B, global_cond_dim) → (B, dim*6)
self.global_cond_embedder = None
if global_cond_shared_embed and global_cond_dim is not None:
self.global_cond_embedder = nn.Sequential(
operations.Linear(global_cond_dim, dim, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(dim, dim * 6, bias=True, dtype=dtype, device=device),
)
# When using shared embed, TransformerBlocks use per-block Parameter (not per-block MLP)
block_global_cond_dim = None if global_cond_shared_embed else global_cond_dim
for i in range(depth):
self.layers.append(
@ -586,7 +727,9 @@ class ContinuousTransformer(nn.Module):
dim_heads = dim_heads,
cross_attend = cross_attend,
dim_context = cond_token_dim,
global_cond_dim = global_cond_dim,
global_cond_dim = block_global_cond_dim,
global_cond_shared_embed = global_cond_shared_embed,
local_add_cond_dim = local_add_cond_dim,
causal = causal,
zero_init_branch_outputs = zero_init_branch_outputs,
conformer=conformer,
@ -605,6 +748,7 @@ class ContinuousTransformer(nn.Module):
prepend_embeds = None,
prepend_mask = None,
global_cond = None,
local_add_cond = None,
return_info = False,
**kwargs
):
@ -632,7 +776,9 @@ class ContinuousTransformer(nn.Module):
mask = torch.cat((prepend_mask, mask), dim = -1)
# Attention layers
if self.num_memory_tokens > 0:
memory_tokens = comfy.ops.cast_to_input(self.memory_tokens, x).expand(batch, -1, -1)
x = torch.cat((memory_tokens, x), dim=1)
if self.rotary_pos_emb is not None:
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device)
@ -642,6 +788,10 @@ class ContinuousTransformer(nn.Module):
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
x = x + self.pos_emb(x)
# Project global_cond once (SA3 shared-embed path)
if global_cond is not None and self.global_cond_embedder is not None:
global_cond = self.global_cond_embedder(global_cond)
blocks_replace = patches_replace.get("dit", {})
# Iterate over the transformer layers
for i, layer in enumerate(self.layers):
@ -654,12 +804,17 @@ class ContinuousTransformer(nn.Module):
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
x = layer(x, rotary_pos_emb=rotary_pos_emb, global_cond=global_cond,
local_add_cond=local_add_cond, context=context,
transformer_options=transformer_options)
if return_info:
info["hidden_states"].append(x)
# Strip memory tokens before projecting out
if self.num_memory_tokens > 0:
x = x[:, self.num_memory_tokens:, :]
x = self.project_out(x)
if return_info:
@ -682,6 +837,7 @@ class AudioDiffusionTransformer(nn.Module):
num_heads=24,
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
timestep_features_type: str = "learned",
audio_model="",
dtype=None,
device=None,
@ -696,7 +852,10 @@ class AudioDiffusionTransformer(nn.Module):
# Timestep embeddings
timestep_features_dim = 256
self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device)
if timestep_features_type == "expo":
self.timestep_features = ExpoFourierFeatures(timestep_features_dim, 0.5, 10000.0)
else:
self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device)
self.to_timestep_embed = nn.Sequential(
operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device),
@ -781,6 +940,7 @@ class AudioDiffusionTransformer(nn.Module):
cross_attn_cond=None,
cross_attn_cond_mask=None,
input_concat_cond=None,
local_add_cond=None,
global_embed=None,
prepend_cond=None,
prepend_cond_mask=None,
@ -802,9 +962,13 @@ class AudioDiffusionTransformer(nn.Module):
prepend_cond = self.to_prepend_embed(prepend_cond)
prepend_inputs = prepend_cond
prepend_length = prepend_cond.shape[1]
if prepend_cond_mask is not None:
prepend_mask = prepend_cond_mask
if local_add_cond is not None and local_add_cond.dim() == 3:
local_add_cond = local_add_cond.permute(0, 2, 1)
if input_concat_cond is not None:
# Interpolate input_concat_cond to the same length as x
@ -850,7 +1014,7 @@ class AudioDiffusionTransformer(nn.Module):
if self.transformer_type == "x-transformers":
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
elif self.transformer_type == "continuous_transformer":
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, local_add_cond=local_add_cond, **extra_args, **kwargs)
if return_info:
output, info = output
@ -876,6 +1040,7 @@ class AudioDiffusionTransformer(nn.Module):
context=None,
context_mask=None,
input_concat_cond=None,
local_add_cond=None,
global_embed=None,
negative_global_embed=None,
prepend_cond=None,
@ -890,6 +1055,7 @@ class AudioDiffusionTransformer(nn.Module):
cross_attn_cond=context,
cross_attn_cond_mask=context_mask,
input_concat_cond=input_concat_cond,
local_add_cond=local_add_cond,
global_embed=global_embed,
prepend_cond=prepend_cond,
prepend_cond_mask=prepend_cond_mask,

View File

@ -31,15 +31,39 @@ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
)
class ExpoFourierFeatures(nn.Module):
"""Exponentially-spaced Fourier features (no learnable parameters)."""
def __init__(self, dim, min_freq=0.5, max_freq=10000.0):
super().__init__()
self.dim = dim
self.min_freq = min_freq
self.max_freq = max_freq
def forward(self, t):
in_dtype = t.dtype
t = t.float()
if t.dim() == 1:
t = t.unsqueeze(-1)
half_dim = self.dim // 2
ramp = torch.linspace(0, 1, half_dim, device=t.device, dtype=torch.float32)
freqs = torch.exp(ramp * (math.log(self.max_freq) - math.log(self.min_freq)) + math.log(self.min_freq))
args = t * freqs * 2 * math.pi
return torch.cat([args.cos(), args.sin()], dim=-1).to(in_dtype)
class NumberEmbedder(nn.Module):
def __init__(
self,
features: int,
dim: int = 256,
fourier_features_type="learned",
):
super().__init__()
self.features = features
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
if fourier_features_type == "expo":
self.embedding = nn.Sequential(ExpoFourierFeatures(dim=dim), comfy.ops.manual_cast.Linear(in_features=dim, out_features=features))
else:
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
if not torch.is_tensor(x):
@ -77,14 +101,15 @@ class NumberConditioner(Conditioner):
def __init__(self,
output_dim: int,
min_val: float=0,
max_val: float=1
max_val: float=1,
fourier_features_type: str = "learned",
):
super().__init__(output_dim, output_dim)
self.min_val = min_val
self.max_val = max_val
self.embedder = NumberEmbedder(features=output_dim)
self.embedder = NumberEmbedder(features=output_dim, fourier_features_type=fourier_features_type)
def forward(self, floats, device=None):
# Cast the inputs to floats

533
comfy/ldm/audio/vae_sa3.py Normal file
View File

@ -0,0 +1,533 @@
import torch
import torch.nn as nn
import comfy.ops
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.audio.autoencoder import WNConv1d
ops = comfy.ops.disable_weight_init
class Transpose(nn.Module):
def forward(self, x, **kwargs):
return x.transpose(-2, -1)
def _zero_pad_modulo_sequence(x, size, dim=-2):
input_len = x.shape[dim]
pad_len = (size - input_len % size) % size
if pad_len > 0:
pad_shape = list(x.shape)
pad_shape[dim] = pad_len
x = torch.cat([x, torch.zeros(pad_shape, device=x.device, dtype=x.dtype)], dim=dim)
return x
def _sliding_window_mask(seq_len, window, device, dtype):
"""Additive attention mask enforcing a ±window local window (matches flash_attn window_size)."""
i = torch.arange(seq_len, device=device).unsqueeze(1)
j = torch.arange(seq_len, device=device).unsqueeze(0)
out_of_window = (j - i).abs() > window
return torch.where(
out_of_window,
torch.full((1,), torch.finfo(dtype).min / 4, device=device, dtype=dtype),
torch.zeros(1, device=device, dtype=dtype),
)
class DynamicTanh(nn.Module):
def __init__(self, dim, init_alpha=4.0, dtype=None, device=None, **kwargs):
super().__init__()
self.alpha = nn.Parameter(torch.empty(1, dtype=dtype, device=device))
self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
def forward(self, x):
alpha = comfy.ops.cast_to_input(self.alpha, x)
gamma = comfy.ops.cast_to_input(self.gamma, x)
beta = comfy.ops.cast_to_input(self.beta, x)
return gamma * torch.tanh(alpha * x) + beta
class RotaryEmbedding(nn.Module):
def __init__(self, dim, base=10000, base_rescale_factor=1., dtype=None, device=None):
super().__init__()
base = base * base_rescale_factor ** (dim / (dim - 2))
self.register_buffer("inv_freq", torch.empty(dim // 2, dtype=dtype, device=device))
def forward_from_seq_len(self, seq_len, device, dtype=None):
t = torch.arange(seq_len, device=device, dtype=torch.float32)
return self.forward(t)
def forward(self, t):
freqs = torch.outer(t.float(), comfy.model_management.cast_to(self.inv_freq, dtype=torch.float32, device=t.device))
freqs = torch.cat((freqs, freqs), dim=-1)
return freqs, 1.
def _rotate_half(x):
d = x.shape[-1] // 2
return torch.cat((-x[..., d:], x[..., :d]), dim=-1)
def _apply_rotary_pos_emb(t, freqs):
out_dtype = t.dtype
rot_dim = freqs.shape[-1]
seq_len = t.shape[-2]
freqs = freqs[-seq_len:]
t_rot, t_pass = t[..., :rot_dim], t[..., rot_dim:]
t_rot = t_rot * freqs.cos() + _rotate_half(t_rot) * freqs.sin()
return torch.cat((t_rot.to(out_dtype), t_pass.to(out_dtype)), dim=-1)
class Attention(nn.Module):
def __init__(self, dim, dim_heads=64, qk_norm="none", qk_norm_eps=1e-6,
differential=False, zero_init_output=True,
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.num_heads = dim // dim_heads
self.differential = differential
self.qk_norm = qk_norm
self.to_qkv = operations.Linear(
dim, dim * (5 if differential else 3), bias=False, dtype=dtype, device=device)
self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
if qk_norm == "dyt":
self.q_norm = DynamicTanh(dim_heads, dtype=dtype, device=device)
self.k_norm = DynamicTanh(dim_heads, dtype=dtype, device=device)
elif qk_norm == "rms":
self.q_norm = operations.RMSNorm(dim_heads, eps=qk_norm_eps, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(dim_heads, eps=qk_norm_eps, dtype=dtype, device=device)
def forward(self, x, rotary_pos_emb=None, mask=None, **kwargs):
B, N, _ = x.shape
h = self.num_heads
qkv = self.to_qkv(x)
if self.differential:
q, k, v, q_diff, k_diff = qkv.chunk(5, dim=-1)
del qkv
q = q.view(B, N, h, -1).transpose(1, 2)
k = k.view(B, N, h, -1).transpose(1, 2)
v = v.view(B, N, h, -1).transpose(1, 2)
q_diff = q_diff.view(B, N, h, -1).transpose(1, 2)
k_diff = k_diff.view(B, N, h, -1).transpose(1, 2)
else:
q, k, v = qkv.chunk(3, dim=-1)
del qkv
q = q.view(B, N, h, -1).transpose(1, 2)
k = k.view(B, N, h, -1).transpose(1, 2)
v = v.view(B, N, h, -1).transpose(1, 2)
if self.qk_norm != "none":
q_dtype, k_dtype = q.dtype, k.dtype
q = self.q_norm(q).to(q_dtype)
k = self.k_norm(k).to(k_dtype)
if self.differential:
q_diff = self.q_norm(q_diff).to(q_dtype)
k_diff = self.k_norm(k_diff).to(k_dtype)
if rotary_pos_emb is not None:
freqs, _ = rotary_pos_emb
q_dtype, k_dtype = q.dtype, k.dtype
q = _apply_rotary_pos_emb(q.float(), freqs).to(q_dtype)
k = _apply_rotary_pos_emb(k.float(), freqs).to(k_dtype)
if self.differential:
q_diff = _apply_rotary_pos_emb(q_diff.float(), freqs).to(q_dtype)
k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype)
if self.differential:
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True))
del q, k, v, q_diff, k_diff
else:
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
del q, k, v
return self.to_out(out)
class _Sin(nn.Module):
def forward(self, x):
return torch.sin(3.14159265359 * x)
class _GLU(nn.Module):
def __init__(self, dim_in, dim_out, activation, dtype=None, device=None, operations=None):
super().__init__()
self.act = activation
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
def forward(self, x):
x = self.proj(x)
x, gate = x.chunk(2, dim=-1)
return x * self.act(gate)
class FeedForward(nn.Module):
def __init__(self, dim, mult=4, no_bias=False, zero_init_output=True,
sinusoidal=False, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
inner_dim = int(dim * mult)
act = _Sin() if sinusoidal else nn.SiLU()
self.ff = nn.Sequential(
_GLU(dim, inner_dim, act, dtype=dtype, device=device, operations=operations),
nn.Identity(),
operations.Linear(inner_dim, dim, bias=not no_bias, dtype=dtype, device=device),
nn.Identity(),
)
def forward(self, x, **kwargs):
return self.ff(x)
class TransformerBlock(nn.Module):
def __init__(self, dim, dim_heads=64, causal=False, zero_init_branch_outputs=True,
norm_type="dyt", add_rope=False, attn_kwargs=None, ff_kwargs=None,
norm_kwargs=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
if attn_kwargs is None:
attn_kwargs = {}
if ff_kwargs is None:
ff_kwargs = {}
if norm_kwargs is None:
norm_kwargs = {}
dim_heads = min(dim_heads, dim)
Norm = DynamicTanh if norm_type == "dyt" else operations.RMSNorm
norm_kw = {**norm_kwargs, "dtype": dtype, "device": device}
self.pre_norm = Norm(dim, **norm_kw)
self.self_attn = Attention(dim, dim_heads=dim_heads,
zero_init_output=zero_init_branch_outputs,
dtype=dtype, device=device, operations=operations,
**attn_kwargs)
self.ff_norm = Norm(dim, **norm_kw)
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs,
dtype=dtype, device=device, operations=operations, **ff_kwargs)
self.rope = RotaryEmbedding(dim_heads // 2, dtype=dtype, device=device) if add_rope else None
def forward(self, x, mask=None, **kwargs):
rope = self.rope.forward_from_seq_len(x.shape[-2], device=x.device) \
if self.rope is not None else None
x = x + self.self_attn(self.pre_norm(x), rotary_pos_emb=rope, mask=mask)
x = x + self.ff(self.ff_norm(x))
return x
class TransformerResamplingBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, type="encoder",
transformer_depth=3, dim_heads=128, differential=True,
sliding_window=None, chunk_size=128, chunk_midpoint_shift=False,
dyt=True, ff_mult=3, mapping_bias=True, variable_stride=False,
sinusoidal_blocks=0, conv_mapping=False, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
if type not in ("encoder", "decoder"):
raise ValueError(f"type must be 'encoder' or 'decoder', got {type!r}")
self.type = type
self.stride = stride
self.chunk_size = chunk_size
self.chunk_midpoint_shift = chunk_midpoint_shift
self.variable_stride = variable_stride
self.transformer_depth = transformer_depth
transformer_dim = out_channels if type == "encoder" else in_channels
self.mapping = (WNConv1d(in_channels, out_channels, 3 if conv_mapping else 1, padding="same", bias=mapping_bias)
if in_channels != out_channels else nn.Identity())
self.sliding_window_latents = sliding_window
self.sliding_window_seq = self._get_sliding_window_size(sliding_window, stride)
self.input_seg_size, self.output_seg_size, self.sub_chunk_size = self._get_seg_sizes(stride)
token_seq = 1 if variable_stride else self.output_seg_size
self.new_tokens = nn.Parameter(torch.empty(1, token_seq, transformer_dim, dtype=dtype, device=device))
norm_type = "dyt" if dyt else "rms_norm"
attn_kwargs = {"qk_norm": "dyt" if dyt else "rms", "qk_norm_eps": 1e-3,
"differential": differential}
norm_kwargs = {"eps": 1e-3}
transformers = []
for i in range(transformer_depth):
sinusoidal = (transformer_depth - i) < sinusoidal_blocks
transformers.append(TransformerBlock(
transformer_dim,
dim_heads=dim_heads,
causal=False,
zero_init_branch_outputs=True,
norm_type=norm_type,
add_rope=True,
attn_kwargs=attn_kwargs,
ff_kwargs={"mult": ff_mult, "no_bias": False, "sinusoidal": sinusoidal},
norm_kwargs=norm_kwargs,
dtype=dtype, device=device, operations=operations,
))
self.transformers = nn.ModuleList(transformers)
def _get_sliding_window_size(self, window, stride, prepend_cond_length=0):
if window is None:
return None
return [w * (stride + 1 + prepend_cond_length) for w in window]
def _get_seg_sizes(self, stride, prepend_cond_length=0):
sub_chunk_size = stride + 1 + prepend_cond_length
input_seg_size = stride if self.type == "encoder" else 1
output_seg_size = 1 if self.type == "encoder" else stride
return input_seg_size, output_seg_size, sub_chunk_size
def forward(self, x, stride=None, **kwargs):
B = x.shape[0]
if stride is None:
input_seg = self.input_seg_size
output_seg = self.output_seg_size
sub_chunk = self.sub_chunk_size
sliding_window = self.sliding_window_seq
else:
input_seg, output_seg, sub_chunk = self._get_seg_sizes(stride)
sliding_window = self._get_sliding_window_size(self.sliding_window_latents, stride)
if self.type == "encoder":
if self.transformer_depth > 0:
pad_mod = self.chunk_size if sliding_window is None else input_seg
x = _zero_pad_modulo_sequence(x, pad_mod, dim=-1)
x = self.mapping(x)
if self.transformer_depth > 0:
x = x.permute(0, 2, 1)
if self.type != "encoder":
pad_mod = 1 if sliding_window is not None else (
self.chunk_size // (stride if stride is not None else self.stride))
x = _zero_pad_modulo_sequence(x, pad_mod)
C = x.shape[2]
x = x.reshape(-1, input_seg, C)
new_tokens = self.new_tokens.expand(x.shape[0], output_seg, -1)
x = torch.cat([x, comfy.ops.cast_to_input(new_tokens, x)], dim=-2)
del new_tokens
x = x.reshape(B, -1, C)
if sliding_window is None:
eff_chunk = self.chunk_size + self.chunk_size // (stride if stride is not None else self.stride)
if sliding_window is None and self.chunk_midpoint_shift:
split = self.transformer_depth // 2
shift = eff_chunk // 2
x = x.reshape(-1, eff_chunk, C)
for layer in self.transformers[:split]:
x = layer(x)
x = x.reshape(B, -1, C)
shifted = torch.cat([x[:, :shift, :], x, x[:, -shift:, :]], dim=1)
del x
x = shifted.reshape(-1, eff_chunk, C)
del shifted
for layer in self.transformers[split:]:
x = layer(x)
x = x.reshape(B, -1, C)
x = x[:, shift:-shift, :]
elif sliding_window is None:
x = x.reshape(-1, eff_chunk, C)
for layer in self.transformers:
x = layer(x)
x = x.reshape(B, -1, C)
else:
attn_mask = _sliding_window_mask(x.shape[1], sliding_window[0], x.device, x.dtype)
for layer in self.transformers:
x = layer(x, mask=attn_mask)
x = x.reshape(-1, sub_chunk, C)
x = x[:, -output_seg:, :]
x = x.reshape(B, -1, C).transpose(1, 2)
if self.type == "decoder":
x = self.mapping(x)
return x
class SAMEEncoder(nn.Module):
def __init__(self, in_channels=2, channels=128, latent_dim=32,
c_mults=(1, 2, 4, 8), strides=(2, 4, 8, 8),
transformer_depths=(3, 3, 3, 3),
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
channel_dims = [in_channels] + [channels * c for c in c_mults]
layers = []
for i in range(len(c_mults)):
layers.append(TransformerResamplingBlock(
in_channels=channel_dims[i], out_channels=channel_dims[i + 1],
stride=strides[i], type="encoder",
transformer_depth=transformer_depths[i],
dtype=dtype, device=device, operations=operations, **kwargs))
layers += [
Transpose(),
operations.Linear(channel_dims[-1], latent_dim, dtype=dtype, device=device),
Transpose(),
]
self.layers = nn.ModuleList(layers)
def forward(self, x, **kwargs):
for layer in self.layers:
x = layer(x)
return x
class SAMEDecoder(nn.Module):
def __init__(self, out_channels=2, channels=128, latent_dim=32,
c_mults=(1, 2, 4, 8), strides=(2, 4, 8, 8),
transformer_depths=(3, 3, 3, 3), sinusoidal_blocks=None,
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
if sinusoidal_blocks is None:
sinusoidal_blocks = [0] * len(c_mults)
channel_dims = [out_channels] + [channels * c for c in c_mults]
layers = [
Transpose(),
operations.Linear(latent_dim, channel_dims[-1], dtype=dtype, device=device),
Transpose(),
]
for i in range(len(c_mults) - 1, -1, -1):
layers.append(TransformerResamplingBlock(
in_channels=channel_dims[i + 1], out_channels=channel_dims[i],
stride=strides[i], type="decoder",
transformer_depth=transformer_depths[i],
sinusoidal_blocks=sinusoidal_blocks[i],
dtype=dtype, device=device, operations=operations, **kwargs))
self.layers = nn.ModuleList(layers)
def forward(self, x, **kwargs):
for layer in self.layers:
x = layer(x)
return x
class SoftNormBottleneck(nn.Module):
def __init__(self, dim=32, noise_augment_dim=0, noise_regularize=False,
auto_scale=False, freeze=False, dtype=None, device=None, **kwargs):
super().__init__()
self.noise_augment_dim = noise_augment_dim
self.noise_regularize = noise_regularize
self.scaling_factor = nn.Parameter(torch.empty(1, dim, 1, dtype=dtype, device=device))
self.bias = nn.Parameter(torch.empty(1, dim, 1, dtype=dtype, device=device))
self.noise_scaling_factor = nn.Parameter(torch.empty(1, noise_augment_dim, 1, dtype=dtype, device=device))
if auto_scale:
self.register_parameter("running_std", nn.Parameter(
torch.empty(1, dtype=dtype, device=device), requires_grad=False))
if freeze:
for p in self.parameters():
p.requires_grad = False
def encode(self, x, return_info=False, **kwargs):
x = x * comfy.ops.cast_to_input(self.scaling_factor, x) \
+ comfy.ops.cast_to_input(self.bias, x)
if hasattr(self, "running_std"):
x = x / comfy.ops.cast_to_input(self.running_std, x)
if return_info:
return x, {}
return x
def decode(self, x, **kwargs):
if hasattr(self, "running_std"):
x = x * comfy.ops.cast_to_input(self.running_std, x)
if self.noise_regularize:
scaling = self.running_std if hasattr(self, "running_std") \
else x.std(dim=-1, keepdim=True)
noise = torch.randn_like(x) * comfy.ops.cast_to_input(scaling, x) * 1e-3
x = x + noise
if self.noise_augment_dim > 0:
noise = comfy.ops.cast_to_input(self.noise_scaling_factor, x) * torch.randn(
x.shape[0], self.noise_augment_dim, x.shape[-1], device=x.device, dtype=x.dtype)
x = torch.cat([x, noise], dim=1)
return x
class PatchedPretransform(nn.Module):
def __init__(self, channels, patch_size, **kwargs):
super().__init__()
self.channels = channels
self.patch_size = patch_size
self.enable_grad = False
def _pad(self, x):
pad_len = (self.patch_size - x.shape[-1] % self.patch_size) % self.patch_size
if pad_len > 0:
x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1)
return x
def encode(self, x):
x = self._pad(x)
B, C, T = x.shape
h = self.patch_size
L = T // h
# b c (l h) -> b (c h) l
return x.reshape(B, C, L, h).permute(0, 1, 3, 2).reshape(B, C * h, L)
def decode(self, x):
B, Ch, L = x.shape
h = self.patch_size
C = Ch // h
# b (c h) l -> b c (l h)
return x.reshape(B, C, h, L).permute(0, 1, 3, 2).reshape(B, C, L * h)
class SA3AudioVAE(nn.Module):
"""SA3 VAE. State dict keys match checkpoint after stripping 'pretransform.model.'"""
def __init__(self, channels=256, transformer_depths=12, sinusoidal_blocks=8,
sliding_window=None, decoder_conv_mapping=False,
chunk_size=128, chunk_midpoint_shift=False,
dtype=None, device=None, operations=None):
super().__init__()
if operations is None:
operations = ops
self.pretransform = PatchedPretransform(channels=2, patch_size=256)
common_kwargs = dict(
differential=True, dyt=True, dim_heads=64,
sliding_window=sliding_window, variable_stride=True,
chunk_size=chunk_size, chunk_midpoint_shift=chunk_midpoint_shift,
dtype=dtype, device=device, operations=operations,
)
self.encoder = SAMEEncoder(
in_channels=512, channels=channels, c_mults=[6], strides=[16],
latent_dim=256, transformer_depths=[transformer_depths],
conv_mapping=False, **common_kwargs,
)
self.decoder = SAMEDecoder(
out_channels=512, channels=channels, c_mults=[6], strides=[16],
latent_dim=256, transformer_depths=[transformer_depths], sinusoidal_blocks=[sinusoidal_blocks],
conv_mapping=decoder_conv_mapping, **common_kwargs,
)
self.bottleneck = SoftNormBottleneck(
dim=256, noise_augment_dim=0, noise_regularize=True,
auto_scale=True, freeze=True,
dtype=dtype, device=device,
)
@torch.no_grad()
def _pretransform_encode(self, x):
return self.pretransform.encode(x)
@torch.no_grad()
def _pretransform_decode(self, x):
return self.pretransform.decode(x)
def encode(self, x):
x = self._pretransform_encode(x)
x = self.encoder(x)
x = self.bottleneck.encode(x)
return x
def decode(self, x):
x = self.bottleneck.decode(x)
x = self.decoder(x)
x = self._pretransform_decode(x)
return x

View File

@ -328,7 +328,7 @@ class CrossAttention(nn.Module):
kv = torch.cat((k, v), dim=-1)
split_size = kv.shape[-1] // self.num_heads // 2
kv = kv.view(1, -1, self.num_heads, split_size * 2)
kv = kv.view(b, -1, self.num_heads, split_size * 2)
k, v = torch.split(kv, split_size, dim=-1)
q = q.view(b, s1, self.num_heads, self.head_dim)
@ -398,7 +398,7 @@ class Attention(nn.Module):
qkv_combined = torch.cat((query, key, value), dim=-1)
split_size = qkv_combined.shape[-1] // self.num_heads // 3
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
qkv = qkv_combined.view(B, -1, self.num_heads, split_size * 3)
query, key, value = torch.split(qkv, split_size, dim=-1)
query = query.reshape(B, N, self.num_heads, self.head_dim)
@ -607,9 +607,9 @@ class HunYuanDiTPlain(nn.Module):
def forward(self, x, t, context, transformer_options = {}, **kwargs):
x = x.movedim(-1, -2)
uncond_emb, cond_emb = context.chunk(2, dim = 0)
context = torch.cat([cond_emb, uncond_emb], dim = 0)
if context.shape[0] >= 2:
uncond_emb, cond_emb = context.chunk(2, dim = 0)
context = torch.cat([cond_emb, uncond_emb], dim = 0)
main_condition = context
t = 1.0 - t
@ -657,5 +657,8 @@ class HunYuanDiTPlain(nn.Module):
output = self.final_layer(combined)
output = output.movedim(-2, -1) * (-1.0)
cond_emb, uncond_emb = output.chunk(2, dim = 0)
return torch.cat([uncond_emb, cond_emb])
if output.shape[0] >= 2:
cond_emb, uncond_emb = output.chunk(2, dim = 0)
return torch.cat([uncond_emb, cond_emb])
else:
return output

View File

@ -813,6 +813,85 @@ class StableAudio1(BaseModel):
sd["{}{}".format(k, l)] = s[l]
return sd
class StableAudio3(BaseModel):
def __init__(self, model_config, seconds_total_embedder_weights, padding_embedding=None, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=384, fourier_features_type=model_config.unet_config["timestep_features_type"])
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights)
if padding_embedding is not None:
self.padding_embedding = torch.nn.Parameter(padding_embedding, requires_grad=False)
else:
self.padding_embedding = None
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
image = kwargs.get("concat_latent_image", None)
if image is None:
shape_image = list(noise.shape)
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
image = self.process_latent_in(image)
# TODO: scale if not match
image = utils.resize_to_batch_size(image, noise.shape[0])
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :1]
else:
if mask.shape[1] != 1:
mask = torch.mean(mask, dim=1, keepdim=True)
mask = 1.0 - mask
# TODO: scale if not match
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((mask, image), dim=1)
def extra_conds(self, **kwargs):
out = {}
concat_cond = self.concat_cond(**kwargs)
if concat_cond is not None:
out['local_add_cond'] = comfy.conds.CONDNoiseShape(concat_cond)
noise = kwargs.get("noise", None)
device = kwargs["device"]
seconds_total = kwargs.get("seconds_total", int(noise.shape[-1] / 10.7666))
seconds_total_embed = self.seconds_total_embedder([seconds_total])[0].to(device)
global_embed = seconds_total_embed.reshape((1, -1))
out['global_embed'] = comfy.conds.CONDRegular(global_embed)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
cross_attn = cross_attn.to(device)
if self.padding_embedding is not None:
pe = self.padding_embedding.to(device=device, dtype=cross_attn.dtype)
max_text_tokens = self.model_config.unet_config.get("max_text_tokens", 256)
n_text = cross_attn.shape[1]
if n_text < max_text_tokens:
pad = pe.view(1, 1, -1).expand(cross_attn.shape[0], max_text_tokens - n_text, -1)
cross_attn = torch.cat([cross_attn, pad], dim=1)
cross_attn = torch.cat([cross_attn, seconds_total_embed.repeat((cross_attn.shape[0], 1, 1))], dim=1)
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
d = {"conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
for k in d:
s = d[k]
for l in s:
sd["{}{}".format(k, l)] = s[l]
if self.padding_embedding is not None:
sd["conditioner.conditioners.prompt.padding_embedding"] = self.padding_embedding.data
return sd
class HunyuanDiT(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):

View File

@ -116,6 +116,45 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
unet_config = {}
unet_config["audio_model"] = "dit1.0"
unet_config["global_cond_dim"] = state_dict['{}to_global_embed.0.weight'.format(key_prefix)].shape[1]
cond_embed = state_dict['{}to_cond_embed.0.weight'.format(key_prefix)]
unet_config["project_cond_tokens"] = cond_embed.shape[0] != cond_embed.shape[1]
unet_config["embed_dim"] = state_dict['{}to_timestep_embed.0.weight'.format(key_prefix)].shape[0]
mem_tokens = state_dict.get('{}transformer.memory_tokens'.format(key_prefix), None)
to_qkv = state_dict.get('{}transformer.layers.0.self_attn.to_qkv.weight'.format(key_prefix), None)
differential = False
if to_qkv is not None:
if to_qkv.shape[0] == to_qkv.shape[1] * 5:
differential = True
if mem_tokens is not None:
unet_config["num_memory_tokens"] = mem_tokens.shape[0]
if '{}transformer.layers.0.self_attn.q_norm.weight'.format(key_prefix) in state_dict:
unet_config["attn_kwargs"] = {"qk_norm": "ln", "feat_scale": True}
rms_norm = state_dict.get('{}transformer.layers.0.self_attn.q_norm.gamma'.format(key_prefix), None)
if rms_norm is not None:
unet_config["attn_kwargs"] = {"qk_norm": "rms", "differential": differential}
unet_config["norm_type"] = "rms_norm"
unet_config["num_heads"] = unet_config["embed_dim"] // rms_norm.shape[0]
if '{}timestep_features.weight'.format(key_prefix) in state_dict:
unet_config["timestep_features_type"] = "learned"
else:
unet_config["timestep_features_type"] = "expo"
io_channels = state_dict['{}postprocess_conv.weight'.format(key_prefix)].shape[0]
unet_config["io_channels"] = io_channels
unet_config["input_concat_dim"] = state_dict['{}transformer.project_in.weight'.format(key_prefix)].shape[1] - io_channels
local_add_cond = state_dict.get('{}transformer.layers.0.to_local_embed.0.weight'.format(key_prefix), None)
if local_add_cond is not None:
unet_config["local_add_cond_dim"] = local_add_cond.shape[1]
global_cond_embed = state_dict.get('{}transformer.global_cond_embedder.0.weight'.format(key_prefix), None)
if global_cond_embed is not None:
unet_config["global_cond_shared_embed"] = True
unet_config["global_cond_type"] = "adaLN"
unet_config["depth"] = count_blocks(state_dict_keys, '{}transformer.layers.'.format(key_prefix) + '{}.')
return unet_config
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit

View File

@ -260,7 +260,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# NOTE: offloadable=False is a legacy mode and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
if input is not None:

View File

@ -21,6 +21,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.cogvideo.vae
import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder
import comfy.ldm.audio.vae_sa3
import comfy.pixel_space_convert
import comfy.weight_adapter
import yaml
@ -67,6 +68,7 @@ import comfy.text_encoders.qwen35
import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4
import comfy.text_encoders.cogvideo
import comfy.text_encoders.sa3
import comfy.model_patcher
import comfy.lora
@ -854,6 +856,34 @@ class VAE:
self.working_dtypes = [torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
elif "decoder.layers.3.transformers.0.pre_norm.alpha" in sd: # Stable Audio 3 VAE
if "decoder.layers.3.transformers.11.self_attn.to_out.weight" in sd:
config = {"channels": 256, "transformer_depths": 12, "sinusoidal_blocks": 8,
"sliding_window": [1, 1], "decoder_conv_mapping": False,
"chunk_size": 128, "chunk_midpoint_shift": False}
self.memory_used_encode = lambda shape, dtype: (1500 * shape[2]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * 4096) * model_management.dtype_size(dtype)
else:
config = {"channels": 128, "transformer_depths": 6, "sinusoidal_blocks": 0,
"sliding_window": None, "decoder_conv_mapping": True,
"chunk_size": 32, "chunk_midpoint_shift": True}
self.memory_used_encode = lambda shape, dtype: (72 * shape[2]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (72 * shape[2] * 4096) * model_management.dtype_size(dtype)
self.first_stage_model = comfy.ldm.audio.vae_sa3.SA3AudioVAE(**config)
self.latent_channels = 256
self.output_channels = 2
self.upscale_ratio = 4096
self.downscale_ratio = 4096
self.latent_dim = 1
self.audio_sample_rate = 44100
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
#This VAE has Parameters and Buffers the non-dynamic caster cannot handle
#Force cast it for --disable-dynamic-vram users until there is a true core fix.
if not comfy.memory_management.aimdo_enabled:
self.disable_offload = True
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -1290,6 +1320,7 @@ class TEModel(Enum):
GEMMA_4_E4B = 29
GEMMA_4_E2B = 30
GEMMA_4_31B = 31
T5_GEMMA = 32
def detect_te_model(sd):
@ -1314,6 +1345,8 @@ def detect_te_model(sd):
if weight.shape[0] == 384:
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if "model.encoder.layers.0.pre_self_attn_layernorm.weight" in sd:
return TEModel.T5_GEMMA
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.59.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_4_31B
@ -1463,6 +1496,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
elif te_model == TEModel.T5_GEMMA:
clip_target.clip = comfy.text_encoders.sa3.SAT5GemmaModel
clip_target.tokenizer = comfy.text_encoders.sa3.SAT5GemmaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,

View File

@ -7,6 +7,7 @@ from . import sdxl_clip
import comfy.text_encoders.sd2_clip
import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5
import comfy.text_encoders.sa3
import comfy.text_encoders.aura_t5
import comfy.text_encoders.pixart_t5
import comfy.text_encoders.hydit
@ -603,6 +604,29 @@ class StableAudio(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.sa_t5.SAT5Tokenizer, comfy.text_encoders.sa_t5.SAT5Model)
class StableAudio3(StableAudio):
unet_config = {
"audio_model": "dit1.0",
"global_cond_shared_embed": True,
}
sampling_settings = {
"multiplier": 1.0,
"shift": 2.0,
}
latent_format = latent_formats.StableAudio3
memory_usage_factor = 7
def get_model(self, state_dict, prefix="", device=None):
seconds_total_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_total.": ""}, filter_keys=True)
padding_embedding = state_dict.get("conditioner.conditioners.prompt.padding_embedding", None)
return model_base.StableAudio3(self, seconds_total_embedder_weights=seconds_total_sd, padding_embedding=padding_embedding, device=device)
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.sa3.SAT5GemmaTokenizer, comfy.text_encoders.sa3.SAT5GemmaModel)
class AuraFlow(supported_models_base.BASE):
unet_config = {
"cond_seq_dim": 2048,
@ -2018,6 +2042,7 @@ models = [
SV3D_u,
SV3D_p,
SD3,
StableAudio3,
StableAudio,
AuraFlow,
PixArtAlpha,

207
comfy/text_encoders/sa3.py Normal file
View File

@ -0,0 +1,207 @@
import torch
import torch.nn as nn
from comfy import sd1_clip
from comfy.text_encoders.llama import Attention as LlamaAttention, RMSNorm, MLP, precompute_freqs_cis, apply_rope, _make_scaled_embedding
from comfy.text_encoders.spiece_tokenizer import SPieceTokenizer
class T5GemmaEncoderConfig:
def __init__(self):
self.vocab_size = 256000
self.hidden_size = 768
self.intermediate_size = 2048
self.num_hidden_layers = 12
self.num_attention_heads = 12
self.num_key_value_heads = 12
self.head_dim = 64
self.rms_norm_eps = 1e-6
self.rms_norm_add = False
self.rope_theta = 10000.0
self.attn_logit_softcapping = 50.0
self.query_pre_attn_scalar = 64
self.sliding_window = 4096
self.mlp_activation = "gelu_pytorch_tanh"
self.layer_types = ["sliding_attention", "full_attention"] * 6
self.qkv_bias = False
self.q_norm = None
self.k_norm = None
self.rms_norm_add = True
class T5GemmaAttention(LlamaAttention):
"""Reuses LlamaAttention projection setup; overrides forward for softcap attention.
T5Gemma applies tanh(QK^T * scale / cap) * cap between the matmul and softmax.
This nonlinearity is incompatible with fused SDPA kernels, so attention is
computed manually. Everything else (projections, RoPE, GQA expansion) is identical
to LlamaAttention so __init__ is inherited unchanged.
"""
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__(config, device=device, dtype=dtype, ops=ops)
self.scale = config.query_pre_attn_scalar ** -0.5
self.softcap = config.attn_logit_softcapping
def forward(self, hidden_states, attention_mask=None, freqs_cis=None, **kwargs):
B, S, _ = hidden_states.shape
xq = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
xk = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
xv = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
xq, xk = apply_rope(xq, xk, freqs_cis)
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
attn = torch.matmul(xq * self.scale, xk.transpose(-2, -1))
attn = torch.tanh(attn / self.softcap) * self.softcap
if attention_mask is not None:
attn = attn + attention_mask
attn = torch.nn.functional.softmax(attn.float(), dim=-1).to(xq.dtype)
out = torch.matmul(attn, xv).transpose(1, 2).reshape(B, S, self.inner_size)
return self.o_proj(out), None
class T5GemmaBlock(nn.Module):
def __init__(self, config, layer_type, device=None, dtype=None, ops=None):
super().__init__()
self.self_attn = T5GemmaAttention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
# Names match checkpoint keys: model.encoder.layers.X.<name>.weight
self.pre_self_attn_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype)
self.post_self_attn_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype)
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype)
self.is_sliding = (layer_type == "sliding_attention")
self.sliding_window = config.sliding_window
def forward(self, x, attention_mask=None, freqs_cis=None):
attn_mask = attention_mask
if self.is_sliding and x.shape[1] > self.sliding_window:
S = x.shape[1]
pos = torch.arange(S, device=x.device)
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs()
sw_mask = torch.zeros(S, S, dtype=x.dtype, device=x.device)
sw_mask.masked_fill_(dist > self.sliding_window, -torch.finfo(x.dtype).max)
sw_mask = sw_mask.unsqueeze(0).unsqueeze(0)
attn_mask = (attention_mask + sw_mask) if attention_mask is not None else sw_mask
residual = x
x = self.pre_self_attn_layernorm(x)
x, _ = self.self_attn(x, attention_mask=attn_mask, freqs_cis=freqs_cis)
x = self.post_self_attn_layernorm(x)
x = residual + x
residual = x
x = self.pre_feedforward_layernorm(x)
x = self.mlp(x)
x = self.post_feedforward_layernorm(x)
x = residual + x
return x
class T5GemmaEncoder(nn.Module):
"""Encoder stack: embed_tokens, layers, norm.
Keys: embed_tokens.*, layers.X.*, norm.*"""
def __init__(self, config, device, dtype, ops):
super().__init__()
self.config = config
# Gemma-style scaled embedding: output *= sqrt(hidden_size)
self.embed_tokens = _make_scaled_embedding(
ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
self.layers = nn.ModuleList([
T5GemmaBlock(config, config.layer_types[i], device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype)
def forward(self, input_ids, attention_mask=None, embeds=None, intermediate_output=None,
final_layer_norm_intermediate=True, dtype=None, num_layers=None):
x = embeds if embeds is not None else self.embed_tokens(input_ids, out_dtype=dtype or torch.float32)
seq_len = x.shape[1]
position_ids = torch.arange(seq_len, device=x.device).unsqueeze(0)
freqs_cis = precompute_freqs_cis(self.config.head_dim, position_ids, self.config.rope_theta, device=x.device)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape(
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
intermediate = None
for i, layer in enumerate(self.layers):
x = layer(x, attention_mask=mask, freqs_cis=freqs_cis)
if i == intermediate_output:
intermediate = x.clone()
x = self.norm(x)
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.norm(intermediate)
return x, intermediate
class T5GemmaBody(nn.Module):
"""Provides the 'encoder' sub-module.
Keys: encoder.*"""
def __init__(self, config, device, dtype, ops):
super().__init__()
self.encoder = T5GemmaEncoder(config, device, dtype, ops)
class T5GemmaModel(nn.Module):
"""Top-level model class passed to SDClipModel as model_class.
Module layout: self.model.encoder.* → matches checkpoint keys model.encoder.*"""
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = T5GemmaEncoderConfig()
self.num_layers = config.num_hidden_layers
self.dtype = dtype
self.model = T5GemmaBody(config, device, dtype, operations)
def get_input_embeddings(self):
return self.model.encoder.embed_tokens
def set_input_embeddings(self, embeddings):
self.model.encoder.embed_tokens = embeddings
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None,
intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, **kwargs):
if intermediate_output is not None and intermediate_output < 0:
intermediate_output = self.num_layers + intermediate_output
return self.model.encoder(
input_ids, attention_mask=attention_mask, embeds=embeds,
intermediate_output=intermediate_output,
final_layer_norm_intermediate=final_layer_norm_intermediate,
dtype=dtype, num_layers=self.num_layers)
class T5GemmaSDClipModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx,
textmodel_json_config={}, dtype=dtype,
special_tokens={"pad": 0},
model_class=T5GemmaModel,
enable_attention_masks=True, zero_out_masked=True,
model_options=model_options)
class T5GemmaSDTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_model = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer_model, pad_with_end=False, embedding_size=768,
embedding_key="t5gemma", tokenizer_class=SPieceTokenizer,
has_start_token=False, has_end_token=False, pad_to_max_length=False,
max_length=99999999, min_length=1, pad_token=0,
tokenizer_data=tokenizer_data,
tokenizer_args={"add_bos": False, "add_eos": False})
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class SAT5GemmaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory,
tokenizer_data=tokenizer_data, clip_name="t5gemma", tokenizer=T5GemmaSDTokenizer)
class SAT5GemmaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options,
name="t5gemma", clip_model=T5GemmaSDClipModel, **kwargs)

View File

@ -77,7 +77,7 @@ class EmptyLTXVLatentVideo(io.ComfyNode):
@classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent})
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 32})
generate = execute # TODO: remove

View File

@ -1,10 +1,41 @@
import re
import json
import string
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class StringFormat(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
autogrow = io.Autogrow.TemplateNames(
input=io.AnyType.Input("value"),
names=list(string.ascii_lowercase),
min=0,
)
return io.Schema(
node_id="StringFormat",
display_name="Format Text",
category="text",
search_aliases=["string", "format"],
description="Same as Python's string format method. Supports all of Python's format options and features.",
inputs=[
io.Autogrow.Input("values", template=autogrow),
io.String.Input("f_string", default="{a}", multiline=True),
],
outputs=[
io.String.Output(),
],
)
@classmethod
def execute(
cls, values: io.Autogrow.Type, f_string: str
) -> io.NodeOutput:
return io.NodeOutput(f_string.format(**values))
class StringConcatenate(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -413,6 +444,7 @@ class StringExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
StringFormat,
StringConcatenate,
StringSubstring,
StringLength,

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.21.1"
__version__ = "0.22.0"

View File

@ -4160,6 +4160,10 @@ paths:
name:
type: string
description: Display name for the API key
description:
type: string
description: User-provided description of the key's purpose
maxLength: 5000
responses:
"201":
description: API key created
@ -6351,14 +6355,6 @@ components:
type: integer
format: int64
description: Size of the asset in bytes
width:
type: integer
nullable: true
description: "Original image width in pixels. Null for non-image assets or assets ingested before dimension extraction."
height:
type: integer
nullable: true
description: "Original image height in pixels. Null for non-image assets or assets ingested before dimension extraction."
mime_type:
type: string
description: MIME type of the asset
@ -7685,11 +7681,16 @@ components:
required:
- id
- name
- description
properties:
id:
type: string
name:
type: string
description:
type: string
maxLength: 5000
description: User-provided description of the key's purpose. Always present in responses; empty string when no description was supplied on create.
prefix:
type: string
description: First few characters of the key for identification
@ -7710,12 +7711,17 @@ components:
required:
- id
- name
- description
- key
properties:
id:
type: string
name:
type: string
description:
type: string
maxLength: 5000
description: User-provided description of the key's purpose. Always present in responses; empty string when no description was supplied on create.
key:
type: string
description: Full API key value (only returned on creation)

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.21.1"
version = "0.22.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.43.18
comfyui-workflow-templates==0.9.77
comfyui-workflow-templates==0.9.79
comfyui-embedded-docs==0.5.0
torch
torchsde