mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 05:25:50 +08:00
Compare commits
779 Commits
node-memor
...
autogrow-v
| Author | SHA1 | Date | |
|---|---|---|---|
| 9c74c82ddf | |||
| 3adb63a098 | |||
| 12256ff5af | |||
| a498556d0d | |||
| f7ca41ff62 | |||
| ac26065e61 | |||
| 190c4416cc | |||
| 0fd10ffa09 | |||
| 00c775950a | |||
| 7ac999bf30 | |||
| 0c6b36c6ac | |||
| 9125613b53 | |||
| 732b707397 | |||
| 4c816d5c69 | |||
| 6125b3a5e7 | |||
| 12918a5f78 | |||
| 8f40b43e02 | |||
| 3b832231bb | |||
| be518db5a7 | |||
| 80441eb15e | |||
| 07f2462eae | |||
| d150440466 | |||
| 6165c38cb5 | |||
| 712cca36a1 | |||
| ac4d8ea9b3 | |||
| c9196f355e | |||
| 7eb959ce93 | |||
| 469dd9c16a | |||
| eff2b9d412 | |||
| 15b312de7a | |||
| 1419047fdb | |||
| 79f6bb5e4f | |||
| e4b4fb3479 | |||
| d9dc02a7d6 | |||
| c543ad81c3 | |||
| 5ac1372533 | |||
| 1dcbd9efaf | |||
| db9e6edfa1 | |||
| 8af13b439b | |||
| acd0e53653 | |||
| 117e7a5853 | |||
| b3c0e4de57 | |||
| ecaeeb990d | |||
| c2b65e2fce | |||
| fd5c0755af | |||
| c881a1d689 | |||
| a3b5d4996a | |||
| c6238047ee | |||
| 5cd1113236 | |||
| 2f642d5d9b | |||
| cd912963f1 | |||
| 6e4b1f9d00 | |||
| dc202a2e51 | |||
| 153bc524bf | |||
| 393d2880dd | |||
| 4484b93d61 | |||
| bd0e6825e8 | |||
| ec0a832acb | |||
| 04c49a29b4 | |||
| 4609fcd260 | |||
| 6207f86c18 | |||
| 1dc3da6314 | |||
| 114fc73685 | |||
| b48d6a83d4 | |||
| 027042db68 | |||
| 1a20656448 | |||
| 0f11869d55 | |||
| 5943fbf457 | |||
| a60b7b86c5 | |||
| 2e9d51680a | |||
| 50d6e1caf4 | |||
| ac12f77bed | |||
| fcd9a236b0 | |||
| 21e8425087 | |||
| b6c79a648a | |||
| 25bc1b5b57 | |||
| 3cd19e99c1 | |||
| 007b87e7ac | |||
| 34751fe9f9 | |||
| 1c705f7bfb | |||
| 48e5ea1dfd | |||
| 3cd7b32f1b | |||
| c0c9720d77 | |||
| fc0cb10bcb | |||
| b7d7cc1d49 | |||
| 79e94544bd | |||
| ce0000c4f2 | |||
| c5cfb34c07 | |||
| edee33f55e | |||
| 2c03884f5f | |||
| 6e9ee55cdd | |||
| 023cf13721 | |||
| c3566c0d76 | |||
| c3c3e93c5b | |||
| 6ffc159bdd | |||
| 96e0d0924e | |||
| e14f3b6610 | |||
| 1618002411 | |||
| 6ef85c4915 | |||
| 6da00dd899 | |||
| 4f3f9e72a9 | |||
| d157c3299d | |||
| d1b9822f74 | |||
| f2b002372b | |||
| 38d0493825 | |||
| acbf08cd60 | |||
| 53e762a3af | |||
| 9a552df898 | |||
| f2fda021ab | |||
| 303b1735f8 | |||
| 9e5f677746 | |||
| 65cfcf5b1b | |||
| 1bdc9a947f | |||
| d622a61874 | |||
| 236b9e211d | |||
| 6ca3d5c011 | |||
| 0be8a76c93 | |||
| 0357ed7ec4 | |||
| f59f71cf34 | |||
| 178bdc5e14 | |||
| 25a1bfab4e | |||
| d7111e426a | |||
| 0e6221cc79 | |||
| 9ca7e143af | |||
| 8fd07170f1 | |||
| 2943093a53 | |||
| 36deef2c57 | |||
| 0d2e4bdd44 | |||
| eff4ea0b62 | |||
| 865568b7fc | |||
| 1e4e342f54 | |||
| 16fb6849d2 | |||
| d9a76cf66e | |||
| 532e285079 | |||
| 4f067b07fb | |||
| 650e716dda | |||
| e4c61d7555 | |||
| 22ff1bbfcb | |||
| f4f44bb807 | |||
| 33aa808713 | |||
| eb0e10aec4 | |||
| c176b214cc | |||
| 91bf6b6aa3 | |||
| 807538fe6c | |||
| bbb11e2608 | |||
| 0899012ad6 | |||
| fb478f679a | |||
| 4c432c11ed | |||
| 31e961736a | |||
| 767ee30f21 | |||
| 3ab9748903 | |||
| 0aa7fa464e | |||
| 514c24d756 | |||
| 809ce68749 | |||
| cc4ddba1b6 | |||
| 8376ff6831 | |||
| 5b4d0664c8 | |||
| 894802b0f9 | |||
| 28eaab608b | |||
| 6a2678ac65 | |||
| e4fb3a3572 | |||
| e8ebbe668e | |||
| 1ca89b810e | |||
| bf7dc63bd6 | |||
| 86dbb89fc9 | |||
| ba6080bbab | |||
| 16d85ea133 | |||
| 5d9ad0c6bf | |||
| c08f97f344 | |||
| 887143854b | |||
| 3a5f239cb6 | |||
| 827bb1512b | |||
| ffdd53b327 | |||
| 65e2103b09 | |||
| 9304e47351 | |||
| bc606d7d64 | |||
| 645ee1881e | |||
| 3d082c3206 | |||
| 683569de55 | |||
| ea2c117bc3 | |||
| fc4af86068 | |||
| 41bcf0619d | |||
| d02d0e5744 | |||
| 70541d4e77 | |||
| 77b2f7c228 | |||
| 43e0d4e3cc | |||
| dbd330454a | |||
| 33c7f1179d | |||
| af91eb6c99 | |||
| 5cb1e0c9a0 | |||
| 51347f9fb8 | |||
| a5e85017d8 | |||
| 5ac3b26a7d | |||
| 6592bffc60 | |||
| 971cefe7d4 | |||
| da2bfb5b0a | |||
| c5a47a1692 | |||
| 908fd7d749 | |||
| 5495589db3 | |||
| 982876d59a | |||
| 338d9ae3bb | |||
| eeb020b9b7 | |||
| ae65433a60 | |||
| fdebe18296 | |||
| f8321eb57b | |||
| 93948e3fc5 | |||
| e711aaf1a7 | |||
| 57ddb7fd13 | |||
| 17c92a9f28 | |||
| 36357bbcc3 | |||
| f668c2e3c9 | |||
| fc657f471a | |||
| 791e30ff50 | |||
| e2a800e7ef | |||
| 9d252f3b70 | |||
| b9fb542703 | |||
| cabc4d351f | |||
| e136b6dbb0 | |||
| d50f342c90 | |||
| 3b0368aa34 | |||
| 935493f6c1 | |||
| 60ee574748 | |||
| 8e889c535d | |||
| fd271dedfd | |||
| c3c6313fc7 | |||
| 85c4b4ae26 | |||
| 058f084371 | |||
| ec7f65187d | |||
| 56fa7dbe38 | |||
| 329480da5a | |||
| 4086acf3c2 | |||
| 50ca97e776 | |||
| 7ac7d69d94 | |||
| 76f18e955d | |||
| d7a0aef650 | |||
| 913f86b727 | |||
| 117bf3f2bd | |||
| ae676ed105 | |||
| fd109325db | |||
| bed12674a1 | |||
| 092ee8a500 | |||
| 79d17ba233 | |||
| 6fd463aec9 | |||
| 43071e3de3 | |||
| 0ec05b1481 | |||
| 35fa091340 | |||
| 3c8456223c | |||
| 9bc893c5bb | |||
| f4bdf5f830 | |||
| 6be85c7920 | |||
| ea17add3c6 | |||
| ecdc8697d5 | |||
| dce518c2b4 | |||
| 440268d394 | |||
| 87c104bfc1 | |||
| 19f2192d69 | |||
| 519c941165 | |||
| 861817d22d | |||
| c120eee5ba | |||
| 73f5649196 | |||
| 3f512f5659 | |||
| b94d394a64 | |||
| 277237ccc1 | |||
| daaceac769 | |||
| 33d6aec3b7 | |||
| 44baa0b7f3 | |||
| a17cf1c387 | |||
| b4a20acc54 | |||
| c55dc857d5 | |||
| 878db3a727 | |||
| 30c259cac8 | |||
| 1cb7e22a95 | |||
| 2640acb31c | |||
| 7dbd5dfe91 | |||
| f8b981ae9a | |||
| 4967f81778 | |||
| 0a6746898d | |||
| 5151cff293 | |||
| af96d9812d | |||
| 52a32e2b32 | |||
| b907085709 | |||
| 065a2fbbec | |||
| 0ff0457892 | |||
| 6484ac89dc | |||
| f55c98a89f | |||
| ca7808f240 | |||
| 52e778fff3 | |||
| 9d8a817985 | |||
| b59750a86a | |||
| 3f382a4f98 | |||
| f17251bec6 | |||
| c38e7d6599 | |||
| eaf68c9b5b | |||
| cc6a8dcd1a | |||
| a2d60aad0f | |||
| d8433c63fd | |||
| dd41b74549 | |||
| 55f654db3d | |||
| 58c6ed541d | |||
| 234c3dc85f | |||
| 8908ee2628 | |||
| 1105e0d139 | |||
| 8938aa3f30 | |||
| f16219e3aa | |||
| 8402c8700a | |||
| 58b8574661 | |||
| 90b3995ec8 | |||
| bdb10a583f | |||
| 0e24dbb19f | |||
| e9aae31fa2 | |||
| 0c18842acb | |||
| d196a905bb | |||
| 18b79acba9 | |||
| dff996ca39 | |||
| 828b1b9953 | |||
| af81cb962d | |||
| 5c7b08ca58 | |||
| 6b573ae0cb | |||
| 015a0599d0 | |||
| acfaa5c4a1 | |||
| b6805429b9 | |||
| 25022e0b09 | |||
| 22a2644e57 | |||
| b2ef58e2b1 | |||
| 6a6d456c88 | |||
| 3d1fdaf9f4 | |||
| 1286fcfe40 | |||
| 3bd71554a2 | |||
| f66183a541 | |||
| cbd68e3d58 | |||
| d89c29f259 | |||
| a9c35256bc | |||
| 532938b16b | |||
| ecb683b057 | |||
| c55fd74816 | |||
| 3398123752 | |||
| 943b3b615d | |||
| 10e90a5757 | |||
| b75d349f25 | |||
| 7b8389578e | |||
| 9e00ce5b76 | |||
| f5e66d5e47 | |||
| 87b0359392 | |||
| cb96d4d18c | |||
| 394348f5ca | |||
| 7601e89255 | |||
| 6a1d3a1ae1 | |||
| 65ee24c978 | |||
| 17027f2a6a | |||
| b5c8be8b1d | |||
| 24fdb92edf | |||
| d526974576 | |||
| e1ab6bb394 | |||
| 048f49adbd | |||
| 47bfd5a33f | |||
| fdf49a2861 | |||
| f41e5f398d | |||
| 27cbac865e | |||
| 3d0003c24c | |||
| 7d6103325e | |||
| 2d4a08b717 | |||
| 9a02382568 | |||
| bd01d9f7fd | |||
| 443056c401 | |||
| f60923590c | |||
| 1ef328c007 | |||
| 94c298f962 | |||
| 2fde9597f4 | |||
| f91078b1ff | |||
| 3b3ef9a77a | |||
| 8b0b93df51 | |||
| 1c7eaeca10 | |||
| 18e7d6dba5 | |||
| e1d85e7577 | |||
| 1199411747 | |||
| 5ebcab3c7d | |||
| c350009236 | |||
| dea899f221 | |||
| e632e5de28 | |||
| 2abd2b5c20 | |||
| a1a70362ca | |||
| cf97b033ee | |||
| eb1c42f649 | |||
| e05c907126 | |||
| 09dc24c8a9 | |||
| 1d69245981 | |||
| 97f198e421 | |||
| bda0eb2448 | |||
| c4a6b389de | |||
| 4cd881866b | |||
| 265adad858 | |||
| 7f3e4d486c | |||
| a389ee01bb | |||
| 9c71a66790 | |||
| af4b7b5edb | |||
| 0f4ef3afa0 | |||
| 6b88478f9f | |||
| e199c8cc67 | |||
| 0652cb8e2d | |||
| 958a17199a | |||
| e974e554ca | |||
| 4e2110c794 | |||
| e617cddf24 | |||
| 1f3f7a2823 | |||
| 88df172790 | |||
| 6d6a18b0b7 | |||
| 97ff9fae7e | |||
| 135fa49ec2 | |||
| 44869ff786 | |||
| 20182a393f | |||
| 5f109fe6a0 | |||
| c58c13b2ba | |||
| 7f374e42c8 | |||
| 27d1bd8829 | |||
| 614cf9805e | |||
| 513b0c46fb | |||
| dfac94695b | |||
| 163b629c70 | |||
| 998bf60beb | |||
| 906c089957 | |||
| 25de7b1bfa | |||
| ab7ab5be23 | |||
| ec4fc2a09a | |||
| 1a58087ac2 | |||
| 6c14f3afac | |||
| e525673f72 | |||
| 3fa7a5c04a | |||
| 210f7a1ba5 | |||
| d202c2ba74 | |||
| 8817f8fc14 | |||
| 22e40d2ace | |||
| 3bea4efc6b | |||
| 8cf2ba4ba6 | |||
| b61a40cbc9 | |||
| f2bb3230b7 | |||
| 614b8d3345 | |||
| 6abc30aae9 | |||
| 55bad30375 | |||
| c305deed56 | |||
| 601ee1775a | |||
| c170fd2db5 | |||
| 9d529e5308 | |||
| f6bbc1ac84 | |||
| 098a352f13 | |||
| e86b79ab9e | |||
| 426cde37f1 | |||
| dd5af0c587 | |||
| 388b306a2b | |||
| 24188b3141 | |||
| 1bcda6df98 | |||
| a1864c01f2 | |||
| 4739d7717f | |||
| f13cff0be6 | |||
| 9cdc64998f | |||
| 560b1bdfca | |||
| b7992f871a | |||
| 2c2aa409b0 | |||
| a4787ac83b | |||
| b5c59b763c | |||
| b4f30bd408 | |||
| dad076aee6 | |||
| 0cf33953a7 | |||
| 5b80addafd | |||
| 9da397ea2f | |||
| 92d97380bd | |||
| 99ce2a1f66 | |||
| b1467da480 | |||
| d8d60b5609 | |||
| b1293d50ef | |||
| 19b466160c | |||
| bc0ad9bb49 | |||
| 4054b4bf38 | |||
| 55ac7d333c | |||
| afa8a24fe1 | |||
| 493b81e48f | |||
| 6b035bfce2 | |||
| 74b7f0b04b | |||
| f72c6616b2 | |||
| 1c10b33f9b | |||
| ddfce1af4f | |||
| 7a883849ea | |||
| 84867067ea | |||
| 3374e900d0 | |||
| 51696e3fdc | |||
| dfff7e5332 | |||
| e4ea393666 | |||
| c8674bc6e9 | |||
| 3dfdcf66b6 | |||
| 95ca2e56c8 | |||
| 27ffd12c45 | |||
| e693e4db6a | |||
| d68ece7301 | |||
| 894837de9a | |||
| fdc92863b6 | |||
| a125cd84b0 | |||
| 84e9ce32c6 | |||
| f43b8ab2a2 | |||
| 14d642acd6 | |||
| aa895db7e8 | |||
| cdfc25a160 | |||
| 81e4dac107 | |||
| 90853fb9cd | |||
| f1dd6e50f8 | |||
| fc0fbf141c | |||
| f3d5d328a3 | |||
| 139addd53c | |||
| cbee7d3390 | |||
| 6732014a0a | |||
| 989f715d92 | |||
| 2ba8d7cce8 | |||
| 51fb505ffa | |||
| 72c2071972 | |||
| 6e59934089 | |||
| 3e0eb8d33f | |||
| 637221995f | |||
| 51697d50dc | |||
| 19f595b788 | |||
| 8a15568f10 | |||
| 9e984c48bc | |||
| fc34c3d112 | |||
| 8aea746212 | |||
| 8c19910427 | |||
| e77e0a8f8f | |||
| a49007a7b0 | |||
| 6ae3515801 | |||
| 6bd3f8eb9f | |||
| 7326e46dee | |||
| 195e0b0639 | |||
| 187f43696d | |||
| caf07331ff | |||
| b1fa1922df | |||
| 2ed74f7ac7 | |||
| 22f99fb97e | |||
| bbd683098e | |||
| 08726b64fe | |||
| 93d859cfaa | |||
| 4614ee09ca | |||
| 5c8e986e27 | |||
| 8c26d7bbe6 | |||
| d7aa414141 | |||
| 3e68bc342c | |||
| c2c5a7d5f8 | |||
| 8a293372ec | |||
| ed3ca78e08 | |||
| 4ffea0e864 | |||
| 1395bce9f7 | |||
| e9364ee279 | |||
| f6e3e9a456 | |||
| 8f4ee9984c | |||
| 0e9d1724be | |||
| 4965c0e2ac | |||
| 911331c06c | |||
| bb32d4ec31 | |||
| a6f83a4a1a | |||
| e4f99b479a | |||
| d9c0a4053d | |||
| 11bab7be76 | |||
| 3af1881455 | |||
| e0210ce0a7 | |||
| 7eb7160db4 | |||
| 638097829d | |||
| c4a8cf60ab | |||
| bab8ba20bf | |||
| b682a73c55 | |||
| 631b9ae861 | |||
| f48d7230de | |||
| 6e079abc3a | |||
| 977a4ed8c5 | |||
| 414a178fb6 | |||
| 447884b657 | |||
| bed4b49d08 | |||
| 342cf644ce | |||
| 3758848423 | |||
| 0db6aabed3 | |||
| 1673ace19b | |||
| 7f38e4c538 | |||
| 8accf50908 | |||
| ed0f4a609b | |||
| 041b8824f5 | |||
| b1111c2062 | |||
| 05a258efd8 | |||
| c8276f8c6b | |||
| 6ec1cfe101 | |||
| b60dc31627 | |||
| 555f902fc1 | |||
| 1364548c72 | |||
| 2dadb34860 | |||
| 1cf86f5ae5 | |||
| a1127b232d | |||
| 896f2e653c | |||
| 40ae495ddc | |||
| 653ceab414 | |||
| 160698eb41 | |||
| 7eca95657c | |||
| ad5aef2d0c | |||
| bcfd80dd79 | |||
| 6b4b671ce7 | |||
| a9cf1cd249 | |||
| 255572188f | |||
| 0572029fee | |||
| 196954ab8c | |||
| 1e098d6132 | |||
| cd66d72b46 | |||
| 2103e39335 | |||
| d20576e6a3 | |||
| a061b06321 | |||
| 80718908a9 | |||
| 7ea173c187 | |||
| 76eb1d72c3 | |||
| c4a46e943c | |||
| 2b7f9a8196 | |||
| ce4cb2389c | |||
| c8d2117f02 | |||
| fccab99ec0 | |||
| fd79d32f38 | |||
| 341b4adefd | |||
| b8730510db | |||
| e808790799 | |||
| 145b0e4f79 | |||
| 707b2638ec | |||
| 8a5ac527e6 | |||
| e3206351b0 | |||
| 1fee8827cb | |||
| 27bc181c49 | |||
| d1d9eb94b1 | |||
| 7be2b49b6b | |||
| 9ed3c5cc09 | |||
| 66241cef31 | |||
| e8df53b764 | |||
| 852704c81a | |||
| 9fdf8c25ab | |||
| dc95b6acc0 | |||
| 711bcf33ee | |||
| 24b0fce099 | |||
| 1ea8c54064 | |||
| 8d6653fca6 | |||
| dd611a7700 | |||
| 9288c78fc5 | |||
| e42682b24e | |||
| a39ac59c3e | |||
| 1a85483da1 | |||
| 47a9cde5d3 | |||
| 4f1f26ac6c | |||
| f228367c5e | |||
| 80b7c9455b | |||
| c1297f4eb3 | |||
| e5e70636e7 | |||
| 29bf807b0e | |||
| 2559dee492 | |||
| a3b04de700 | |||
| d7f40442f9 | |||
| b149e2e1e3 | |||
| 581bae2af3 | |||
| af99928f22 | |||
| 53c9c7d39a | |||
| ba68e83f1c | |||
| dcb8834983 | |||
| f9d2e4b742 | |||
| 45bc1f5c00 | |||
| 0aa074a420 | |||
| 7757d5a657 | |||
| e600520f8a | |||
| fd2b820ec2 | |||
| d6b977b2e6 | |||
| 15ec9ea958 | |||
| 33bd9ed9cb | |||
| 18de0b2830 | |||
| df6850fae8 | |||
| e01e99d075 | |||
| 72212fef66 | |||
| df34f1549a | |||
| 9b0553809c | |||
| 8d7c930246 | |||
| de44b95db6 | |||
| 543888d3d8 | |||
| 70fc0425b3 | |||
| 85e34643f8 | |||
| 5c33872e2f | |||
| 206595f854 | |||
| b288fb0db8 | |||
| f73b176abd | |||
| 103a12cb66 | |||
| 97652d26b8 | |||
| bd1d9bcd5f | |||
| fb763d4333 | |||
| bcbd7884e3 | |||
| 27a0fcccc3 | |||
| ea6cdd2631 | |||
| 2ee7879a0b | |||
| 3493b9cb1f | |||
| c9ebe70072 | |||
| 261421e218 | |||
| a9f1bb10a5 | |||
| b0338e930b | |||
| b71f9bcb71 | |||
| 72855db715 | |||
| f48d05a2d1 | |||
| 4368d8f87f | |||
| 22da0a83e9 | |||
| 50333f1715 | |||
| 26d5b86da8 | |||
| 4f5812b937 | |||
| 1bcb469089 | |||
| 464ba1d614 | |||
| e3018c2a5a | |||
| 3412d53b1d | |||
| e2d1e5dad9 | |||
| 27e067ce50 | |||
| 9b15155972 | |||
| 32a627bf1f | |||
| fe442fac2e | |||
| d2c502e629 | |||
| fea9ea8268 | |||
| f949094b3c | |||
| 4449e14769 | |||
| 885015eecf | |||
| a86aaa4301 | |||
| 2efb2cbc38 | |||
| 15aa9222c4 | |||
| c7bb3e2bce | |||
| e80a14ad50 | |||
| d28b39d93d | |||
| 1c184c29eb | |||
| edde0b5043 | |||
| 0063610177 | |||
| ce0052c087 | |||
| 0eb821a7b6 | |||
| 4aa79dbf2c | |||
| 38f697d953 | |||
| 3aad339b63 | |||
| 491755325c | |||
| 496888fd68 | |||
| b5ac6ed7ce | |||
| b20ba1f27c | |||
| 31a37686d0 | |||
| 88aee596a3 | |||
| 6a193ac557 | |||
| 47f4db3e84 | |||
| 5352abc6d3 | |||
| 39aa06bd5d | |||
| 914c2a2973 | |||
| e633a47ad1 | |||
| f6b93d41a0 | |||
| 95ac7794b7 | |||
| 71ed4a399e | |||
| 3e316c6338 | |||
| 8be0d22ab7 | |||
| 59eddda900 | |||
| 41048c69b4 | |||
| fc247150fe | |||
| fe31ad0276 | |||
| ca4e96a8ae | |||
| 050c67323c | |||
| 497d41fb50 | |||
| ff57793659 | |||
| f7bd5e58dd | |||
| 7ed73d12d1 | |||
| eb39019daa | |||
| bab08f40d1 | |||
| bc49106837 | |||
| 1b2de2642d | |||
| 9fa1036f60 | |||
| 0737b7e0d2 | |||
| 0963493a9c | |||
| e73a9dbe30 | |||
| fe01885acf | |||
| 7139d6d93f | |||
| 2f52e8f05f | |||
| 8d38ea3bbf | |||
| 5a8f502db5 | |||
| 7cd2c4bd6a | |||
| dfa791eb4b | |||
| bddd69618b | |||
| 54d8fdbed0 | |||
| d844d8b13b | |||
| 07a927517c | |||
| f16a70ba67 | |||
| 36b5127fd3 | |||
| 4977f203fa |
@ -53,6 +53,16 @@ try:
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash") # noqa: T201
|
||||
except:
|
||||
print("Could not stash, cleaning index and trying again.") # noqa: T201
|
||||
repo.state_cleanup()
|
||||
repo.index.read_tree(repo.head.peel().tree)
|
||||
repo.index.write()
|
||||
try:
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash.") # noqa: T201
|
||||
|
||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
||||
try:
|
||||
@ -66,8 +76,10 @@ if branch is None:
|
||||
try:
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
except:
|
||||
print("pulling.") # noqa: T201
|
||||
pull(repo)
|
||||
print("fetching.") # noqa: T201
|
||||
for remote in repo.remotes:
|
||||
if remote.name == "origin":
|
||||
remote.fetch()
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
repo.checkout(ref)
|
||||
branch = repo.lookup_branch('master')
|
||||
@ -149,3 +161,4 @@ try:
|
||||
shutil.copy(stable_update_script, stable_update_script_to)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
28
.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt
Executable file
28
.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt
Executable file
@ -0,0 +1,28 @@
|
||||
As of the time of writing this you need this driver for best results:
|
||||
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
|
||||
|
||||
HOW TO RUN:
|
||||
|
||||
If you have a AMD gpu:
|
||||
|
||||
run_amd_gpu.bat
|
||||
|
||||
If you have memory issues you can try disabling the smart memory management by running comfyui with:
|
||||
|
||||
run_amd_gpu_disable_smart_memory.bat
|
||||
|
||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||
|
||||
You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors
|
||||
|
||||
|
||||
RECOMMENDED WAY TO UPDATE:
|
||||
To update the ComfyUI code: update\update_comfyui.bat
|
||||
|
||||
|
||||
TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
|
||||
In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
|
||||
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
|
||||
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
|
||||
pause
|
||||
@ -0,0 +1,3 @@
|
||||
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||
pause
|
||||
3
.ci/windows_nvidia_base_files/run_nvidia_gpu.bat
Executable file
3
.ci/windows_nvidia_base_files/run_nvidia_gpu.bat
Executable file
@ -0,0 +1,3 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||
pause
|
||||
@ -0,0 +1,3 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
|
||||
pause
|
||||
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -8,13 +8,15 @@ body:
|
||||
Before submitting a **Bug Report**, please ensure the following:
|
||||
|
||||
- **1:** You are running the latest version of ComfyUI.
|
||||
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
|
||||
- **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
|
||||
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
|
||||
`--disable-all-custom-nodes` command line argument.
|
||||
`--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
|
||||
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
|
||||
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
## Very Important
|
||||
|
||||
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
|
||||
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
Normal file
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
<!-- API_NODE_PR_CHECKLIST: do not remove -->
|
||||
|
||||
## API Node PR Checklist
|
||||
|
||||
### Scope
|
||||
- [ ] **Is API Node Change**
|
||||
|
||||
### Pricing & Billing
|
||||
- [ ] **Need pricing update**
|
||||
- [ ] **No pricing update**
|
||||
|
||||
If **Need pricing update**:
|
||||
- [ ] Metronome rate cards updated
|
||||
- [ ] Auto‑billing tests updated and passing
|
||||
|
||||
### QA
|
||||
- [ ] **QA done**
|
||||
- [ ] **QA not required**
|
||||
|
||||
### Comms
|
||||
- [ ] Informed **Kosinkadink**
|
||||
58
.github/workflows/api-node-template.yml
vendored
Normal file
58
.github/workflows/api-node-template.yml
vendored
Normal file
@ -0,0 +1,58 @@
|
||||
name: Append API Node PR template
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, reopened, synchronize, ready_for_review]
|
||||
paths:
|
||||
- 'comfy_api_nodes/**' # only run if these files changed
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
inject:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Ensure template exists and append to PR body
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const { owner, repo } = context.repo;
|
||||
const number = context.payload.pull_request.number;
|
||||
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
|
||||
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
|
||||
|
||||
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
|
||||
|
||||
let templateText;
|
||||
try {
|
||||
const res = await github.rest.repos.getContent({
|
||||
owner,
|
||||
repo,
|
||||
path: templatePath,
|
||||
ref: pr.base.ref
|
||||
});
|
||||
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
|
||||
templateText = buf.toString('utf8');
|
||||
} catch (e) {
|
||||
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Enforce the presence of the marker inside the template (for idempotence)
|
||||
if (!templateText.includes(marker)) {
|
||||
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// If the PR already contains the marker, do not append again.
|
||||
const body = pr.body || '';
|
||||
if (body.includes(marker)) {
|
||||
core.info('Template already present in PR body; nothing to inject.');
|
||||
return;
|
||||
}
|
||||
|
||||
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
|
||||
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
|
||||
core.notice('API Node template appended to PR description.');
|
||||
78
.github/workflows/release-stable-all.yml
vendored
Normal file
78
.github/workflows/release-stable-all.yml
vendored
Normal file
@ -0,0 +1,78 @@
|
||||
name: "Release Stable All Portable Versions"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_tag:
|
||||
description: 'Git tag'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
release_nvidia_default:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA Default (cu130)"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu130"
|
||||
python_minor: "13"
|
||||
python_patch: "9"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_nvidia_cu128:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA cu128"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu128"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: "_cu128"
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_nvidia_cu126:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA cu126"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu126"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: "_cu126"
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_amd_rocm:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release AMD ROCm 7.1.1"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "rocm711"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "amd"
|
||||
rel_extra_name: ""
|
||||
test_release: false
|
||||
secrets: inherit
|
||||
25
.github/workflows/ruff.yml
vendored
25
.github/workflows/ruff.yml
vendored
@ -21,3 +21,28 @@ jobs:
|
||||
|
||||
- name: Run Ruff
|
||||
run: ruff check .
|
||||
|
||||
pylint:
|
||||
name: Run Pylint
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Install Pylint
|
||||
run: pip install pylint
|
||||
|
||||
- name: Run Pylint
|
||||
run: pylint comfy_api_nodes
|
||||
|
||||
100
.github/workflows/stable-release.yml
vendored
100
.github/workflows/stable-release.yml
vendored
@ -2,17 +2,17 @@
|
||||
name: "Release Stable Version"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
inputs:
|
||||
git_tag:
|
||||
description: 'Git tag'
|
||||
required: true
|
||||
type: string
|
||||
cu:
|
||||
description: 'CUDA version'
|
||||
cache_tag:
|
||||
description: 'Cached dependencies tag'
|
||||
required: true
|
||||
type: string
|
||||
default: "129"
|
||||
default: "cu129"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
@ -23,7 +23,57 @@ on:
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
|
||||
rel_name:
|
||||
description: 'Release name'
|
||||
required: true
|
||||
type: string
|
||||
default: "nvidia"
|
||||
rel_extra_name:
|
||||
description: 'Release extra name'
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
test_release:
|
||||
description: 'Test Release'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_tag:
|
||||
description: 'Git tag'
|
||||
required: true
|
||||
type: string
|
||||
cache_tag:
|
||||
description: 'Cached dependencies tag'
|
||||
required: true
|
||||
type: string
|
||||
default: "cu129"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "13"
|
||||
python_patch:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
rel_name:
|
||||
description: 'Release name'
|
||||
required: true
|
||||
type: string
|
||||
default: "nvidia"
|
||||
rel_extra_name:
|
||||
description: 'Release extra name'
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
test_release:
|
||||
description: 'Test Release'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
jobs:
|
||||
package_comfy_windows:
|
||||
@ -42,15 +92,15 @@ jobs:
|
||||
id: cache
|
||||
with:
|
||||
path: |
|
||||
cu${{ inputs.cu }}_python_deps.tar
|
||||
${{ inputs.cache_tag }}_python_deps.tar
|
||||
update_comfyui_and_python_dependencies.bat
|
||||
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
|
||||
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
|
||||
- shell: bash
|
||||
run: |
|
||||
mv cu${{ inputs.cu }}_python_deps.tar ../
|
||||
mv ${{ inputs.cache_tag }}_python_deps.tar ../
|
||||
mv update_comfyui_and_python_dependencies.bat ../
|
||||
cd ..
|
||||
tar xf cu${{ inputs.cu }}_python_deps.tar
|
||||
tar xf ${{ inputs.cache_tag }}_python_deps.tar
|
||||
pwd
|
||||
ls
|
||||
|
||||
@ -65,12 +115,19 @@ jobs:
|
||||
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
|
||||
|
||||
grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
|
||||
./python.exe -s -m pip install -r requirements_comfyui.txt
|
||||
rm requirements_comfyui.txt
|
||||
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
fi
|
||||
|
||||
cd ..
|
||||
|
||||
@ -85,14 +142,18 @@ jobs:
|
||||
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
|
||||
cp ../update_comfyui_and_python_dependencies.bat ./update/
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
|
||||
|
||||
- shell: bash
|
||||
if: ${{ inputs.test_release }}
|
||||
run: |
|
||||
cd ..
|
||||
cd ComfyUI_windows_portable
|
||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||
|
||||
@ -101,10 +162,9 @@ jobs:
|
||||
ls
|
||||
|
||||
- name: Upload binaries to release
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
file: ComfyUI_windows_portable_nvidia.7z
|
||||
tag: ${{ inputs.git_tag }}
|
||||
overwrite: true
|
||||
files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
|
||||
tag_name: ${{ inputs.git_tag }}
|
||||
draft: true
|
||||
overwrite_files: true
|
||||
|
||||
2
.github/workflows/test-build.yml
vendored
2
.github/workflows/test-build.yml
vendored
@ -18,7 +18,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
||||
21
.github/workflows/test-ci.yml
vendored
21
.github/workflows/test-ci.yml
vendored
@ -5,6 +5,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- release/**
|
||||
paths-ignore:
|
||||
- 'app/**'
|
||||
- 'input/**'
|
||||
@ -21,14 +22,15 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# os: [macos, linux, windows]
|
||||
os: [macos, linux]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
# os: [macos, linux]
|
||||
os: [linux]
|
||||
python_version: ["3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["stable"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
# - os: macos
|
||||
# runner_label: [self-hosted, macOS]
|
||||
# flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
@ -73,14 +75,15 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos, linux]
|
||||
# os: [macos, linux]
|
||||
os: [linux]
|
||||
python_version: ["3.11"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
# - os: macos
|
||||
# runner_label: [self-hosted, macOS]
|
||||
# flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
|
||||
30
.github/workflows/test-execution.yml
vendored
Normal file
30
.github/workflows/test-execution.yml
vendored
Normal file
@ -0,0 +1,30 @@
|
||||
name: Execution Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master, release/** ]
|
||||
pull_request:
|
||||
branches: [ main, master, release/** ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install -r tests-unit/requirements.txt
|
||||
- name: Run Execution Tests
|
||||
run: |
|
||||
python -m pytest tests/execution -v --skip-timing-checks
|
||||
10
.github/workflows/test-launch.yml
vendored
10
.github/workflows/test-launch.yml
vendored
@ -2,9 +2,9 @@ name: Test server launches without errors
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
@ -13,7 +13,7 @@ jobs:
|
||||
- name: Checkout ComfyUI
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: "comfyanonymous/ComfyUI"
|
||||
repository: "Comfy-Org/ComfyUI"
|
||||
path: "ComfyUI"
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
@ -32,7 +32,9 @@ jobs:
|
||||
working-directory: ComfyUI
|
||||
- name: Check for unhandled exceptions in server log
|
||||
run: |
|
||||
if grep -qE "Exception|Error" console_output.log; then
|
||||
grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': True, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" console_output.log | grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': False, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" > console_output_filtered.log
|
||||
cat console_output_filtered.log
|
||||
if grep -qE "Exception|Error" console_output_filtered.log; then
|
||||
echo "Unhandled exception/error found in server log."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
6
.github/workflows/test-unit.yml
vendored
6
.github/workflows/test-unit.yml
vendored
@ -2,15 +2,15 @@ name: Unit Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
os: [ubuntu-latest, windows-2022, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
|
||||
59
.github/workflows/update-ci-container.yml
vendored
Normal file
59
.github/workflows/update-ci-container.yml
vendored
Normal file
@ -0,0 +1,59 @@
|
||||
name: "CI: Update CI Container"
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'ComfyUI version (e.g., v0.7.0)'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
update-ci-container:
|
||||
runs-on: ubuntu-latest
|
||||
# Skip pre-releases unless manually triggered
|
||||
if: github.event_name == 'workflow_dispatch' || !github.event.release.prerelease
|
||||
steps:
|
||||
- name: Get version
|
||||
id: version
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "release" ]; then
|
||||
VERSION="${{ github.event.release.tag_name }}"
|
||||
else
|
||||
VERSION="${{ inputs.version }}"
|
||||
fi
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Checkout comfyui-ci-container
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: comfy-org/comfyui-ci-container
|
||||
token: ${{ secrets.CI_CONTAINER_PAT }}
|
||||
|
||||
- name: Check current version
|
||||
id: current
|
||||
run: |
|
||||
CURRENT=$(grep -oP 'ARG COMFYUI_VERSION=\K.*' Dockerfile || echo "unknown")
|
||||
echo "current_version=$CURRENT" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Update Dockerfile
|
||||
run: |
|
||||
VERSION="${{ steps.version.outputs.version }}"
|
||||
sed -i "s/^ARG COMFYUI_VERSION=.*/ARG COMFYUI_VERSION=${VERSION}/" Dockerfile
|
||||
|
||||
- name: Create Pull Request
|
||||
id: create-pr
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.CI_CONTAINER_PAT }}
|
||||
branch: automation/comfyui-${{ steps.version.outputs.version }}
|
||||
title: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}"
|
||||
body: |
|
||||
Updates ComfyUI version from `${{ steps.current.outputs.current_version }}` to `${{ steps.version.outputs.version }}`
|
||||
|
||||
**Triggered by:** ${{ github.event_name == 'release' && format('[Release {0}]({1})', github.event.release.tag_name, github.event.release.html_url) || 'Manual workflow dispatch' }}
|
||||
|
||||
labels: automation
|
||||
commit-message: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}"
|
||||
1
.github/workflows/update-version.yml
vendored
1
.github/workflows/update-version.yml
vendored
@ -6,6 +6,7 @@ on:
|
||||
- "pyproject.toml"
|
||||
branches:
|
||||
- master
|
||||
- release/**
|
||||
|
||||
jobs:
|
||||
update-version:
|
||||
|
||||
@ -17,7 +17,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "129"
|
||||
default: "130"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
default: "9"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@ -56,7 +56,8 @@ jobs:
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||
pause" > update_comfyui_and_python_dependencies.bat
|
||||
|
||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
|
||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||
echo installed basic
|
||||
ls -lah temp_wheel_dir
|
||||
|
||||
64
.github/workflows/windows_release_dependencies_manual.yml
vendored
Normal file
64
.github/workflows/windows_release_dependencies_manual.yml
vendored
Normal file
@ -0,0 +1,64 @@
|
||||
name: "Windows Release dependencies Manual"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
torch_dependencies:
|
||||
description: 'torch dependencies'
|
||||
required: false
|
||||
type: string
|
||||
default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128"
|
||||
cache_tag:
|
||||
description: 'Cached dependencies tag'
|
||||
required: true
|
||||
type: string
|
||||
default: "cu128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
|
||||
jobs:
|
||||
build_dependencies:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
|
||||
|
||||
- shell: bash
|
||||
run: |
|
||||
echo "@echo off
|
||||
call update_comfyui.bat nopause
|
||||
echo -
|
||||
echo This will try to update pytorch and all python dependencies.
|
||||
echo -
|
||||
echo If you just want to update normally, close this and run update_comfyui.bat instead.
|
||||
echo -
|
||||
pause
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2
|
||||
pause" > update_comfyui_and_python_dependencies.bat
|
||||
|
||||
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
|
||||
python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||
echo installed basic
|
||||
ls -lah temp_wheel_dir
|
||||
mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps
|
||||
tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps
|
||||
|
||||
- uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
${{ inputs.cache_tag }}_python_deps.tar
|
||||
update_comfyui_and_python_dependencies.bat
|
||||
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
|
||||
@ -68,7 +68,7 @@ jobs:
|
||||
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
|
||||
|
||||
echo "call update_comfyui.bat nopause
|
||||
|
||||
@ -81,7 +81,7 @@ jobs:
|
||||
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
|
||||
cp ../update_comfyui_and_python_dependencies.bat ./update/
|
||||
|
||||
cd ..
|
||||
|
||||
25
CODEOWNERS
25
CODEOWNERS
@ -1,25 +1,2 @@
|
||||
# Admins
|
||||
* @comfyanonymous
|
||||
|
||||
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
|
||||
# Inlined the team members for now.
|
||||
|
||||
# Maintainers
|
||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||
/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||
* @comfyanonymous @kosinkadink @guill
|
||||
|
||||
168
QUANTIZATION.md
Normal file
168
QUANTIZATION.md
Normal file
@ -0,0 +1,168 @@
|
||||
# The Comfy guide to Quantization
|
||||
|
||||
|
||||
## How does quantization work?
|
||||
|
||||
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
|
||||
|
||||
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
|
||||
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
|
||||
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
|
||||
|
||||
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
|
||||
|
||||
```
|
||||
absmax = max(abs(tensor))
|
||||
scale = amax / max_dynamic_range_low_precision
|
||||
|
||||
# Quantization
|
||||
tensor_q = (tensor / scale).to(low_precision_dtype)
|
||||
|
||||
# De-Quantization
|
||||
tensor_dq = tensor_q.to(fp16) * scale
|
||||
|
||||
tensor_dq ~ tensor
|
||||
```
|
||||
|
||||
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
|
||||
|
||||
|
||||
## Quantization in Comfy
|
||||
|
||||
```
|
||||
QuantizedTensor (torch.Tensor subclass)
|
||||
↓ __torch_dispatch__
|
||||
Two-Level Registry (generic + layout handlers)
|
||||
↓
|
||||
MixedPrecisionOps + Metadata Detection
|
||||
```
|
||||
|
||||
### Representation
|
||||
|
||||
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
|
||||
|
||||
A `Layout` class defines how a specific quantization format behaves:
|
||||
- Required parameters
|
||||
- Quantize method
|
||||
- De-Quantize method
|
||||
|
||||
```python
|
||||
from comfy.quant_ops import QuantizedLayout
|
||||
|
||||
class MyLayout(QuantizedLayout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, **kwargs):
|
||||
# Convert to quantized format
|
||||
qdata = ...
|
||||
params = {'scale': ..., 'orig_dtype': tensor.dtype}
|
||||
return qdata, params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
return qdata.to(orig_dtype) * scale
|
||||
```
|
||||
|
||||
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
|
||||
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
|
||||
|
||||
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
|
||||
```python
|
||||
from comfy.quant_ops import register_layout_op
|
||||
|
||||
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
|
||||
def my_linear(func, args, kwargs):
|
||||
# Extract tensors, call optimized kernel
|
||||
...
|
||||
```
|
||||
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
|
||||
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
|
||||
|
||||
|
||||
### Mixed Precision
|
||||
|
||||
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
|
||||
|
||||
**Architecture:**
|
||||
|
||||
```python
|
||||
class MixedPrecisionOps(disable_weight_init):
|
||||
_layer_quant_config = {} # Maps layer names to quantization configs
|
||||
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
|
||||
```
|
||||
|
||||
**Key mechanism:**
|
||||
|
||||
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
|
||||
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
|
||||
- If the layer name **is** in `_layer_quant_config`:
|
||||
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
|
||||
- Load associated quantization parameters (scales, block_size, etc.)
|
||||
|
||||
**Why it's needed:**
|
||||
|
||||
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
|
||||
|
||||
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
|
||||
|
||||
|
||||
## Checkpoint Format
|
||||
|
||||
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
|
||||
|
||||
The quantized checkpoint will contain the same layers as the original checkpoint but:
|
||||
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
|
||||
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
|
||||
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
|
||||
|
||||
### Scaling Parameters details
|
||||
We define 4 possible scaling parameters that should cover most recipes in the near-future:
|
||||
- **weight_scale**: quantization scalers for the weights
|
||||
- **weight_scale_2**: global scalers in the context of double scaling
|
||||
- **pre_quant_scale**: scalers used for smoothing salient weights
|
||||
- **input_scale**: quantization scalers for the activations
|
||||
|
||||
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|
||||
|--------|---------------|--------------|----------------|-----------------|-------------|
|
||||
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
|
||||
|
||||
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
|
||||
|
||||
### Quantization Metadata
|
||||
|
||||
The metadata stored alongside the checkpoint contains:
|
||||
- **format_version**: String to define a version of the standard
|
||||
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"_quantization_metadata": {
|
||||
"format_version": "1.0",
|
||||
"layers": {
|
||||
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
|
||||
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
|
||||
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Creating Quantized Checkpoints
|
||||
|
||||
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
|
||||
|
||||
### Weight Quantization
|
||||
|
||||
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
|
||||
|
||||
### Calibration (for Activation Quantization)
|
||||
|
||||
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
|
||||
|
||||
1. **Collect statistics**: Run inference on N representative samples
|
||||
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
|
||||
3. **Compute scales**: Derive `input_scale` from collected statistics
|
||||
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
|
||||
|
||||
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
||||
101
README.md
101
README.md
@ -65,20 +65,23 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
||||
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
||||
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
|
||||
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
|
||||
- [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
||||
- [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
|
||||
- Audio Models
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
@ -112,10 +115,14 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
## Release Process
|
||||
|
||||
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0)
|
||||
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
||||
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
|
||||
- Minor versions will be used for releases off the master branch.
|
||||
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
|
||||
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
||||
- Serves as the foundation for the desktop release
|
||||
|
||||
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
||||
@ -172,10 +179,20 @@ There is a portable standalone build for Windows that should work for running on
|
||||
|
||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
|
||||
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
#### Alternative Downloads:
|
||||
|
||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||
|
||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
|
||||
|
||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
|
||||
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
@ -191,7 +208,13 @@ comfy install
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet.
|
||||
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
|
||||
### Instructions:
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
@ -200,18 +223,36 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
||||
Put your VAE in: models/vae
|
||||
|
||||
|
||||
### AMD GPUs (Linux only)
|
||||
### AMD GPUs (Linux)
|
||||
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||
|
||||
|
||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
||||
|
||||
These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware.
|
||||
|
||||
RDNA 3 (RX 7000 series):
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/```
|
||||
|
||||
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/```
|
||||
|
||||
RDNA 4 (RX 9000 series):
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/```
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch xpu, use the following command:
|
||||
|
||||
@ -221,19 +262,15 @@ This is the command to install the Pytorch xpu nightly which might have some per
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||
|
||||
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
||||
|
||||
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
### NVIDIA
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu129```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu130```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
@ -264,12 +301,6 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
|
||||
|
||||
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
|
||||
|
||||
#### DirectML (AMD Cards on Windows)
|
||||
|
||||
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
|
||||
|
||||
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
||||
|
||||
#### Ascend NPUs
|
||||
|
||||
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:
|
||||
@ -294,6 +325,32 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
||||
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
|
||||
## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
|
||||
|
||||
**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
|
||||
|
||||
### Setup
|
||||
|
||||
1. Install the manager dependencies:
|
||||
```bash
|
||||
pip install -r manager_requirements.txt
|
||||
```
|
||||
|
||||
2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
|
||||
```bash
|
||||
python main.py --enable-manager
|
||||
```
|
||||
|
||||
### Command Line Options
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||
|
||||
|
||||
# Running
|
||||
|
||||
```python main.py```
|
||||
|
||||
174
alembic_db/versions/0001_assets.py
Normal file
174
alembic_db/versions/0001_assets.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""
|
||||
Initial assets schema
|
||||
Revision ID: 0001_assets
|
||||
Revises: None
|
||||
Create Date: 2025-12-10 00:00:00
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0001_assets"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ASSETS: content identity
|
||||
op.create_table(
|
||||
"assets",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("hash", sa.String(length=256), nullable=True),
|
||||
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("mime_type", sa.String(length=255), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
|
||||
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
|
||||
|
||||
# ASSETS_INFO: user-visible references
|
||||
op.create_table(
|
||||
"assets_info",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||
sa.Column("name", sa.String(length=512), nullable=False),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
)
|
||||
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
|
||||
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
|
||||
|
||||
# TAGS: normalized tag vocabulary
|
||||
op.create_table(
|
||||
"tags",
|
||||
sa.Column("name", sa.String(length=512), primary_key=True),
|
||||
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
|
||||
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
|
||||
)
|
||||
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
|
||||
|
||||
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
|
||||
op.create_table(
|
||||
"asset_info_tags",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
|
||||
)
|
||||
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||
|
||||
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
|
||||
op.create_table(
|
||||
"asset_cache_state",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
|
||||
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
|
||||
|
||||
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
|
||||
op.create_table(
|
||||
"asset_info_meta",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("key", sa.String(length=256), nullable=False),
|
||||
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||
sa.Column("val_json", sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
|
||||
)
|
||||
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
|
||||
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
|
||||
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||
|
||||
# Tags vocabulary
|
||||
tags_table = sa.table(
|
||||
"tags",
|
||||
sa.column("name", sa.String(length=512)),
|
||||
sa.column("tag_type", sa.String()),
|
||||
)
|
||||
op.bulk_insert(
|
||||
tags_table,
|
||||
[
|
||||
{"name": "models", "tag_type": "system"},
|
||||
{"name": "input", "tag_type": "system"},
|
||||
{"name": "output", "tag_type": "system"},
|
||||
|
||||
{"name": "configs", "tag_type": "system"},
|
||||
{"name": "checkpoints", "tag_type": "system"},
|
||||
{"name": "loras", "tag_type": "system"},
|
||||
{"name": "vae", "tag_type": "system"},
|
||||
{"name": "text_encoders", "tag_type": "system"},
|
||||
{"name": "diffusion_models", "tag_type": "system"},
|
||||
{"name": "clip_vision", "tag_type": "system"},
|
||||
{"name": "style_models", "tag_type": "system"},
|
||||
{"name": "embeddings", "tag_type": "system"},
|
||||
{"name": "diffusers", "tag_type": "system"},
|
||||
{"name": "vae_approx", "tag_type": "system"},
|
||||
{"name": "controlnet", "tag_type": "system"},
|
||||
{"name": "gligen", "tag_type": "system"},
|
||||
{"name": "upscale_models", "tag_type": "system"},
|
||||
{"name": "hypernetworks", "tag_type": "system"},
|
||||
{"name": "photomaker", "tag_type": "system"},
|
||||
{"name": "classifiers", "tag_type": "system"},
|
||||
|
||||
{"name": "encoder", "tag_type": "system"},
|
||||
{"name": "decoder", "tag_type": "system"},
|
||||
|
||||
{"name": "missing", "tag_type": "system"},
|
||||
{"name": "rescan", "tag_type": "system"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||
op.drop_table("asset_info_meta")
|
||||
|
||||
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
|
||||
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_table("asset_cache_state")
|
||||
|
||||
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
|
||||
op.drop_table("asset_info_tags")
|
||||
|
||||
op.drop_index("ix_tags_tag_type", table_name="tags")
|
||||
op.drop_table("tags")
|
||||
|
||||
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||
op.drop_table("assets_info")
|
||||
|
||||
op.drop_index("uq_assets_hash", table_name="assets")
|
||||
op.drop_index("ix_assets_mime_type", table_name="assets")
|
||||
op.drop_table("assets")
|
||||
@ -58,8 +58,13 @@ class InternalRoutes:
|
||||
return web.json_response({"error": "Invalid directory type"}, status=400)
|
||||
|
||||
directory = get_directory_by_type(directory_type)
|
||||
|
||||
def is_visible_file(entry: os.DirEntry) -> bool:
|
||||
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
|
||||
return entry.is_file() and not entry.name.startswith('.')
|
||||
|
||||
sorted_files = sorted(
|
||||
(entry for entry in os.scandir(directory) if entry.is_file()),
|
||||
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
|
||||
key=lambda entry: -entry.stat().st_mtime
|
||||
)
|
||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
||||
|
||||
102
app/assets/api/routes.py
Normal file
102
app/assets/api/routes.py
Normal file
@ -0,0 +1,102 @@
|
||||
import logging
|
||||
import uuid
|
||||
from aiohttp import web
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
import app.assets.manager as manager
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in
|
||||
from app.assets.helpers import get_query_dict
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
|
||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||
global USER_MANAGER
|
||||
USER_MANAGER = user_manager_instance
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
|
||||
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
|
||||
|
||||
|
||||
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to list assets.
|
||||
"""
|
||||
query_dict = get_query_dict(request)
|
||||
try:
|
||||
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_QUERY", ve)
|
||||
|
||||
payload = manager.list_assets(
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=q.sort,
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def get_asset(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to get an asset's info as JSON.
|
||||
"""
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
result = manager.get_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as e:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"get_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to list all tags based on query parameters.
|
||||
"""
|
||||
query_map = dict(request.rel_url.query)
|
||||
|
||||
try:
|
||||
query = schemas_in.TagsListQuery.model_validate(query_map)
|
||||
except ValidationError as e:
|
||||
return web.json_response(
|
||||
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}},
|
||||
status=400,
|
||||
)
|
||||
|
||||
result = manager.list_tags(
|
||||
prefix=query.prefix,
|
||||
limit=query.limit,
|
||||
offset=query.offset,
|
||||
order=query.order,
|
||||
include_zero=query.include_zero,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
94
app/assets/api/schemas_in.py
Normal file
94
app/assets/api/schemas_in.py
Normal file
@ -0,0 +1,94 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: str | None = None
|
||||
|
||||
# Accept either a JSON string (query param) or a dict
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
|
||||
order: Literal["asc", "desc"] = "desc"
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [t.strip() for t in v.split(",") if t.strip()]
|
||||
if isinstance(v, list):
|
||||
out: list[str] = []
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v
|
||||
if isinstance(v, str) and v.strip():
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("metadata_filter must be a JSON object")
|
||||
return parsed
|
||||
return None
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
prefix: str | None = Field(None, min_length=1, max_length=256)
|
||||
limit: int = Field(100, ge=1, le=1000)
|
||||
offset: int = Field(0, ge=0, le=10_000_000)
|
||||
order: Literal["count_desc", "name_asc"] = "count_desc"
|
||||
include_zero: bool = True
|
||||
|
||||
@field_validator("prefix")
|
||||
@classmethod
|
||||
def normalize_prefix(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return v
|
||||
v = v.strip()
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: str | None = None
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
60
app/assets/api/schemas_out.py
Normal file
60
app/assets/api/schemas_out.py
Normal file
@ -0,0 +1,60 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
|
||||
|
||||
class AssetSummary(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
preview_url: str | None = None
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "updated_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetsList(BaseModel):
|
||||
assets: list[AssetSummary]
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
preview_id: str | None = None
|
||||
created_at: datetime | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
type: str
|
||||
|
||||
|
||||
class TagsList(BaseModel):
|
||||
tags: list[TagUsage] = Field(default_factory=list)
|
||||
total: int
|
||||
has_more: bool
|
||||
204
app/assets/database/bulk_ops.py
Normal file
204
app/assets/database/bulk_ops.py
Normal file
@ -0,0 +1,204 @@
|
||||
import os
|
||||
import uuid
|
||||
import sqlalchemy
|
||||
from typing import Iterable
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
|
||||
if not rows:
|
||||
return []
|
||||
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
|
||||
for i in range(0, len(rows), rows_per_stmt):
|
||||
yield rows[i:i + rows_per_stmt]
|
||||
|
||||
def _iter_chunks(seq, n: int):
|
||||
for i in range(0, len(seq), n):
|
||||
yield seq[i:i + n]
|
||||
|
||||
def _rows_per_stmt(cols: int) -> int:
|
||||
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||
|
||||
|
||||
def seed_from_paths_batch(
|
||||
session: Session,
|
||||
*,
|
||||
specs: list[dict],
|
||||
owner_id: str = "",
|
||||
) -> dict:
|
||||
"""Each spec is a dict with keys:
|
||||
- abs_path: str
|
||||
- size_bytes: int
|
||||
- mtime_ns: int
|
||||
- info_name: str
|
||||
- tags: list[str]
|
||||
- fname: Optional[str]
|
||||
"""
|
||||
if not specs:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
|
||||
|
||||
now = utcnow()
|
||||
asset_rows: list[dict] = []
|
||||
state_rows: list[dict] = []
|
||||
path_to_asset: dict[str, str] = {}
|
||||
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
|
||||
path_list: list[str] = []
|
||||
|
||||
for sp in specs:
|
||||
ap = os.path.abspath(sp["abs_path"])
|
||||
aid = str(uuid.uuid4())
|
||||
iid = str(uuid.uuid4())
|
||||
path_list.append(ap)
|
||||
path_to_asset[ap] = aid
|
||||
|
||||
asset_rows.append(
|
||||
{
|
||||
"id": aid,
|
||||
"hash": None,
|
||||
"size_bytes": sp["size_bytes"],
|
||||
"mime_type": None,
|
||||
"created_at": now,
|
||||
}
|
||||
)
|
||||
state_rows.append(
|
||||
{
|
||||
"asset_id": aid,
|
||||
"file_path": ap,
|
||||
"mtime_ns": sp["mtime_ns"],
|
||||
}
|
||||
)
|
||||
asset_to_info[aid] = {
|
||||
"id": iid,
|
||||
"owner_id": owner_id,
|
||||
"name": sp["info_name"],
|
||||
"asset_id": aid,
|
||||
"preview_id": None,
|
||||
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
"_tags": sp["tags"],
|
||||
"_filename": sp["fname"],
|
||||
}
|
||||
|
||||
# insert all seed Assets (hash=NULL)
|
||||
ins_asset = sqlite.insert(Asset)
|
||||
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
|
||||
session.execute(ins_asset, chunk)
|
||||
|
||||
# try to claim AssetCacheState (file_path)
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
|
||||
ins_state = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
||||
session.execute(ins_state, chunk)
|
||||
|
||||
# Query to find which of our paths won (were actually inserted)
|
||||
winners_by_path: set[str] = set()
|
||||
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetCacheState.file_path)
|
||||
.where(AssetCacheState.file_path.in_(chunk))
|
||||
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
|
||||
)
|
||||
winners_by_path.update(result.scalars().all())
|
||||
|
||||
all_paths_set = set(path_list)
|
||||
losers_by_path = all_paths_set - winners_by_path
|
||||
lost_assets = [path_to_asset[p] for p in losers_by_path]
|
||||
if lost_assets: # losers get their Asset removed
|
||||
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
|
||||
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
|
||||
|
||||
if not winners_by_path:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
||||
|
||||
# insert AssetInfo only for winners
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
|
||||
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
||||
ins_info = (
|
||||
sqlite.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
)
|
||||
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
||||
session.execute(ins_info, chunk)
|
||||
|
||||
# Query to find which info rows were actually inserted (by matching our generated IDs)
|
||||
all_info_ids = [row["id"] for row in winner_info_rows]
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
|
||||
)
|
||||
inserted_info_ids.update(result.scalars().all())
|
||||
|
||||
# build and insert tag + meta rows for the AssetInfo
|
||||
tag_rows: list[dict] = []
|
||||
meta_rows: list[dict] = []
|
||||
if inserted_info_ids:
|
||||
for row in winner_info_rows:
|
||||
iid = row["id"]
|
||||
if iid not in inserted_info_ids:
|
||||
continue
|
||||
for t in row["_tags"]:
|
||||
tag_rows.append({
|
||||
"asset_info_id": iid,
|
||||
"tag_name": t,
|
||||
"origin": "automatic",
|
||||
"added_at": now,
|
||||
})
|
||||
if row["_filename"]:
|
||||
meta_rows.append(
|
||||
{
|
||||
"asset_info_id": iid,
|
||||
"key": "filename",
|
||||
"ordinal": 0,
|
||||
"val_str": row["_filename"],
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
|
||||
return {
|
||||
"inserted_infos": len(inserted_info_ids),
|
||||
"won_states": len(winners_by_path),
|
||||
"lost_states": len(losers_by_path),
|
||||
}
|
||||
|
||||
|
||||
def bulk_insert_tags_and_meta(
|
||||
session: Session,
|
||||
*,
|
||||
tag_rows: list[dict],
|
||||
meta_rows: list[dict],
|
||||
max_bind_params: int,
|
||||
) -> None:
|
||||
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
|
||||
- tag_rows keys: asset_info_id, tag_name, origin, added_at
|
||||
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
|
||||
"""
|
||||
if tag_rows:
|
||||
ins_links = (
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
|
||||
session.execute(ins_links, chunk)
|
||||
if meta_rows:
|
||||
ins_meta = (
|
||||
sqlite.insert(AssetInfoMeta)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
||||
)
|
||||
)
|
||||
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
|
||||
session.execute(ins_meta, chunk)
|
||||
233
app/assets/database/models.py
Normal file
233
app/assets/database/models.py
Normal file
@ -0,0 +1,233 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Any
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.database.models import to_dict, Base
|
||||
|
||||
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[str | None] = mapped_column(String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
infos: Mapped[list[AssetInfo]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
|
||||
foreign_keys=lambda: [AssetInfo.asset_id],
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
preview_of: Mapped[list[AssetInfo]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
|
||||
foreign_keys=lambda: [AssetInfo.preview_id],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
cache_states: Mapped[list[AssetCacheState]] = relationship(
|
||||
back_populates="asset",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_assets_hash", "hash", unique=True),
|
||||
Index("ix_assets_mime_type", "mime_type"),
|
||||
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
|
||||
|
||||
|
||||
class AssetCacheState(Base):
|
||||
__tablename__ = "asset_cache_state"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
asset: Mapped[Asset] = relationship(back_populates="cache_states")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||
Index("ix_asset_cache_state_asset_id", "asset_id"),
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
|
||||
|
||||
|
||||
class AssetInfo(Base):
|
||||
__tablename__ = "assets_info"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
|
||||
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
|
||||
asset: Mapped[Asset] = relationship(
|
||||
"Asset",
|
||||
back_populates="infos",
|
||||
foreign_keys=[asset_id],
|
||||
lazy="selectin",
|
||||
)
|
||||
preview_asset: Mapped[Asset | None] = relationship(
|
||||
"Asset",
|
||||
back_populates="preview_of",
|
||||
foreign_keys=[preview_id],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
tag_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
overlaps="tags,asset_infos",
|
||||
)
|
||||
|
||||
tags: Mapped[list[Tag]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="asset_infos",
|
||||
lazy="selectin",
|
||||
viewonly=True,
|
||||
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
Index("ix_assets_info_owner_name", "owner_id", "name"),
|
||||
Index("ix_assets_info_owner_id", "owner_id"),
|
||||
Index("ix_assets_info_asset_id", "asset_id"),
|
||||
Index("ix_assets_info_name", "name"),
|
||||
Index("ix_assets_info_created_at", "created_at"),
|
||||
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
data = to_dict(self, include_none=include_none)
|
||||
data["tags"] = [t.name for t in self.tags]
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
|
||||
|
||||
|
||||
class AssetInfoMeta(Base):
|
||||
__tablename__ = "asset_info_meta"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
|
||||
|
||||
val_str: Mapped[str | None] = mapped_column(String(2048), nullable=True)
|
||||
val_num: Mapped[float | None] = mapped_column(Numeric(38, 10), nullable=True)
|
||||
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
|
||||
|
||||
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_meta_key", "key"),
|
||||
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
|
||||
)
|
||||
|
||||
|
||||
class AssetInfoTag(Base):
|
||||
__tablename__ = "asset_info_tags"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
tag_name: Mapped[str] = mapped_column(
|
||||
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
|
||||
tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(512), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||
back_populates="tag",
|
||||
overlaps="asset_infos,tags",
|
||||
)
|
||||
asset_infos: Mapped[list[AssetInfo]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="tags",
|
||||
viewonly=True,
|
||||
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_tags_tag_type", "tag_type"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
267
app/assets/database/queries.py
Normal file
267
app/assets/database/queries.py
Normal file
@ -0,0 +1,267 @@
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from sqlalchemy import select, exists, func
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import escape_like_prefix, normalize_tags
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_info_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetInfoMeta.val_json.is_(None),
|
||||
AssetInfoMeta.val_str.is_(None),
|
||||
AssetInfoMeta.val_num.is_(None),
|
||||
AssetInfoMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float)):
|
||||
from decimal import Decimal
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
row = (
|
||||
session.execute(
|
||||
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(sa.func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int((session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
infos = (session.execute(base)).unique().scalars().all()
|
||||
|
||||
id_list: list[str] = [i.id for i in infos]
|
||||
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset, list[str]] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset, Tag.name)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = (session.execute(stmt)).all()
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
first_info, first_asset, _ = rows[0]
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _info, _asset, tag_name in rows:
|
||||
if tag_name and tag_name not in seen:
|
||||
seen.add(tag_name)
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetInfoTag)
|
||||
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
total_q = total_q.where(
|
||||
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||
)
|
||||
|
||||
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
62
app/assets/database/tags.py
Normal file
62
app/assets/database/tags.py
Normal file
@ -0,0 +1,62 @@
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import normalize_tags, utcnow
|
||||
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
return session.execute(ins)
|
||||
|
||||
def add_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> None:
|
||||
select_rows = (
|
||||
sqlalchemy.select(
|
||||
AssetInfo.id.label("asset_info_id"),
|
||||
sqlalchemy.literal("missing").label("tag_name"),
|
||||
sqlalchemy.literal(origin).label("origin"),
|
||||
sqlalchemy.literal(utcnow()).label("added_at"),
|
||||
)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.where(
|
||||
sqlalchemy.not_(
|
||||
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
|
||||
)
|
||||
)
|
||||
)
|
||||
session.execute(
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.from_select(
|
||||
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sqlalchemy.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
75
app/assets/hashing.py
Normal file
75
app/assets/hashing.py
Normal file
@ -0,0 +1,75 @@
|
||||
from blake3 import blake3
|
||||
from typing import IO
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
|
||||
|
||||
# NOTE: this allows hashing different representations of a file-like object
|
||||
def blake3_hash(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""
|
||||
Returns a BLAKE3 hex digest for ``fp``, which may be:
|
||||
- a filename (str/bytes) or PathLike
|
||||
- an open binary file object
|
||||
If ``fp`` is a file object, it must be opened in **binary** mode and support
|
||||
``read``, ``seek``, and ``tell``. The function will seek to the start before
|
||||
reading and will attempt to restore the original position afterward.
|
||||
"""
|
||||
# duck typing to check if input is a file-like object
|
||||
if hasattr(fp, "read"):
|
||||
return _hash_file_obj(fp, chunk_size)
|
||||
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
|
||||
async def blake3_hash_async(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Async wrapper for ``blake3_hash_sync``.
|
||||
Uses a worker thread so the event loop remains responsive.
|
||||
"""
|
||||
# If it is a path, open inside the worker thread to keep I/O off the loop.
|
||||
if hasattr(fp, "read"):
|
||||
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
|
||||
|
||||
def _worker() -> str:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
return await asyncio.to_thread(_worker)
|
||||
|
||||
|
||||
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
|
||||
"""
|
||||
Hash an already-open binary file object by streaming in chunks.
|
||||
- Seeks to the beginning before reading (if supported).
|
||||
- Restores the original position afterward (if tell/seek are supported).
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
# in case file object is already open and not at the beginning, track so can be restored after hashing
|
||||
orig_pos = file_obj.tell()
|
||||
|
||||
try:
|
||||
# seek to the beginning before reading
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(0)
|
||||
|
||||
h = blake3()
|
||||
while True:
|
||||
chunk = file_obj.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
finally:
|
||||
# restore original position in file object, if needed
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(orig_pos)
|
||||
217
app/assets/helpers.py
Normal file
217
app/assets/helpers.py
Normal file
@ -0,0 +1,217 @@
|
||||
import contextlib
|
||||
import os
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal, Any
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||
|
||||
def get_query_dict(request: web.Request) -> dict[str, Any]:
|
||||
"""
|
||||
Gets a dictionary of query parameters from the request.
|
||||
|
||||
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
|
||||
"""
|
||||
query_dict = {
|
||||
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
|
||||
for key in request.query.keys()
|
||||
}
|
||||
return query_dict
|
||||
|
||||
def list_tree(base_dir: str) -> list[str]:
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
|
||||
def prefixes_for_root(root: RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
|
||||
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
"""Escapes %, _ and the escape char itself in a LIKE prefix.
|
||||
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
|
||||
"""
|
||||
s = s.replace(escape, escape + escape) # escape the escape char first
|
||||
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
|
||||
return s, escape
|
||||
|
||||
def fast_asset_file_check(
|
||||
*,
|
||||
mtime_db: int | None,
|
||||
size_db: int | None,
|
||||
stat_result: os.stat_result,
|
||||
) -> bool:
|
||||
if mtime_db is None:
|
||||
return False
|
||||
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
|
||||
if int(mtime_db) != int(actual_mtime_ns):
|
||||
return False
|
||||
sz = int(size_db or 0)
|
||||
if sz > 0:
|
||||
return int(stat_result.st_size) == sz
|
||||
return True
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||
|
||||
We trust `folder_paths.folder_names_and_paths` and include a category if
|
||||
*any* of its base paths lies under the Comfy `models_dir`.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
models_root = os.path.abspath(folder_paths.models_dir)
|
||||
for name, values in folder_paths.folder_names_and_paths.items():
|
||||
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
For non-model paths, returns None.
|
||||
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
- 'output' if the file resides under `folder_paths.get_output_directory()`
|
||||
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
For 'models', the relative path is prefixed with the category name:
|
||||
e.g. ('models', 'vae/test/sub/ae.safetensors')
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _is_within(child: str, parent: str) -> bool:
|
||||
try:
|
||||
return os.path.commonpath([child, parent]) == parent
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _rel(child: str, parent: str) -> str:
|
||||
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _is_within(fp_abs, input_base):
|
||||
return "input", _rel(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _is_within(fp_abs, output_base):
|
||||
return "output", _rel(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return a tuple (name, tags) derived from a filesystem path.
|
||||
|
||||
Semantics:
|
||||
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
|
||||
- The returned `name` is the base filename with extension from the relative path.
|
||||
- The returned `tags` are:
|
||||
[root_category] + parent folders of the relative path (in order)
|
||||
For 'models', this means:
|
||||
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
|
||||
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
|
||||
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
|
||||
def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
"""
|
||||
Normalize a list of tags by:
|
||||
- Stripping whitespace and converting to lowercase.
|
||||
- Removing duplicates.
|
||||
"""
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
with contextlib.suppress(Exception):
|
||||
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
123
app/assets/manager.py
Normal file
123
app/assets/manager.py
Normal file
@ -0,0 +1,123 @@
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
)
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
if not requested:
|
||||
return "created_at"
|
||||
v = requested.lower()
|
||||
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||
return v
|
||||
return "created_at"
|
||||
|
||||
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
def list_assets(
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetsList:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
|
||||
with create_session() as session:
|
||||
infos, tag_map, total = list_asset_infos_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries: list[schemas_out.AssetSummary] = []
|
||||
for info in infos:
|
||||
asset = info.asset
|
||||
tags = tag_map.get(info.id, [])
|
||||
summaries.append(
|
||||
schemas_out.AssetSummary(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
preview_url=f"/api/assets/{info.id}/content",
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
)
|
||||
|
||||
return schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=total,
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
info, asset, tag_names = res
|
||||
preview_id = info.preview_id
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsList:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
with create_session() as session:
|
||||
rows, total = list_tags_with_usage(
|
||||
session,
|
||||
prefix=prefix,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
|
||||
229
app/assets/scanner.py
Normal file
229
app/assets/scanner.py
Normal file
@ -0,0 +1,229 @@
|
||||
import contextlib
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
import sqlalchemy
|
||||
|
||||
import folder_paths
|
||||
from app.database.db import create_session, dependencies_available
|
||||
from app.assets.helpers import (
|
||||
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
|
||||
list_tree,prefixes_for_root, escape_like_prefix,
|
||||
RootType
|
||||
)
|
||||
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
|
||||
from app.assets.database.bulk_ops import seed_from_paths_batch
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
|
||||
|
||||
|
||||
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
|
||||
"""
|
||||
Scan the given roots and seed the assets into the database.
|
||||
"""
|
||||
if not dependencies_available():
|
||||
if enable_logging:
|
||||
logging.warning("Database dependencies not available, skipping assets scan")
|
||||
return
|
||||
t_start = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
for r in roots:
|
||||
try:
|
||||
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
|
||||
if survivors:
|
||||
existing_paths.update(survivors)
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_output_directory()))
|
||||
|
||||
specs: list[dict] = []
|
||||
tag_pool: set[str] = set()
|
||||
for p in paths:
|
||||
abs_p = os.path.abspath(p)
|
||||
if abs_p in existing_paths:
|
||||
skipped_existing += 1
|
||||
continue
|
||||
try:
|
||||
stat_p = os.stat(abs_p, follow_symlinks=False)
|
||||
except OSError:
|
||||
continue
|
||||
# skip empty files
|
||||
if not stat_p.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": abs_p,
|
||||
"size_bytes": stat_p.st_size,
|
||||
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": compute_relative_filename(abs_p),
|
||||
}
|
||||
)
|
||||
for t in tags:
|
||||
tag_pool.add(t)
|
||||
# if no file specs, nothing to do
|
||||
if not specs:
|
||||
return
|
||||
with create_session() as sess:
|
||||
if tag_pool:
|
||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
|
||||
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
|
||||
created += result["inserted_infos"]
|
||||
sess.commit()
|
||||
finally:
|
||||
if enable_logging:
|
||||
logging.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_start,
|
||||
created,
|
||||
skipped_existing,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
|
||||
def _fast_db_consistency_pass(
|
||||
root: RootType,
|
||||
*,
|
||||
collect_existing_paths: bool = False,
|
||||
update_missing_tags: bool = False,
|
||||
) -> set[str] | None:
|
||||
"""Fast DB+FS pass for a root:
|
||||
- Toggle needs_verify per state using fast check
|
||||
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
|
||||
- For seed assets with all states missing: delete Asset and its AssetInfos
|
||||
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
||||
- Optionally return surviving absolute paths
|
||||
"""
|
||||
prefixes = prefixes_for_root(root)
|
||||
if not prefixes:
|
||||
return set() if collect_existing_paths else None
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
with create_session() as sess:
|
||||
rows = (
|
||||
sess.execute(
|
||||
sqlalchemy.select(
|
||||
AssetCacheState.id,
|
||||
AssetCacheState.file_path,
|
||||
AssetCacheState.mtime_ns,
|
||||
AssetCacheState.needs_verify,
|
||||
AssetCacheState.asset_id,
|
||||
Asset.hash,
|
||||
Asset.size_bytes,
|
||||
)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sqlalchemy.or_(*conds))
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
)
|
||||
).all()
|
||||
|
||||
by_asset: dict[str, dict] = {}
|
||||
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
|
||||
by_asset[aid] = acc
|
||||
|
||||
fast_ok = False
|
||||
try:
|
||||
exists = True
|
||||
fast_ok = fast_asset_file_check(
|
||||
mtime_db=mtime_db,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(fp, follow_symlinks=True),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except OSError:
|
||||
exists = False
|
||||
|
||||
acc["states"].append({
|
||||
"sid": sid,
|
||||
"fp": fp,
|
||||
"exists": exists,
|
||||
"fast_ok": fast_ok,
|
||||
"needs_verify": bool(needs_verify),
|
||||
})
|
||||
|
||||
to_set_verify: list[int] = []
|
||||
to_clear_verify: list[int] = []
|
||||
stale_state_ids: list[int] = []
|
||||
survivors: set[str] = set()
|
||||
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
states = acc["states"]
|
||||
any_fast_ok = any(s["fast_ok"] for s in states)
|
||||
all_missing = all(not s["exists"] for s in states)
|
||||
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
continue
|
||||
if s["fast_ok"] and s["needs_verify"]:
|
||||
to_clear_verify.append(s["sid"])
|
||||
if not s["fast_ok"] and not s["needs_verify"]:
|
||||
to_set_verify.append(s["sid"])
|
||||
|
||||
if a_hash is None:
|
||||
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
|
||||
asset = sess.get(Asset, aid)
|
||||
if asset:
|
||||
sess.delete(asset)
|
||||
else:
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
continue
|
||||
|
||||
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
stale_state_ids.append(s["sid"])
|
||||
if update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
elif update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
|
||||
if stale_state_ids:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
|
||||
if to_set_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_set_verify))
|
||||
.values(needs_verify=True)
|
||||
)
|
||||
if to_clear_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_clear_verify))
|
||||
.values(needs_verify=False)
|
||||
)
|
||||
sess.commit()
|
||||
return survivors if collect_existing_paths else None
|
||||
@ -1,14 +1,21 @@
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
Base = declarative_base()
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||
fields = obj.__table__.columns.keys()
|
||||
return {
|
||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||
for field in fields
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
out: dict[str, Any] = {}
|
||||
for field in fields:
|
||||
val = getattr(obj, field)
|
||||
if val is None and not include_none:
|
||||
continue
|
||||
if isinstance(val, datetime):
|
||||
out[field] = val.isoformat()
|
||||
else:
|
||||
out[field] = val
|
||||
return out
|
||||
|
||||
# TODO: Define models here
|
||||
|
||||
@ -10,7 +10,8 @@ import importlib
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict, Optional
|
||||
from typing import Dict, TypedDict, Optional
|
||||
from aiohttp import web
|
||||
from importlib.metadata import version
|
||||
|
||||
import requests
|
||||
@ -42,6 +43,7 @@ def get_installed_frontend_version():
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
return frontend_version_str
|
||||
|
||||
|
||||
def get_required_frontend_version():
|
||||
"""Get the required frontend version from requirements.txt."""
|
||||
try:
|
||||
@ -63,6 +65,7 @@ def get_required_frontend_version():
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def check_frontend_version():
|
||||
"""Check if the frontend version is up to date."""
|
||||
|
||||
@ -203,6 +206,37 @@ class FrontendManager:
|
||||
"""Get the required frontend package version."""
|
||||
return get_required_frontend_version()
|
||||
|
||||
@classmethod
|
||||
def get_installed_templates_version(cls) -> str:
|
||||
"""Get the currently installed workflow templates package version."""
|
||||
try:
|
||||
templates_version_str = version("comfyui-workflow-templates")
|
||||
return templates_version_str
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_required_templates_version(cls) -> str:
|
||||
"""Get the required workflow templates version from requirements.txt."""
|
||||
try:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("comfyui-workflow-templates=="):
|
||||
version_str = line.split("==")[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
|
||||
return None
|
||||
return version_str
|
||||
logging.error("comfyui-workflow-templates not found in requirements.txt")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("requirements.txt not found. Cannot determine required templates version.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
try:
|
||||
@ -224,7 +258,54 @@ comfyui-frontend-package is not installed.
|
||||
sys.exit(-1)
|
||||
|
||||
@classmethod
|
||||
def templates_path(cls) -> str:
|
||||
def template_asset_map(cls) -> Optional[Dict[str, str]]:
|
||||
"""Return a mapping of template asset names to their absolute paths."""
|
||||
try:
|
||||
from comfyui_workflow_templates import (
|
||||
get_asset_path,
|
||||
iter_templates,
|
||||
)
|
||||
except ImportError:
|
||||
logging.error(
|
||||
f"""
|
||||
********** ERROR ***********
|
||||
|
||||
comfyui-workflow-templates is not installed.
|
||||
|
||||
{frontend_install_warning_message()}
|
||||
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
template_entries = list(iter_templates())
|
||||
except Exception as exc:
|
||||
logging.error(f"Failed to enumerate workflow templates: {exc}")
|
||||
return None
|
||||
|
||||
asset_map: Dict[str, str] = {}
|
||||
try:
|
||||
for entry in template_entries:
|
||||
for asset in entry.assets:
|
||||
asset_map[asset.filename] = get_asset_path(
|
||||
entry.template_id, asset.filename
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.error(f"Failed to resolve template asset paths: {exc}")
|
||||
return None
|
||||
|
||||
if not asset_map:
|
||||
logging.error("No workflow template assets found. Did the packages install correctly?")
|
||||
return None
|
||||
|
||||
return asset_map
|
||||
|
||||
|
||||
@classmethod
|
||||
def legacy_templates_path(cls) -> Optional[str]:
|
||||
"""Return the legacy templates directory shipped inside the meta package."""
|
||||
try:
|
||||
import comfyui_workflow_templates
|
||||
|
||||
@ -243,6 +324,7 @@ comfyui-workflow-templates is not installed.
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def embedded_docs_path(cls) -> str:
|
||||
@ -359,3 +441,17 @@ comfyui-workflow-templates is not installed.
|
||||
logging.info("Falling back to the default frontend.")
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
@classmethod
|
||||
def template_asset_handler(cls):
|
||||
assets = cls.template_asset_map()
|
||||
if not assets:
|
||||
return None
|
||||
|
||||
async def serve_template(request: web.Request) -> web.StreamResponse:
|
||||
rel_path = request.match_info.get("path", "")
|
||||
target = assets.get(rel_path)
|
||||
if target is None:
|
||||
raise web.HTTPNotFound()
|
||||
return web.FileResponse(target)
|
||||
|
||||
return serve_template
|
||||
|
||||
@ -44,7 +44,7 @@ class ModelFileManager:
|
||||
@routes.get("/experiment/models/{folder}")
|
||||
async def get_all_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
if folder not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = self.get_model_file_list(folder)
|
||||
return web.json_response(files)
|
||||
@ -55,7 +55,7 @@ class ModelFileManager:
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
filename = request.match_info.get("filename", None)
|
||||
|
||||
if not folder_name in folder_paths.folder_names_and_paths:
|
||||
if folder_name not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
|
||||
132
app/subgraph_manager.py
Normal file
132
app/subgraph_manager.py
Normal file
@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
import os
|
||||
import folder_paths
|
||||
import glob
|
||||
from aiohttp import web
|
||||
import hashlib
|
||||
|
||||
|
||||
class Source:
|
||||
custom_node = "custom_node"
|
||||
templates = "templates"
|
||||
|
||||
class SubgraphEntry(TypedDict):
|
||||
source: str
|
||||
"""
|
||||
Source of subgraph - custom_nodes vs templates.
|
||||
"""
|
||||
path: str
|
||||
"""
|
||||
Relative path of the subgraph file.
|
||||
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
|
||||
"""
|
||||
name: str
|
||||
"""
|
||||
Name of subgraph file.
|
||||
"""
|
||||
info: CustomNodeSubgraphEntryInfo
|
||||
"""
|
||||
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
|
||||
"""
|
||||
data: str
|
||||
|
||||
class CustomNodeSubgraphEntryInfo(TypedDict):
|
||||
node_pack: str
|
||||
"""Node pack name."""
|
||||
|
||||
class SubgraphManager:
|
||||
def __init__(self):
|
||||
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
|
||||
self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None
|
||||
|
||||
def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]:
|
||||
"""Create a subgraph entry from a file path. Expects normalized path (forward slashes)."""
|
||||
entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
||||
entry: SubgraphEntry = {
|
||||
"source": source,
|
||||
"name": os.path.splitext(os.path.basename(file))[0],
|
||||
"path": file,
|
||||
"info": {"node_pack": node_pack},
|
||||
}
|
||||
return entry_id, entry
|
||||
|
||||
async def load_entry_data(self, entry: SubgraphEntry):
|
||||
with open(entry['path'], 'r') as f:
|
||||
entry['data'] = f.read()
|
||||
return entry
|
||||
|
||||
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
|
||||
if entry is None:
|
||||
return None
|
||||
entry = entry.copy()
|
||||
entry.pop('path', None)
|
||||
if remove_data:
|
||||
entry.pop('data', None)
|
||||
return entry
|
||||
|
||||
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
|
||||
entries = entries.copy()
|
||||
for key in list(entries.keys()):
|
||||
entries[key] = await self.sanitize_entry(entries[key], remove_data)
|
||||
return entries
|
||||
|
||||
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
|
||||
"""Load subgraphs from custom nodes."""
|
||||
if not force_reload and self.cached_custom_node_subgraphs is not None:
|
||||
return self.cached_custom_node_subgraphs
|
||||
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||
pattern = os.path.join(folder, "*/subgraphs/*.json")
|
||||
for file in glob.glob(pattern):
|
||||
file = file.replace('\\', '/')
|
||||
node_pack = "custom_nodes." + file.split('/')[-3]
|
||||
entry_id, entry = self._create_entry(file, Source.custom_node, node_pack)
|
||||
subgraphs_dict[entry_id] = entry
|
||||
|
||||
self.cached_custom_node_subgraphs = subgraphs_dict
|
||||
return subgraphs_dict
|
||||
|
||||
async def get_blueprint_subgraphs(self, force_reload=False):
|
||||
"""Load subgraphs from the blueprints directory."""
|
||||
if not force_reload and self.cached_blueprint_subgraphs is not None:
|
||||
return self.cached_blueprint_subgraphs
|
||||
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints')
|
||||
|
||||
if os.path.exists(blueprints_dir):
|
||||
for file in glob.glob(os.path.join(blueprints_dir, "*.json")):
|
||||
file = file.replace('\\', '/')
|
||||
entry_id, entry = self._create_entry(file, Source.templates, "comfyui")
|
||||
subgraphs_dict[entry_id] = entry
|
||||
|
||||
self.cached_blueprint_subgraphs = subgraphs_dict
|
||||
return subgraphs_dict
|
||||
|
||||
async def get_all_subgraphs(self, loadedModules, force_reload=False):
|
||||
"""Get all subgraphs from all sources (custom nodes and blueprints)."""
|
||||
custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload)
|
||||
blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload)
|
||||
return {**custom_node_subgraphs, **blueprint_subgraphs}
|
||||
|
||||
async def get_subgraph(self, id: str, loadedModules):
|
||||
"""Get a specific subgraph by ID from any source."""
|
||||
entry = (await self.get_all_subgraphs(loadedModules)).get(id)
|
||||
if entry is not None and entry.get('data') is None:
|
||||
await self.load_entry_data(entry)
|
||||
return entry
|
||||
|
||||
def add_routes(self, routes, loadedModules):
|
||||
@routes.get("/global_subgraphs")
|
||||
async def get_global_subgraphs(request):
|
||||
subgraphs_dict = await self.get_all_subgraphs(loadedModules)
|
||||
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
|
||||
|
||||
@routes.get("/global_subgraphs/{id}")
|
||||
async def get_global_subgraph(request):
|
||||
id = request.match_info.get("id", None)
|
||||
subgraph = await self.get_subgraph(id, loadedModules)
|
||||
return web.json_response(await self.sanitize_entry(subgraph))
|
||||
@ -59,6 +59,9 @@ class UserManager():
|
||||
user = "default"
|
||||
if args.multi_user and "comfy-user" in request.headers:
|
||||
user = request.headers["comfy-user"]
|
||||
# Block System Users (use same error message to prevent probing)
|
||||
if user.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise KeyError("Unknown user: " + user)
|
||||
|
||||
if user not in self.users:
|
||||
raise KeyError("Unknown user: " + user)
|
||||
@ -66,15 +69,16 @@ class UserManager():
|
||||
return user
|
||||
|
||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
||||
if type == "userdata":
|
||||
root_dir = user_directory
|
||||
root_dir = folder_paths.get_user_directory()
|
||||
else:
|
||||
raise KeyError("Unknown filepath type:" + type)
|
||||
|
||||
user = self.get_request_user_id(request)
|
||||
path = user_root = os.path.abspath(os.path.join(root_dir, user))
|
||||
user_root = folder_paths.get_public_user_directory(user)
|
||||
if user_root is None:
|
||||
return None
|
||||
path = user_root
|
||||
|
||||
# prevent leaving /{type}
|
||||
if os.path.commonpath((root_dir, user_root)) != root_dir:
|
||||
@ -101,7 +105,11 @@ class UserManager():
|
||||
name = name.strip()
|
||||
if not name:
|
||||
raise ValueError("username not provided")
|
||||
if name.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise ValueError("System User prefix not allowed")
|
||||
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
|
||||
if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise ValueError("System User prefix not allowed")
|
||||
user_id = user_id + "_" + str(uuid.uuid4())
|
||||
|
||||
self.users[user_id] = name
|
||||
@ -132,7 +140,10 @@ class UserManager():
|
||||
if username in self.users.values():
|
||||
return web.json_response({"error": "Duplicate username."}, status=400)
|
||||
|
||||
user_id = self.add_user(username)
|
||||
try:
|
||||
user_id = self.add_user(username)
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
return web.json_response(user_id)
|
||||
|
||||
@routes.get("/userdata")
|
||||
@ -363,10 +374,17 @@ class UserManager():
|
||||
if not overwrite and os.path.exists(path):
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
body = await request.read()
|
||||
try:
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
except OSError as e:
|
||||
logging.warning(f"Error saving file '{path}': {e}")
|
||||
return web.Response(
|
||||
status=400,
|
||||
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
|
||||
)
|
||||
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
@ -417,7 +435,7 @@ class UserManager():
|
||||
return source
|
||||
|
||||
dest = get_user_data_path(request, check_exists=False, param="dest")
|
||||
if not isinstance(source, str):
|
||||
if not isinstance(dest, str):
|
||||
return dest
|
||||
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
|
||||
0
blueprints/put_blueprints_here
Normal file
0
blueprints/put_blueprints_here
Normal file
91
comfy/audio_encoders/audio_encoders.py
Normal file
91
comfy/audio_encoders/audio_encoders.py
Normal file
@ -0,0 +1,91 @@
|
||||
from .wav2vec2 import Wav2Vec2Model
|
||||
from .whisper import WhisperLargeV3
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.utils
|
||||
import logging
|
||||
import torchaudio
|
||||
|
||||
|
||||
class AudioEncoderModel():
|
||||
def __init__(self, config):
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
model_type = config.pop("model_type")
|
||||
model_config = dict(config)
|
||||
model_config.update({
|
||||
"dtype": self.dtype,
|
||||
"device": offload_device,
|
||||
"operations": comfy.ops.manual_cast
|
||||
})
|
||||
|
||||
if model_type == "wav2vec2":
|
||||
self.model = Wav2Vec2Model(**model_config)
|
||||
elif model_type == "whisper3":
|
||||
self.model = WhisperLargeV3(**model_config)
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def encode_audio(self, audio, sample_rate):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
|
||||
out, all_layers = self.model(audio.to(self.load_device))
|
||||
outputs = {}
|
||||
outputs["encoded_audio"] = out
|
||||
outputs["encoded_audio_all_layers"] = all_layers
|
||||
outputs["audio_samples"] = audio.shape[2]
|
||||
return outputs
|
||||
|
||||
|
||||
def load_audio_encoder_from_sd(sd, prefix=""):
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
|
||||
if "encoder.layer_norm.bias" in sd: #wav2vec2
|
||||
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
|
||||
if embed_dim == 1024:# large
|
||||
config = {
|
||||
"model_type": "wav2vec2",
|
||||
"embed_dim": 1024,
|
||||
"num_heads": 16,
|
||||
"num_layers": 24,
|
||||
"conv_norm": True,
|
||||
"conv_bias": True,
|
||||
"do_normalize": True,
|
||||
"do_stable_layer_norm": True
|
||||
}
|
||||
elif embed_dim == 768: # base
|
||||
config = {
|
||||
"model_type": "wav2vec2",
|
||||
"embed_dim": 768,
|
||||
"num_heads": 12,
|
||||
"num_layers": 12,
|
||||
"conv_norm": False,
|
||||
"conv_bias": False,
|
||||
"do_normalize": False, # chinese-wav2vec2-base has this False
|
||||
"do_stable_layer_norm": False
|
||||
}
|
||||
else:
|
||||
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
|
||||
elif "model.encoder.embed_positions.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
|
||||
config = {
|
||||
"model_type": "whisper3",
|
||||
}
|
||||
else:
|
||||
raise RuntimeError("ERROR: audio encoder not supported.")
|
||||
|
||||
audio_encoder = AudioEncoderModel(config)
|
||||
m, u = audio_encoder.load_sd(sd)
|
||||
if len(m) > 0:
|
||||
logging.warning("missing audio encoder: {}".format(m))
|
||||
if len(u) > 0:
|
||||
logging.warning("unexpected audio encoder: {}".format(u))
|
||||
|
||||
return audio_encoder
|
||||
252
comfy/audio_encoders/wav2vec2.py
Normal file
252
comfy/audio_encoders/wav2vec2.py
Normal file
@ -0,0 +1,252 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
|
||||
|
||||
class LayerNormConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||
self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
|
||||
|
||||
class LayerGroupNormConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||
self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(self.layer_norm(x))
|
||||
|
||||
class ConvNoNorm(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(x)
|
||||
|
||||
|
||||
class ConvFeatureEncoder(nn.Module):
|
||||
def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if conv_norm:
|
||||
self.conv_layers = nn.ModuleList([
|
||||
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
])
|
||||
else:
|
||||
self.conv_layers = nn.ModuleList([
|
||||
LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
for conv in self.conv_layers:
|
||||
x = conv(x)
|
||||
|
||||
return x.transpose(1, 2)
|
||||
|
||||
|
||||
class FeatureProjection(nn.Module):
|
||||
def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
|
||||
self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layer_norm(x)
|
||||
x = self.projection(x)
|
||||
return x
|
||||
|
||||
|
||||
class PositionalConvEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=groups,
|
||||
)
|
||||
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.activation = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, 2)
|
||||
x = self.conv(x)[:, :, :-1]
|
||||
x = self.activation(x)
|
||||
x = x.transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=768,
|
||||
num_heads=12,
|
||||
num_layers=12,
|
||||
mlp_ratio=4.0,
|
||||
do_stable_layer_norm=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
do_stable_layer_norm=do_stable_layer_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x = x + self.pos_conv_embed(x)
|
||||
all_x = ()
|
||||
if not self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
for layer in self.layers:
|
||||
all_x += (x,)
|
||||
x = layer(x, mask)
|
||||
if self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
all_x += (x,)
|
||||
return x, all_x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
|
||||
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
assert (mask is None) # TODO?
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
|
||||
out = optimized_attention_masked(q, k, v, self.num_heads)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
|
||||
self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.intermediate_dense(x)
|
||||
x = torch.nn.functional.gelu(x)
|
||||
x = self.output_dense(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=768,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
do_stable_layer_norm=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
residual = x
|
||||
if self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
x = self.attention(x, mask=mask)
|
||||
x = residual + x
|
||||
if not self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
return self.final_layer_norm(x + self.feed_forward(x))
|
||||
else:
|
||||
return x + self.feed_forward(self.final_layer_norm(x))
|
||||
|
||||
|
||||
class Wav2Vec2Model(nn.Module):
|
||||
"""Complete Wav2Vec 2.0 model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=1024,
|
||||
final_dim=256,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
conv_norm=True,
|
||||
conv_bias=True,
|
||||
do_normalize=True,
|
||||
do_stable_layer_norm=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
conv_dim = 512
|
||||
self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations)
|
||||
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
self.encoder = TransformerEncoder(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers,
|
||||
do_stable_layer_norm=do_stable_layer_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
def forward(self, x, mask_time_indices=None, return_dict=False):
|
||||
x = torch.mean(x, dim=1)
|
||||
|
||||
if self.do_normalize:
|
||||
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
|
||||
|
||||
features = self.feature_extractor(x)
|
||||
features = self.feature_projection(features)
|
||||
batch_size, seq_len, _ = features.shape
|
||||
|
||||
x, all_x = self.encoder(features)
|
||||
return x, all_x
|
||||
186
comfy/audio_encoders/whisper.py
Executable file
186
comfy/audio_encoders/whisper.py
Executable file
@ -0,0 +1,186 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from typing import Optional
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
import comfy.ops
|
||||
|
||||
class WhisperFeatureExtractor(nn.Module):
|
||||
def __init__(self, n_mels=128, device=None):
|
||||
super().__init__()
|
||||
self.sample_rate = 16000
|
||||
self.n_fft = 400
|
||||
self.hop_length = 160
|
||||
self.n_mels = n_mels
|
||||
self.chunk_length = 30
|
||||
self.n_samples = 480000
|
||||
|
||||
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=self.sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
norm="slaney",
|
||||
mel_scale="slaney",
|
||||
).to(device)
|
||||
|
||||
def __call__(self, audio):
|
||||
audio = torch.mean(audio, dim=1)
|
||||
batch_size = audio.shape[0]
|
||||
processed_audio = []
|
||||
|
||||
for i in range(batch_size):
|
||||
aud = audio[i]
|
||||
if aud.shape[0] > self.n_samples:
|
||||
aud = aud[:self.n_samples]
|
||||
elif aud.shape[0] < self.n_samples:
|
||||
aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
|
||||
processed_audio.append(aud)
|
||||
|
||||
audio = torch.stack(processed_audio)
|
||||
|
||||
mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)
|
||||
|
||||
log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
|
||||
log_mel_spec = (log_mel_spec + 4.0) / 4.0
|
||||
|
||||
return log_mel_spec
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert d_model % n_heads == 0
|
||||
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.d_k = d_model // n_heads
|
||||
|
||||
self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
|
||||
self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = query.shape
|
||||
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
|
||||
attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
|
||||
self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
||||
|
||||
self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
|
||||
self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
|
||||
self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x = self.self_attn(x, x, x, attention_mask)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.final_layer_norm(x)
|
||||
x = self.fc1(x)
|
||||
x = F.gelu(x)
|
||||
x = self.fc2(x)
|
||||
x = residual + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_mels: int = 128,
|
||||
n_ctx: int = 1500,
|
||||
n_state: int = 1280,
|
||||
n_head: int = 20,
|
||||
n_layer: int = 32,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
|
||||
self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)
|
||||
|
||||
self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(n_layer)
|
||||
])
|
||||
|
||||
self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
|
||||
|
||||
all_x = ()
|
||||
for layer in self.layers:
|
||||
all_x += (x,)
|
||||
x = layer(x)
|
||||
|
||||
x = self.layer_norm(x)
|
||||
all_x += (x,)
|
||||
return x, all_x
|
||||
|
||||
|
||||
class WhisperLargeV3(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_mels: int = 128,
|
||||
n_audio_ctx: int = 1500,
|
||||
n_audio_state: int = 1280,
|
||||
n_audio_head: int = 20,
|
||||
n_audio_layer: int = 32,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)
|
||||
|
||||
self.encoder = AudioEncoder(
|
||||
n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward(self, audio):
|
||||
mel = self.feature_extractor(audio)
|
||||
x, all_x = self.encoder(mel)
|
||||
return x, all_x
|
||||
@ -413,7 +413,8 @@ class ControlNet(nn.Module):
|
||||
out_middle = []
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
if y is None:
|
||||
raise ValueError("y is None, did you try using a controlnet for SDXL on SD1?")
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x
|
||||
|
||||
@ -97,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
|
||||
Latent2RGB = "latent2rgb"
|
||||
TAESD = "taesd"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str):
|
||||
for member in cls:
|
||||
if member.value == value:
|
||||
return member
|
||||
return None
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
@ -105,6 +112,7 @@ cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
@ -120,6 +128,12 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force
|
||||
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
||||
|
||||
|
||||
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
|
||||
manager_group = parser.add_mutually_exclusive_group()
|
||||
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
|
||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
|
||||
|
||||
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||
@ -130,7 +144,8 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||
|
||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
@ -143,8 +158,11 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp16Accumulation = "fp16_accumulation"
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||
@ -156,13 +174,14 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
@ -212,6 +231,7 @@ database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -1,6 +1,59 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.ops
|
||||
import math
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5):
|
||||
def scale_dim(size, scale):
|
||||
scaled = math.ceil(size * scale / patch_size) * patch_size
|
||||
return max(patch_size, int(scaled))
|
||||
|
||||
# Binary search for optimal scale
|
||||
lo, hi = eps / 10, 100.0
|
||||
while hi - lo >= eps:
|
||||
mid = (lo + hi) / 2
|
||||
h, w = scale_dim(oh, mid), scale_dim(ow, mid)
|
||||
if (h // patch_size) * (w // patch_size) <= max_num_patches:
|
||||
lo = mid
|
||||
else:
|
||||
hi = mid
|
||||
|
||||
return scale_dim(oh, lo), scale_dim(ow, lo)
|
||||
|
||||
def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True):
|
||||
if size > 0:
|
||||
return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop)
|
||||
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
|
||||
b, c, h, w = image.shape
|
||||
h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True)
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1])
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||
@ -61,8 +114,12 @@ class CLIPEncoder(torch.nn.Module):
|
||||
def forward(self, x, mask=None, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
all_intermediate = None
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
if intermediate_output == "all":
|
||||
all_intermediate = []
|
||||
intermediate_output = None
|
||||
elif intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
@ -70,6 +127,12 @@ class CLIPEncoder(torch.nn.Module):
|
||||
x = l(x, mask, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
if all_intermediate is not None:
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
|
||||
if all_intermediate is not None:
|
||||
intermediate = torch.cat(all_intermediate, dim=1)
|
||||
|
||||
return x, intermediate
|
||||
|
||||
class CLIPEmbeddings(torch.nn.Module):
|
||||
@ -97,7 +160,7 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
|
||||
if embeds is not None:
|
||||
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
|
||||
else:
|
||||
@ -146,6 +209,27 @@ class CLIPTextModel(torch.nn.Module):
|
||||
out = self.text_projection(x[2])
|
||||
return (x[0], x[1], out, x[2])
|
||||
|
||||
def siglip2_pos_embed(embed_weight, embeds, orig_shape):
|
||||
embed_weight_len = round(embed_weight.shape[0] ** 0.5)
|
||||
embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len)
|
||||
embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True)
|
||||
embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1)
|
||||
return embeds + embed_weight
|
||||
|
||||
class Siglip2Embeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
||||
self.patch_size = patch_size
|
||||
|
||||
def forward(self, pixel_values):
|
||||
b, c, h, w = pixel_values.shape
|
||||
img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c)
|
||||
img = img.permute(0, 1, 3, 2, 4, 5)
|
||||
img = img.reshape(b, img.shape[1] * img.shape[2], -1)
|
||||
img = self.patch_embedding(img)
|
||||
return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size))
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
||||
@ -189,8 +273,11 @@ class CLIPVision(torch.nn.Module):
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type == "siglip_vision_model":
|
||||
if model_type in ["siglip2_vision_model"]:
|
||||
self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type in ["siglip_vision_model", "siglip2_vision_model"]:
|
||||
self.pre_layrnorm = lambda a: a
|
||||
self.output_layernorm = True
|
||||
else:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import logging
|
||||
|
||||
@ -17,28 +16,12 @@ class Output:
|
||||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually
|
||||
|
||||
IMAGE_ENCODERS = {
|
||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
}
|
||||
|
||||
@ -50,7 +33,14 @@ class ClipVisionModel():
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
|
||||
self.model_type = config.get("model_type", "clip_vision_model")
|
||||
self.config = config.copy()
|
||||
model_class = IMAGE_ENCODERS.get(self.model_type)
|
||||
if self.model_type == "siglip_vision_model":
|
||||
self.return_all_hidden_states = True
|
||||
else:
|
||||
self.return_all_hidden_states = False
|
||||
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
@ -67,13 +57,23 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
if self.model_type == "siglip2_vision_model":
|
||||
pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
else:
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||
|
||||
outputs = Output()
|
||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||
outputs["image_sizes"] = [pixel_values.shape[1:]] * pixel_values.shape[0]
|
||||
if self.return_all_hidden_states:
|
||||
all_hs = out[1].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = all_hs[:, -2]
|
||||
outputs["all_hidden_states"] = all_hs
|
||||
else:
|
||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||
|
||||
outputs["mm_projected"] = out[3]
|
||||
return outputs
|
||||
|
||||
@ -113,10 +113,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
if embed_shape == 729:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif embed_shape == 1024:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||
patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape
|
||||
if len(patch_embedding_shape) == 2:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json")
|
||||
else:
|
||||
if embed_shape == 729:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||
elif embed_shape == 1024:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
|
||||
elif embed_shape == 577:
|
||||
if "multi_modal_projector.linear_1.bias" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
||||
@ -124,8 +128,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
elif "embeddings.patch_embeddings.projection.weight" in sd:
|
||||
|
||||
# Dinov2
|
||||
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
14
comfy/clip_vision_siglip2_base_naflex.json
Normal file
14
comfy/clip_vision_siglip2_base_naflex.json
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"num_channels": 3,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": -1,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip2_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 16,
|
||||
"num_patches": 256,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}
|
||||
@ -51,32 +51,43 @@ class ContextHandlerABC(ABC):
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.dim = dim
|
||||
self.total_frames = total_frames
|
||||
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
if dim == 0 and full.shape[dim] == 1:
|
||||
return full
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
return full[idx].to(device)
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
window = full[idx]
|
||||
if retain_index_list:
|
||||
idx = tuple([slice(None)] * dim + [retain_index_list])
|
||||
window[idx] = full[idx]
|
||||
return window.to(device)
|
||||
|
||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
full[idx] += to_add
|
||||
return full
|
||||
|
||||
def get_region_index(self, num_regions: int) -> int:
|
||||
region_idx = int(self.center_ratio * num_regions)
|
||||
return min(max(region_idx, 0), num_regions - 1)
|
||||
|
||||
|
||||
class IndexListCallbacks:
|
||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
||||
EXECUTE_START = "execute_start"
|
||||
EXECUTE_CLEANUP = "execute_cleanup"
|
||||
RESIZE_COND_ITEM = "resize_cond_item"
|
||||
|
||||
def init_callbacks(self):
|
||||
return {}
|
||||
@ -94,7 +105,8 @@ class ContextFuseMethod:
|
||||
|
||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
@ -103,13 +115,18 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
self.closed_loop = closed_loop
|
||||
self.dim = dim
|
||||
self._step = 0
|
||||
self.freenoise = freenoise
|
||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||
self.split_conds_to_windows = split_conds_to_windows
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
||||
if self.cond_retain_index_list:
|
||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -123,6 +140,11 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return None
|
||||
# reuse or resize cond items to match context requirements
|
||||
resized_cond = []
|
||||
# if multiple conds, split based on primary region
|
||||
if self.split_conds_to_windows and len(cond_in) > 1:
|
||||
region = window.get_region_index(len(cond_in))
|
||||
logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}")
|
||||
cond_in = [cond_in[region]]
|
||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||
for actual_cond in cond_in:
|
||||
resized_actual_cond = actual_cond.copy()
|
||||
@ -145,13 +167,38 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
new_cond_item = cond_item.copy()
|
||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||
for cond_key, cond_value in new_cond_item.items():
|
||||
# Allow callbacks to handle custom conditioning items
|
||||
handled = False
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(
|
||||
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
|
||||
):
|
||||
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
|
||||
if result is not None:
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
break
|
||||
if handled:
|
||||
continue
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
|
||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||
# Handle audio_embed (temporal dim is 1)
|
||||
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
audio_cond = cond_value.cond
|
||||
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
||||
# Handle vace_context (temporal dim is 3)
|
||||
elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
vace_cond = cond_value.cond
|
||||
if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim):
|
||||
sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list)
|
||||
new_cond_item[cond_key] = cond_value._copy_with(sliced_vace)
|
||||
# if has cond that is a Tensor, check if needs to be subset
|
||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
||||
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
|
||||
elif cond_key == "num_video_frames": # for SVD
|
||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||
new_cond_item[cond_key].cond = window.context_length
|
||||
@ -164,7 +211,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return resized_cond
|
||||
|
||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||
@ -173,7 +220,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
@ -250,8 +297,8 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
prev_weight = (bias_total / (bias_total + bias))
|
||||
new_weight = (bias / (bias_total + bias))
|
||||
# account for dims of tensors
|
||||
idx_window = [slice(None)] * self.dim + [idx]
|
||||
pos_window = [slice(None)] * self.dim + [pos]
|
||||
idx_window = tuple([slice(None)] * self.dim + [idx])
|
||||
pos_window = tuple([slice(None)] * self.dim + [pos])
|
||||
# apply new values
|
||||
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
||||
biases_final[i][idx] = bias_total + bias
|
||||
@ -287,6 +334,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||
)
|
||||
|
||||
|
||||
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
|
||||
model_options = extra_args.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is None:
|
||||
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
if not handler.freenoise:
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
|
||||
|
||||
def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
||||
"ContextWindows_sampler_sample",
|
||||
_sampler_sample_wrapper
|
||||
)
|
||||
|
||||
|
||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||
total_dims = len(x_in.shape)
|
||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||
@ -538,3 +607,29 @@ def shift_window_to_end(window: list[int], num_frames: int):
|
||||
for i in range(len(window)):
|
||||
# 2) add end_delta to each val to slide windows to end
|
||||
window[i] = window[i] + end_delta
|
||||
|
||||
|
||||
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
|
||||
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
|
||||
logging.info("Context windows: Applying FreeNoise")
|
||||
generator = torch.Generator(device='cpu').manual_seed(seed)
|
||||
latent_video_length = noise.shape[dim]
|
||||
delta = context_length - context_overlap
|
||||
|
||||
for start_idx in range(0, latent_video_length - context_length, delta):
|
||||
place_idx = start_idx + context_length
|
||||
|
||||
actual_delta = min(delta, latent_video_length - place_idx)
|
||||
if actual_delta <= 0:
|
||||
break
|
||||
|
||||
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
|
||||
|
||||
source_slice = [slice(None)] * noise.ndim
|
||||
source_slice[dim] = list_idx
|
||||
target_slice = [slice(None)] * noise.ndim
|
||||
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
|
||||
|
||||
noise[tuple(target_slice)] = noise[tuple(source_slice)]
|
||||
|
||||
return noise
|
||||
|
||||
@ -36,6 +36,7 @@ import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.ldm.qwen_image.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
@ -236,11 +237,11 @@ class ControlNet(ControlBase):
|
||||
self.cond_hint = None
|
||||
compression_ratio = self.compression_ratio
|
||||
if self.vae is not None:
|
||||
compression_ratio *= self.vae.downscale_ratio
|
||||
compression_ratio *= self.vae.spacial_compression_encode()
|
||||
else:
|
||||
if self.latent_format is not None:
|
||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||
if self.vae is not None:
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
@ -252,7 +253,10 @@ class ControlNet(ControlBase):
|
||||
to_concat = []
|
||||
for c in self.extra_concat_orig:
|
||||
c = c.to(self.cond_hint.device)
|
||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center")
|
||||
if c.ndim < self.cond_hint.ndim:
|
||||
c = c.unsqueeze(2)
|
||||
c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2)
|
||||
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||
|
||||
@ -306,11 +310,13 @@ class ControlLoraOps:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
x = torch.nn.functional.linear(input, weight, bias)
|
||||
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(
|
||||
@ -346,12 +352,13 @@ class ControlLoraOps:
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
else:
|
||||
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
||||
@ -582,6 +589,22 @@ def load_controlnet_flux_instantx(sd, model_options={}):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1]
|
||||
|
||||
extra_condition_channels = 0
|
||||
concat_mask = False
|
||||
if control_latent_channels == 68: #inpaint controlnet
|
||||
extra_condition_channels = control_latent_channels - 64
|
||||
concat_mask = True
|
||||
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
extra_conds = []
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
@ -655,8 +678,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
||||
else:
|
||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||
elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data:
|
||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
|
||||
144
comfy/float.py
144
comfy/float.py
@ -65,3 +65,147 @@ def stochastic_rounding(value, dtype, seed=0):
|
||||
return output
|
||||
|
||||
return value.to(dtype=dtype)
|
||||
|
||||
|
||||
# TODO: improve this?
|
||||
def stochastic_float_to_fp4_e2m1(x, generator):
|
||||
orig_shape = x.shape
|
||||
sign = torch.signbit(x).to(torch.uint8)
|
||||
|
||||
exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3)
|
||||
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25
|
||||
|
||||
x = x.abs()
|
||||
exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3)
|
||||
|
||||
mantissa = torch.where(
|
||||
exp > 0,
|
||||
(x / (2.0 ** (exp - 1)) - 1.0) * 2.0,
|
||||
(x * 2.0),
|
||||
out=x
|
||||
).round().to(torch.uint8)
|
||||
del x
|
||||
|
||||
exp = exp.to(torch.uint8)
|
||||
|
||||
fp4 = (sign << 3) | (exp << 1) | mantissa
|
||||
del sign, exp, mantissa
|
||||
|
||||
fp4_flat = fp4.view(-1)
|
||||
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
|
||||
return packed.reshape(list(orig_shape)[:-1] + [-1])
|
||||
|
||||
|
||||
def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
|
||||
See:
|
||||
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
||||
|
||||
Args:
|
||||
input_matrix: Input tensor of shape (H, W)
|
||||
Returns:
|
||||
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
|
||||
"""
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
rows, cols = input_matrix.shape
|
||||
n_row_blocks = ceil_div(rows, 128)
|
||||
n_col_blocks = ceil_div(cols, 4)
|
||||
|
||||
# Calculate the padded shape
|
||||
padded_rows = n_row_blocks * 128
|
||||
padded_cols = n_col_blocks * 4
|
||||
|
||||
padded = input_matrix
|
||||
if (rows, cols) != (padded_rows, padded_cols):
|
||||
padded = torch.zeros(
|
||||
(padded_rows, padded_cols),
|
||||
device=input_matrix.device,
|
||||
dtype=input_matrix.dtype,
|
||||
)
|
||||
padded[:rows, :cols] = input_matrix
|
||||
|
||||
# Rearrange the blocks
|
||||
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||||
if flatten:
|
||||
return rearranged.flatten()
|
||||
|
||||
return rearranged.reshape(padded_rows, padded_cols)
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator):
|
||||
F4_E2M1_MAX = 6.0
|
||||
F8_E4M3_MAX = 448.0
|
||||
|
||||
orig_shape = x.shape
|
||||
|
||||
block_size = 16
|
||||
|
||||
x = x.reshape(orig_shape[0], -1, block_size)
|
||||
scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
|
||||
x = x / (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1)
|
||||
|
||||
x = x.view(orig_shape).nan_to_num()
|
||||
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
|
||||
return data_lp, scaled_block_scales_fp8
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
|
||||
def roundup(x: int, multiple: int) -> int:
|
||||
"""Round up x to the nearest multiple."""
|
||||
return ((x + multiple - 1) // multiple) * multiple
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
# Handle padding
|
||||
if pad_16x:
|
||||
rows, cols = x.shape
|
||||
padded_rows = roundup(rows, 16)
|
||||
padded_cols = roundup(cols, 16)
|
||||
if padded_rows != rows or padded_cols != cols:
|
||||
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||
|
||||
x, blocked_scaled = stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator)
|
||||
return x, to_blocked(blocked_scaled, flatten=False)
|
||||
|
||||
|
||||
def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=0, block_size=4096 * 4096):
|
||||
def roundup(x: int, multiple: int) -> int:
|
||||
"""Round up x to the nearest multiple."""
|
||||
return ((x + multiple - 1) // multiple) * multiple
|
||||
|
||||
orig_shape = x.shape
|
||||
|
||||
# Handle padding
|
||||
if pad_16x:
|
||||
rows, cols = x.shape
|
||||
padded_rows = roundup(rows, 16)
|
||||
padded_cols = roundup(cols, 16)
|
||||
if padded_rows != rows or padded_cols != cols:
|
||||
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||
# Note: We update orig_shape because the output tensor logic below assumes x.shape matches
|
||||
# what we want to produce. If we pad here, we want the padded output.
|
||||
orig_shape = x.shape
|
||||
|
||||
orig_shape = list(orig_shape)
|
||||
|
||||
output_fp4 = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 2], dtype=torch.uint8, device=x.device)
|
||||
output_block = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 16], dtype=torch.float8_e4m3fn, device=x.device)
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
num_slices = max(1, (x.numel() / block_size))
|
||||
slice_size = max(1, (round(x.shape[0] / num_slices)))
|
||||
|
||||
for i in range(0, x.shape[0], slice_size):
|
||||
fp4, block = stochastic_round_quantize_nvfp4_block(x[i: i + slice_size], per_tensor_scale, generator=generator)
|
||||
output_fp4[i:i + slice_size].copy_(fp4)
|
||||
output_block[i:i + slice_size].copy_(block)
|
||||
|
||||
return output_fp4, to_blocked(output_block, flatten=False)
|
||||
|
||||
@ -527,7 +527,8 @@ class HookKeyframeGroup:
|
||||
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
||||
break
|
||||
# if eval_c is outside the percent range, stop looking further
|
||||
else: break
|
||||
else:
|
||||
break
|
||||
# update steps current context is used
|
||||
self._current_used_steps += 1
|
||||
# update current timestep this was performed on
|
||||
|
||||
@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||
|
||||
class Dinov2MLP(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
mlp_ratio = 4
|
||||
hidden_features = int(hidden_size * mlp_ratio)
|
||||
self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
|
||||
self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
hidden_state = self.fc1(hidden_state)
|
||||
hidden_state = torch.nn.functional.gelu(hidden_state)
|
||||
hidden_state = self.fc2(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
class SwiGLUFFN(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Block(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
|
||||
super().__init__()
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||
if use_swiglu_ffn:
|
||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||
else:
|
||||
self.mlp = Dinov2MLP(dim, dtype, device, operations)
|
||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
for _ in range(num_layers)])
|
||||
|
||||
def forward(self, x, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module):
|
||||
intermediate_output = len(self.layer) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layer):
|
||||
x = l(x, optimized_attention)
|
||||
for i, layer in enumerate(self.layer):
|
||||
x = layer(x, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module):
|
||||
dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
|
||||
|
||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
|
||||
22
comfy/image_encoders/dino2_large.json
Normal file
22
comfy/image_encoders/dino2_large.json
Normal file
@ -0,0 +1,22 @@
|
||||
{
|
||||
"hidden_size": 1024,
|
||||
"use_mask_token": true,
|
||||
"patch_size": 14,
|
||||
"image_size": 518,
|
||||
"num_channels": 3,
|
||||
"num_attention_heads": 16,
|
||||
"initializer_range": 0.02,
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"hidden_act": "gelu",
|
||||
"mlp_ratio": 4,
|
||||
"model_type": "dinov2",
|
||||
"num_hidden_layers": 24,
|
||||
"layer_norm_eps": 1e-6,
|
||||
"qkv_bias": true,
|
||||
"use_swiglu_ffn": false,
|
||||
"layerscale_value": 1.0,
|
||||
"drop_path_rate": 0.0,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225]
|
||||
}
|
||||
@ -74,6 +74,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
||||
|
||||
def default_noise_sampler(x, seed=None):
|
||||
if seed is not None:
|
||||
if x.device == torch.device("cpu"):
|
||||
seed += 1
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
else:
|
||||
@ -86,24 +89,24 @@ class BatchedBrownianTree:
|
||||
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
||||
|
||||
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
||||
self.cpu_tree = True
|
||||
if "cpu" in kwargs:
|
||||
self.cpu_tree = kwargs.pop("cpu")
|
||||
self.cpu_tree = kwargs.pop("cpu", True)
|
||||
t0, t1, self.sign = self.sort(t0, t1)
|
||||
w0 = kwargs.get('w0', torch.zeros_like(x))
|
||||
w0 = kwargs.pop('w0', None)
|
||||
if w0 is None:
|
||||
w0 = torch.zeros_like(x)
|
||||
self.batched = False
|
||||
if seed is None:
|
||||
seed = torch.randint(0, 2 ** 63 - 1, []).item()
|
||||
self.batched = True
|
||||
try:
|
||||
assert len(seed) == x.shape[0]
|
||||
seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),)
|
||||
elif isinstance(seed, (tuple, list)):
|
||||
if len(seed) != x.shape[0]:
|
||||
raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.")
|
||||
self.batched = True
|
||||
w0 = w0[0]
|
||||
except TypeError:
|
||||
seed = [seed]
|
||||
self.batched = False
|
||||
if self.cpu_tree:
|
||||
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
|
||||
else:
|
||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
||||
seed = (seed,)
|
||||
if self.cpu_tree:
|
||||
t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu()
|
||||
self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed)
|
||||
|
||||
@staticmethod
|
||||
def sort(a, b):
|
||||
@ -111,11 +114,10 @@ class BatchedBrownianTree:
|
||||
|
||||
def __call__(self, t0, t1):
|
||||
t0, t1, sign = self.sort(t0, t1)
|
||||
device, dtype = t0.device, t0.dtype
|
||||
if self.cpu_tree:
|
||||
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
|
||||
else:
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
||||
|
||||
t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float()
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign)
|
||||
return w if self.batched else w[0]
|
||||
|
||||
|
||||
@ -171,6 +173,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
|
||||
return sigmas
|
||||
|
||||
|
||||
def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the result of h*phi_1(h) in exponential integrator methods."""
|
||||
return torch.expm1(h)
|
||||
|
||||
|
||||
def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the result of h*phi_2(h) in exponential integrator methods."""
|
||||
return (torch.expm1(h) - h) / h
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
@ -853,6 +865,11 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""DPM-Solver++(3M) SDE."""
|
||||
@ -925,6 +942,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
if len(sigmas) <= 1:
|
||||
@ -1533,15 +1560,17 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
if solver_type not in {"phi_1", "phi_2"}:
|
||||
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
@ -1549,55 +1578,70 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
fac = 1 / (2 * r)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = lambda_s + r * h
|
||||
fac = 1 / (2 * r)
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
continue
|
||||
|
||||
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r)
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
|
||||
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
# 0 < r < 1
|
||||
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
|
||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised
|
||||
if inject_noise:
|
||||
sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
|
||||
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
# Step 2
|
||||
if solver_type == "phi_1":
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
elif solver_type == "phi_2":
|
||||
b2 = ei_h_phi_2(-h_eta) / r
|
||||
b1 = ei_h_phi_1(-h_eta) - b2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||
|
||||
if inject_noise:
|
||||
segment_factor = (r - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1])
|
||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"):
|
||||
"""Deterministic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
||||
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"):
|
||||
"""Stochastic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
||||
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
|
||||
arXiv: https://arxiv.org/abs/2305.14267
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
@ -1609,45 +1653,49 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = lambda_s + r_1 * h
|
||||
lambda_s_2 = lambda_s + r_2 * h
|
||||
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
||||
continue
|
||||
|
||||
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1)
|
||||
lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2)
|
||||
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
||||
|
||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
# 0 < r_1 < r_2 < 1
|
||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
|
||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised
|
||||
if inject_noise:
|
||||
sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
|
||||
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
||||
if inject_noise:
|
||||
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
# Step 2
|
||||
a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta)
|
||||
a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2
|
||||
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2)
|
||||
if inject_noise:
|
||||
segment_factor = (r_1 - r_2) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2)
|
||||
x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
|
||||
# Step 3
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
||||
# Step 3
|
||||
b3 = ei_h_phi_2(-h_eta) / r_2
|
||||
b1 = ei_h_phi_1(-h_eta) - b3
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3)
|
||||
if inject_noise:
|
||||
segment_factor = (r_2 - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@ -1731,7 +1779,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
# Predictor
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
x_pred = denoised
|
||||
else:
|
||||
tau_t = tau_func(sigmas[i + 1])
|
||||
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
||||
@ -1752,7 +1800,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
if tau_t > 0 and s_noise > 0:
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
||||
x_pred = x_pred + noise
|
||||
return x
|
||||
return x_pred
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@ -6,6 +6,7 @@ class LatentFormat:
|
||||
latent_dimensions = 2
|
||||
latent_rgb_factors = None
|
||||
latent_rgb_factors_bias = None
|
||||
latent_rgb_factors_reshape = None
|
||||
taesd_decoder_name = None
|
||||
|
||||
def process_in(self, latent):
|
||||
@ -178,6 +179,54 @@ class Flux(SD3):
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
class Flux2(LatentFormat):
|
||||
latent_channels = 128
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors =[
|
||||
[0.0058, 0.0113, 0.0073],
|
||||
[0.0495, 0.0443, 0.0836],
|
||||
[-0.0099, 0.0096, 0.0644],
|
||||
[0.2144, 0.3009, 0.3652],
|
||||
[0.0166, -0.0039, -0.0054],
|
||||
[0.0157, 0.0103, -0.0160],
|
||||
[-0.0398, 0.0902, -0.0235],
|
||||
[-0.0052, 0.0095, 0.0109],
|
||||
[-0.3527, -0.2712, -0.1666],
|
||||
[-0.0301, -0.0356, -0.0180],
|
||||
[-0.0107, 0.0078, 0.0013],
|
||||
[0.0746, 0.0090, -0.0941],
|
||||
[0.0156, 0.0169, 0.0070],
|
||||
[-0.0034, -0.0040, -0.0114],
|
||||
[0.0032, 0.0181, 0.0080],
|
||||
[-0.0939, -0.0008, 0.0186],
|
||||
[0.0018, 0.0043, 0.0104],
|
||||
[0.0284, 0.0056, -0.0127],
|
||||
[-0.0024, -0.0022, -0.0030],
|
||||
[0.1207, -0.0026, 0.0065],
|
||||
[0.0128, 0.0101, 0.0142],
|
||||
[0.0137, -0.0072, -0.0007],
|
||||
[0.0095, 0.0092, -0.0059],
|
||||
[0.0000, -0.0077, -0.0049],
|
||||
[-0.0465, -0.0204, -0.0312],
|
||||
[0.0095, 0.0012, -0.0066],
|
||||
[0.0290, -0.0034, 0.0025],
|
||||
[0.0220, 0.0169, -0.0048],
|
||||
[-0.0332, -0.0457, -0.0468],
|
||||
[-0.0085, 0.0389, 0.0609],
|
||||
[-0.0076, 0.0003, -0.0043],
|
||||
[-0.0111, -0.0460, -0.0614],
|
||||
]
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
@ -358,6 +407,11 @@ class LTXV(LatentFormat):
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||
|
||||
class LTXAV(LTXV):
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = None
|
||||
self.latent_rgb_factors_bias = None
|
||||
|
||||
class HunyuanVideo(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
@ -382,6 +436,7 @@ class HunyuanVideo(LatentFormat):
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
||||
taesd_decoder_name = "taehv"
|
||||
|
||||
class Cosmos1CV8x8x8(LatentFormat):
|
||||
latent_channels = 16
|
||||
@ -445,7 +500,7 @@ class Wan21(LatentFormat):
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
|
||||
self.taesd_decoder_name = None #TODO
|
||||
self.taesd_decoder_name = "lighttaew2_1"
|
||||
|
||||
def process_in(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
@ -516,6 +571,7 @@ class Wan22(Wan21):
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.taesd_decoder_name = "lighttaew2_2"
|
||||
self.latents_mean = torch.tensor([
|
||||
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||
@ -533,11 +589,155 @@ class Wan22(Wan21):
|
||||
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
class HunyuanImage21(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 2
|
||||
scale_factor = 0.75289
|
||||
|
||||
latent_rgb_factors = [
|
||||
[-0.0154, -0.0397, -0.0521],
|
||||
[ 0.0005, 0.0093, 0.0006],
|
||||
[-0.0805, -0.0773, -0.0586],
|
||||
[-0.0494, -0.0487, -0.0498],
|
||||
[-0.0212, -0.0076, -0.0261],
|
||||
[-0.0179, -0.0417, -0.0505],
|
||||
[ 0.0158, 0.0310, 0.0239],
|
||||
[ 0.0409, 0.0516, 0.0201],
|
||||
[ 0.0350, 0.0553, 0.0036],
|
||||
[-0.0447, -0.0327, -0.0479],
|
||||
[-0.0038, -0.0221, -0.0365],
|
||||
[-0.0423, -0.0718, -0.0654],
|
||||
[ 0.0039, 0.0368, 0.0104],
|
||||
[ 0.0655, 0.0217, 0.0122],
|
||||
[ 0.0490, 0.1638, 0.2053],
|
||||
[ 0.0932, 0.0829, 0.0650],
|
||||
[-0.0186, -0.0209, -0.0135],
|
||||
[-0.0080, -0.0076, -0.0148],
|
||||
[-0.0284, -0.0201, 0.0011],
|
||||
[-0.0642, -0.0294, -0.0777],
|
||||
[-0.0035, 0.0076, -0.0140],
|
||||
[ 0.0519, 0.0731, 0.0887],
|
||||
[-0.0102, 0.0095, 0.0704],
|
||||
[ 0.0068, 0.0218, -0.0023],
|
||||
[-0.0726, -0.0486, -0.0519],
|
||||
[ 0.0260, 0.0295, 0.0263],
|
||||
[ 0.0250, 0.0333, 0.0341],
|
||||
[ 0.0168, -0.0120, -0.0174],
|
||||
[ 0.0226, 0.1037, 0.0114],
|
||||
[ 0.2577, 0.1906, 0.1604],
|
||||
[-0.0646, -0.0137, -0.0018],
|
||||
[-0.0112, 0.0309, 0.0358],
|
||||
[-0.0347, 0.0146, -0.0481],
|
||||
[ 0.0234, 0.0179, 0.0201],
|
||||
[ 0.0157, 0.0313, 0.0225],
|
||||
[ 0.0423, 0.0675, 0.0524],
|
||||
[-0.0031, 0.0027, -0.0255],
|
||||
[ 0.0447, 0.0555, 0.0330],
|
||||
[-0.0152, 0.0103, 0.0299],
|
||||
[-0.0755, -0.0489, -0.0635],
|
||||
[ 0.0853, 0.0788, 0.1017],
|
||||
[-0.0272, -0.0294, -0.0471],
|
||||
[ 0.0440, 0.0400, -0.0137],
|
||||
[ 0.0335, 0.0317, -0.0036],
|
||||
[-0.0344, -0.0621, -0.0984],
|
||||
[-0.0127, -0.0630, -0.0620],
|
||||
[-0.0648, 0.0360, 0.0924],
|
||||
[-0.0781, -0.0801, -0.0409],
|
||||
[ 0.0363, 0.0613, 0.0499],
|
||||
[ 0.0238, 0.0034, 0.0041],
|
||||
[-0.0135, 0.0258, 0.0310],
|
||||
[ 0.0614, 0.1086, 0.0589],
|
||||
[ 0.0428, 0.0350, 0.0205],
|
||||
[ 0.0153, 0.0173, -0.0018],
|
||||
[-0.0288, -0.0455, -0.0091],
|
||||
[ 0.0344, 0.0109, -0.0157],
|
||||
[-0.0205, -0.0247, -0.0187],
|
||||
[ 0.0487, 0.0126, 0.0064],
|
||||
[-0.0220, -0.0013, 0.0074],
|
||||
[-0.0203, -0.0094, -0.0048],
|
||||
[-0.0719, 0.0429, -0.0442],
|
||||
[ 0.1042, 0.0497, 0.0356],
|
||||
[-0.0659, -0.0578, -0.0280],
|
||||
[-0.0060, -0.0322, -0.0234]]
|
||||
|
||||
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
|
||||
|
||||
class HunyuanImage21Refiner(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
|
||||
def process_in(self, latent):
|
||||
out = latent * self.scale_factor
|
||||
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
b, f_times_2, c, h, w = out.shape
|
||||
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
return out
|
||||
|
||||
def process_out(self, latent):
|
||||
z = latent / self.scale_factor
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
b, f, c, h, w = z.shape
|
||||
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
z = z[:, :, 1:]
|
||||
return z
|
||||
|
||||
class HunyuanVideo15(LatentFormat):
|
||||
latent_rgb_factors = [
|
||||
[ 0.0568, -0.0521, -0.0131],
|
||||
[ 0.0014, 0.0735, 0.0326],
|
||||
[ 0.0186, 0.0531, -0.0138],
|
||||
[-0.0031, 0.0051, 0.0288],
|
||||
[ 0.0110, 0.0556, 0.0432],
|
||||
[-0.0041, -0.0023, -0.0485],
|
||||
[ 0.0530, 0.0413, 0.0253],
|
||||
[ 0.0283, 0.0251, 0.0339],
|
||||
[ 0.0277, -0.0372, -0.0093],
|
||||
[ 0.0393, 0.0944, 0.1131],
|
||||
[ 0.0020, 0.0251, 0.0037],
|
||||
[-0.0017, 0.0012, 0.0234],
|
||||
[ 0.0468, 0.0436, 0.0203],
|
||||
[ 0.0354, 0.0439, -0.0233],
|
||||
[ 0.0090, 0.0123, 0.0346],
|
||||
[ 0.0382, 0.0029, 0.0217],
|
||||
[ 0.0261, -0.0300, 0.0030],
|
||||
[-0.0088, -0.0220, -0.0283],
|
||||
[-0.0272, -0.0121, -0.0363],
|
||||
[-0.0664, -0.0622, 0.0144],
|
||||
[ 0.0414, 0.0479, 0.0529],
|
||||
[ 0.0355, 0.0612, -0.0247],
|
||||
[ 0.0147, 0.0264, 0.0174],
|
||||
[ 0.0438, 0.0038, 0.0542],
|
||||
[ 0.0431, -0.0573, -0.0033],
|
||||
[-0.0162, -0.0211, -0.0406],
|
||||
[-0.0487, -0.0295, -0.0393],
|
||||
[ 0.0005, -0.0109, 0.0253],
|
||||
[ 0.0296, 0.0591, 0.0353],
|
||||
[ 0.0119, 0.0181, -0.0306],
|
||||
[-0.0085, -0.0362, 0.0229],
|
||||
[ 0.0005, -0.0106, 0.0242]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
taesd_decoder_name = "lighttaehy1_5"
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 0.9990943042622529
|
||||
|
||||
class Hunyuan3Dv2_1(LatentFormat):
|
||||
scale_factor = 1.0039506158752403
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
@ -546,3 +746,20 @@ class Hunyuan3Dv2mini(LatentFormat):
|
||||
class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
class ChromaRadiance(LatentFormat):
|
||||
latent_channels = 3
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
# R G B
|
||||
[ 1.0, 0.0, 0.0 ],
|
||||
[ 0.0, 1.0, 0.0 ],
|
||||
[ 0.0, 0.0, 1.0 ]
|
||||
]
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
@ -133,6 +133,7 @@ class Attention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options={},
|
||||
**cross_attention_kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
@ -140,6 +141,7 @@ class Attention(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0:
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
transformer_options={},
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0:
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = optimized_attention(
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
|
||||
).to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
temb: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
):
|
||||
|
||||
N = hidden_states.shape[0]
|
||||
@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
attn_output, _ = self.attn(
|
||||
@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
encoder_attention_mask=None,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=None,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if self.use_adaln_single:
|
||||
@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from .attention import LinearTransformerBlock, t2i_modulate
|
||||
@ -313,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
output_length: int = 0,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
transformer_options={},
|
||||
):
|
||||
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
||||
temb = self.t_block(embedded_timestep)
|
||||
@ -338,12 +340,34 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||
temb=temb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x,
|
||||
timestep,
|
||||
attention_mask=None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
lyrics_strength=1.0,
|
||||
**kwargs
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
|
||||
controlnet_scale, lyrics_strength, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
@ -371,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
|
||||
output_length = hidden_states.shape[-1]
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
output = self.decode(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
@ -380,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
output_length=output_length,
|
||||
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
||||
controlnet_scale=controlnet_scale,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@ -23,8 +23,6 @@ class MusicDCAE(torch.nn.Module):
|
||||
else:
|
||||
self.source_sample_rate = source_sample_rate
|
||||
|
||||
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Normalize(0.5, 0.5),
|
||||
])
|
||||
@ -37,10 +35,6 @@ class MusicDCAE(torch.nn.Module):
|
||||
self.scale_factor = 0.1786
|
||||
self.shift_factor = -1.9091
|
||||
|
||||
def load_audio(self, audio_path):
|
||||
audio, sr = torchaudio.load(audio_path)
|
||||
return audio, sr
|
||||
|
||||
def forward_mel(self, audios):
|
||||
mels = []
|
||||
for i in range(len(audios)):
|
||||
@ -73,10 +67,8 @@ class MusicDCAE(torch.nn.Module):
|
||||
latent = self.dcae.encoder(mel.unsqueeze(0))
|
||||
latents.append(latent)
|
||||
latents = torch.cat(latents, dim=0)
|
||||
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
|
||||
latents = (latents - self.shift_factor) * self.scale_factor
|
||||
return latents
|
||||
# return latents, latent_lengths
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, latents, audio_lengths=None, sr=None):
|
||||
@ -91,9 +83,7 @@ class MusicDCAE(torch.nn.Module):
|
||||
wav = self.vocoder.decode(mels[0]).squeeze(1)
|
||||
|
||||
if sr is not None:
|
||||
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
|
||||
wav = torchaudio.functional.resample(wav, 44100, sr)
|
||||
# wav = resampler(wav)
|
||||
else:
|
||||
sr = 44100
|
||||
pred_wavs.append(wav)
|
||||
@ -101,7 +91,6 @@ class MusicDCAE(torch.nn.Module):
|
||||
if audio_lengths is not None:
|
||||
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
|
||||
return torch.stack(pred_wavs)
|
||||
# return sr, pred_wavs
|
||||
|
||||
def forward(self, audios, audio_lengths=None, sr=None):
|
||||
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
|
||||
|
||||
@ -298,7 +298,8 @@ class Attention(nn.Module):
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
rotary_pos_emb = None,
|
||||
causal = None
|
||||
causal = None,
|
||||
transformer_options={},
|
||||
):
|
||||
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
||||
|
||||
@ -363,7 +364,7 @@ class Attention(nn.Module):
|
||||
heads_per_kv_head = h // kv_h
|
||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
|
||||
out = self.to_out(out)
|
||||
|
||||
if mask is not None:
|
||||
@ -488,7 +489,8 @@ class TransformerBlock(nn.Module):
|
||||
global_cond=None,
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
rotary_pos_emb = None
|
||||
rotary_pos_emb = None,
|
||||
transformer_options={}
|
||||
):
|
||||
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
||||
|
||||
@ -498,12 +500,12 @@ class TransformerBlock(nn.Module):
|
||||
residual = x
|
||||
x = self.pre_norm(x)
|
||||
x = x * (1 + scale_self) + shift_self
|
||||
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
|
||||
x = x * torch.sigmoid(1 - gate_self)
|
||||
x = x + residual
|
||||
|
||||
if context is not None:
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer(x)
|
||||
@ -517,10 +519,10 @@ class TransformerBlock(nn.Module):
|
||||
x = x + residual
|
||||
|
||||
else:
|
||||
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
|
||||
|
||||
if context is not None:
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer(x)
|
||||
@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module):
|
||||
return_info = False,
|
||||
**kwargs
|
||||
):
|
||||
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
context = kwargs["context"]
|
||||
|
||||
@ -632,7 +635,7 @@ class ContinuousTransformer(nn.Module):
|
||||
# Attention layers
|
||||
|
||||
if self.rotary_pos_emb is not None:
|
||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
|
||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device)
|
||||
else:
|
||||
rotary_pos_emb = None
|
||||
|
||||
@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
|
||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
|
||||
if return_info:
|
||||
|
||||
@ -9,6 +9,7 @@ import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
@ -84,7 +85,7 @@ class SingleAttention(nn.Module):
|
||||
)
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c):
|
||||
def forward(self, c, transformer_options={}):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
|
||||
@ -94,7 +95,7 @@ class SingleAttention(nn.Module):
|
||||
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
q, k = self.q_norm1(q), self.k_norm1(k)
|
||||
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
c = self.w1o(output)
|
||||
return c
|
||||
|
||||
@ -143,7 +144,7 @@ class DoubleAttention(nn.Module):
|
||||
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c, x):
|
||||
def forward(self, c, x, transformer_options={}):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
bsz, seqlen2, _ = x.shape
|
||||
@ -167,7 +168,7 @@ class DoubleAttention(nn.Module):
|
||||
torch.cat([cv, xv], dim=1),
|
||||
)
|
||||
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
c, x = output.split([seqlen1, seqlen2], dim=1)
|
||||
c = self.w1o(c)
|
||||
@ -206,7 +207,7 @@ class MMDiTBlock(nn.Module):
|
||||
self.is_last = is_last
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c, x, global_cond, **kwargs):
|
||||
def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
|
||||
|
||||
cres, xres = c, x
|
||||
|
||||
@ -224,7 +225,7 @@ class MMDiTBlock(nn.Module):
|
||||
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
|
||||
|
||||
# attention
|
||||
c, x = self.attn(c, x)
|
||||
c, x = self.attn(c, x, transformer_options=transformer_options)
|
||||
|
||||
|
||||
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
|
||||
@ -254,13 +255,13 @@ class DiTBlock(nn.Module):
|
||||
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, cx, global_cond, **kwargs):
|
||||
def forward(self, cx, global_cond, transformer_options={}, **kwargs):
|
||||
cxres = cx
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
|
||||
global_cond
|
||||
).chunk(6, dim=1)
|
||||
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
|
||||
cx = self.attn(cx)
|
||||
cx = self.attn(cx, transformer_options=transformer_options)
|
||||
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
|
||||
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
|
||||
cx = gate_mlp.unsqueeze(1) * mlpout
|
||||
@ -436,6 +437,13 @@ class MMDiT(nn.Module):
|
||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||
|
||||
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
# patchify x, add PE
|
||||
b, c, h, w = x.shape
|
||||
@ -465,13 +473,14 @@ class MMDiT(nn.Module):
|
||||
out = {}
|
||||
out["txt"], out["img"] = layer(args["txt"],
|
||||
args["img"],
|
||||
args["vec"])
|
||||
args["vec"],
|
||||
transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
c = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
|
||||
|
||||
if len(self.single_layers) > 0:
|
||||
c_len = c.size(1)
|
||||
@ -480,13 +489,13 @@ class MMDiT(nn.Module):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], args["vec"])
|
||||
out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
cx = out["img"]
|
||||
else:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
|
||||
|
||||
x = cx[:, c_len:]
|
||||
|
||||
|
||||
@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
|
||||
|
||||
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
def forward(self, q, k, v, transformer_options={}):
|
||||
q = self.to_q(q)
|
||||
k = self.to_k(k)
|
||||
v = self.to_v(v)
|
||||
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
|
||||
|
||||
return self.out_proj(out)
|
||||
|
||||
@ -47,13 +47,13 @@ class Attention2D(nn.Module):
|
||||
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
||||
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, kv, self_attn=False):
|
||||
def forward(self, x, kv, self_attn=False, transformer_options={}):
|
||||
orig_shape = x.shape
|
||||
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
||||
if self_attn:
|
||||
kv = torch.cat([x, kv], dim=1)
|
||||
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
||||
x = self.attn(x, kv, kv)
|
||||
x = self.attn(x, kv, kv, transformer_options=transformer_options)
|
||||
x = x.permute(0, 2, 1).view(*orig_shape)
|
||||
return x
|
||||
|
||||
@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
|
||||
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, kv):
|
||||
def forward(self, x, kv, transformer_options={}):
|
||||
kv = self.kv_mapper(kv)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ -173,7 +173,7 @@ class StageB(nn.Module):
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip):
|
||||
def _down_encode(self, x, r_embed, clip, transformer_options={}):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
@ -187,7 +187,7 @@ class StageB(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@ -199,7 +199,7 @@ class StageB(nn.Module):
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip):
|
||||
def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
@ -216,7 +216,7 @@ class StageB(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@ -228,7 +228,7 @@ class StageB(nn.Module):
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
|
||||
def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
|
||||
if pixels is None:
|
||||
pixels = x.new_zeros(x.size(0), 3, 8, 8)
|
||||
|
||||
@ -245,8 +245,8 @@ class StageB(nn.Module):
|
||||
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
|
||||
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
level_outputs = self._down_encode(x, r_embed, clip)
|
||||
x = self._up_decode(level_outputs, r_embed, clip)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
|
||||
@ -182,7 +182,7 @@ class StageC(nn.Module):
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip, cnet=None):
|
||||
def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
@ -201,7 +201,7 @@ class StageC(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@ -213,7 +213,7 @@ class StageC(nn.Module):
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
|
||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
@ -235,7 +235,7 @@ class StageC(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@ -247,7 +247,7 @@ class StageC(nn.Module):
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
|
||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs):
|
||||
# Process the conditioning embeddings
|
||||
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||
for c in self.t_conds:
|
||||
@ -262,8 +262,8 @@ class StageC(nn.Module):
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(x)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, cnet)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, cnet)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.math import attention
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
QKNorm,
|
||||
SelfAttention,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
# TODO: remove this in a few months
|
||||
SingleStreamBlock = None
|
||||
DoubleStreamBlock = None
|
||||
|
||||
|
||||
class ChromaModulationOut(ModulationOut):
|
||||
@ -48,124 +48,6 @@ class Approximator(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
||||
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
||||
|
||||
# calculate the txt bloks
|
||||
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x.addcmul_(mod.gate, output)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
@ -5,17 +5,18 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
EmbedND,
|
||||
timestep_embedding,
|
||||
DoubleStreamBlock,
|
||||
SingleStreamBlock,
|
||||
)
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
LastLayer,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
ChromaModulationOut,
|
||||
)
|
||||
@ -39,7 +40,8 @@ class ChromaParams:
|
||||
out_dim: int
|
||||
hidden_dim: int
|
||||
n_layers: int
|
||||
|
||||
txt_ids_dims: list
|
||||
vec_in_dim: int
|
||||
|
||||
|
||||
|
||||
@ -89,6 +91,7 @@ class Chroma(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@ -97,7 +100,7 @@ class Chroma(nn.Module):
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
@ -150,8 +153,6 @@ class Chroma(nn.Module):
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
@ -179,7 +180,10 @@ class Chroma(nn.Module):
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_mmdit:
|
||||
double_mod = (
|
||||
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||
@ -192,14 +196,16 @@ class Chroma(nn.Module):
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": double_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
@ -208,7 +214,8 @@ class Chroma(nn.Module):
|
||||
txt=txt,
|
||||
vec=double_mod,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@ -219,7 +226,10 @@ class Chroma(nn.Module):
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_dit:
|
||||
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
@ -228,17 +238,19 @@ class Chroma(nn.Module):
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": single_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
|
||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@ -248,16 +260,27 @@ class Chroma(nn.Module):
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
final_mod = self.get_modulations(mod_vectors, "final")
|
||||
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
if hasattr(self, "final_layer"):
|
||||
final_mod = self.get_modulations(mod_vectors, "final")
|
||||
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
|
||||
|
||||
if img.ndim != 3 or context.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
|
||||
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
|
||||
206
comfy/ldm/chroma_radiance/layers.py
Normal file
206
comfy/ldm/chroma_radiance/layers.py
Normal file
@ -0,0 +1,206 @@
|
||||
# Adapted from https://github.com/lodestone-rock/flow
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.flux.layers import RMSNorm
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
"""
|
||||
An embedder module that combines input features with a 2D positional
|
||||
encoding that mimics the Discrete Cosine Transform (DCT).
|
||||
|
||||
This module takes an input tensor of shape (B, P^2, C), where P is the
|
||||
patch size, and enriches it with positional information before projecting
|
||||
it to a new hidden size.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_size_input: int,
|
||||
max_freqs: int,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
"""
|
||||
Initializes the NerfEmbedder.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input tensor.
|
||||
hidden_size_input (int): The desired dimension of the output embedding.
|
||||
max_freqs (int): The number of frequency components to use for both
|
||||
the x and y dimensions of the positional encoding.
|
||||
The total number of positional features will be max_freqs^2.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.max_freqs = max_freqs
|
||||
self.hidden_size_input = hidden_size_input
|
||||
|
||||
# A linear layer to project the concatenated input features and
|
||||
# positional encodings to the final output dimension.
|
||||
self.embedder = nn.Sequential(
|
||||
operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=4)
|
||||
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Generates and caches 2D DCT-like positional embeddings for a given patch size.
|
||||
|
||||
The LRU cache is a performance optimization that avoids recomputing the
|
||||
same positional grid on every forward pass.
|
||||
|
||||
Args:
|
||||
patch_size (int): The side length of the square input patch.
|
||||
device: The torch device to create the tensors on.
|
||||
dtype: The torch dtype for the tensors.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (1, patch_size^2, max_freqs^2) containing the
|
||||
positional embeddings.
|
||||
"""
|
||||
# Create normalized 1D coordinate grids from 0 to 1.
|
||||
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
|
||||
# Create a 2D meshgrid of coordinates.
|
||||
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
|
||||
|
||||
# Reshape positions to be broadcastable with frequencies.
|
||||
# Shape becomes (patch_size^2, 1, 1).
|
||||
pos_x = pos_x.reshape(-1, 1, 1)
|
||||
pos_y = pos_y.reshape(-1, 1, 1)
|
||||
|
||||
# Create a 1D tensor of frequency values from 0 to max_freqs-1.
|
||||
freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
|
||||
|
||||
# Reshape frequencies to be broadcastable for creating 2D basis functions.
|
||||
# freqs_x shape: (1, max_freqs, 1)
|
||||
# freqs_y shape: (1, 1, max_freqs)
|
||||
freqs_x = freqs[None, :, None]
|
||||
freqs_y = freqs[None, None, :]
|
||||
|
||||
# A custom weighting coefficient, not part of standard DCT.
|
||||
# This seems to down-weight the contribution of higher-frequency interactions.
|
||||
coeffs = (1 + freqs_x * freqs_y) ** -1
|
||||
|
||||
# Calculate the 1D cosine basis functions for x and y coordinates.
|
||||
# This is the core of the DCT formulation.
|
||||
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
|
||||
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
|
||||
|
||||
# Combine the 1D basis functions to create 2D basis functions by element-wise
|
||||
# multiplication, and apply the custom coefficients. Broadcasting handles the
|
||||
# combination of all (pos_x, freqs_x) with all (pos_y, freqs_y).
|
||||
# The result is flattened into a feature vector for each position.
|
||||
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
|
||||
|
||||
return dct
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the embedder.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input tensor of shape (B, P^2, C).
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor of shape (B, P^2, hidden_size_input).
|
||||
"""
|
||||
# Get the batch size, number of pixels, and number of channels.
|
||||
B, P2, C = inputs.shape
|
||||
|
||||
# Infer the patch side length from the number of pixels (P^2).
|
||||
patch_size = int(P2 ** 0.5)
|
||||
|
||||
input_dtype = inputs.dtype
|
||||
inputs = inputs.to(dtype=self.dtype)
|
||||
|
||||
# Fetch the pre-computed or cached positional embeddings.
|
||||
dct = self.fetch_pos(patch_size, inputs.device, self.dtype)
|
||||
|
||||
# Repeat the positional embeddings for each item in the batch.
|
||||
dct = dct.repeat(B, 1, 1)
|
||||
|
||||
# Concatenate the original input features with the positional embeddings
|
||||
# along the feature dimension.
|
||||
inputs = torch.cat((inputs, dct), dim=-1)
|
||||
|
||||
# Project the combined tensor to the target hidden size.
|
||||
return self.embedder(inputs).to(dtype=input_dtype)
|
||||
|
||||
|
||||
class NerfGLUBlock(nn.Module):
|
||||
"""
|
||||
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
|
||||
"""
|
||||
def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
# The total number of parameters for the MLP is increased to accommodate
|
||||
# the gate, value, and output projection matrices.
|
||||
# We now need to generate parameters for 3 matrices.
|
||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_x, hidden_size_x = x.shape
|
||||
mlp_params = self.param_generator(s)
|
||||
|
||||
# Split the generated parameters into three parts for the gate, value, and output projection.
|
||||
fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1)
|
||||
|
||||
# Reshape the parameters into matrices for batch matrix multiplication.
|
||||
fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
|
||||
fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
|
||||
fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x)
|
||||
|
||||
# Normalize the generated weight matrices as in the original implementation.
|
||||
fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2)
|
||||
fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2)
|
||||
fc2 = torch.nn.functional.normalize(fc2, dim=-2)
|
||||
|
||||
res_x = x
|
||||
x = self.norm(x)
|
||||
|
||||
# Apply the final output projection.
|
||||
x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
|
||||
|
||||
return x + res_x
|
||||
|
||||
|
||||
class NerfFinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
|
||||
# So we temporarily move the channel dimension to the end for the norm operation.
|
||||
return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
|
||||
|
||||
|
||||
class NerfFinalLayerConv(nn.Module):
|
||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.conv = operations.Conv2d(
|
||||
in_channels=hidden_size,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
|
||||
# So we temporarily move the channel dimension to the end for the norm operation.
|
||||
return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))
|
||||
335
comfy/ldm/chroma_radiance/model.py
Normal file
335
comfy/ldm/chroma_radiance/model.py
Normal file
@ -0,0 +1,335 @@
|
||||
# Credits:
|
||||
# Original Flux code can be found on: https://github.com/black-forest-labs/flux
|
||||
# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
|
||||
|
||||
from comfy.ldm.chroma.model import Chroma, ChromaParams
|
||||
from comfy.ldm.chroma.layers import (
|
||||
Approximator,
|
||||
)
|
||||
from .layers import (
|
||||
NerfEmbedder,
|
||||
NerfGLUBlock,
|
||||
NerfFinalLayer,
|
||||
NerfFinalLayerConv,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChromaRadianceParams(ChromaParams):
|
||||
patch_size: int
|
||||
nerf_hidden_size: int
|
||||
nerf_mlp_ratio: int
|
||||
nerf_depth: int
|
||||
nerf_max_freqs: int
|
||||
# Setting nerf_tile_size to 0 disables tiling.
|
||||
nerf_tile_size: int
|
||||
# Currently one of linear (legacy) or conv.
|
||||
nerf_final_head_type: str
|
||||
# None means use the same dtype as the model.
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
use_x0: bool
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
if operations is None:
|
||||
raise RuntimeError("Attempt to create ChromaRadiance object without setting operations")
|
||||
nn.Module.__init__(self)
|
||||
self.dtype = dtype
|
||||
params = ChromaRadianceParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.in_dim = params.in_dim
|
||||
self.out_dim = params.out_dim
|
||||
self.hidden_dim = params.hidden_dim
|
||||
self.n_layers = params.n_layers
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in_patch = operations.Conv2d(
|
||||
params.in_channels,
|
||||
params.hidden_size,
|
||||
kernel_size=params.patch_size,
|
||||
stride=params.patch_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
# set as nn identity for now, will overwrite it later.
|
||||
self.distilled_guidance_layer = Approximator(
|
||||
in_dim=self.in_dim,
|
||||
hidden_dim=self.hidden_dim,
|
||||
out_dim=self.out_dim,
|
||||
n_layers=self.n_layers,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
# pixel channel concat with DCT
|
||||
self.nerf_image_embedder = NerfEmbedder(
|
||||
in_channels=params.in_channels,
|
||||
hidden_size_input=params.nerf_hidden_size,
|
||||
max_freqs=params.nerf_max_freqs,
|
||||
dtype=params.nerf_embedder_dtype or dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.nerf_blocks = nn.ModuleList([
|
||||
NerfGLUBlock(
|
||||
hidden_size_s=params.hidden_size,
|
||||
hidden_size_x=params.nerf_hidden_size,
|
||||
mlp_ratio=params.nerf_mlp_ratio,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
) for _ in range(params.nerf_depth)
|
||||
])
|
||||
|
||||
if params.nerf_final_head_type == "linear":
|
||||
self.nerf_final_layer = NerfFinalLayer(
|
||||
params.nerf_hidden_size,
|
||||
out_channels=params.in_channels,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
elif params.nerf_final_head_type == "conv":
|
||||
self.nerf_final_layer_conv = NerfFinalLayerConv(
|
||||
params.nerf_hidden_size,
|
||||
out_channels=params.in_channels,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
else:
|
||||
errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}"
|
||||
raise ValueError(errstr)
|
||||
|
||||
self.skip_mmdit = []
|
||||
self.skip_dit = []
|
||||
self.lite = False
|
||||
|
||||
if params.use_x0:
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
@property
|
||||
def _nerf_final_layer(self) -> nn.Module:
|
||||
if self.params.nerf_final_head_type == "linear":
|
||||
return self.nerf_final_layer
|
||||
if self.params.nerf_final_head_type == "conv":
|
||||
return self.nerf_final_layer_conv
|
||||
# Impossible to get here as we raise an error on unexpected types on initialization.
|
||||
raise NotImplementedError
|
||||
|
||||
def img_in(self, img: Tensor) -> Tensor:
|
||||
img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P]
|
||||
# flatten into a sequence for the transformer.
|
||||
return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden]
|
||||
|
||||
def forward_nerf(
|
||||
self,
|
||||
img_orig: Tensor,
|
||||
img_out: Tensor,
|
||||
params: ChromaRadianceParams,
|
||||
) -> Tensor:
|
||||
B, C, H, W = img_orig.shape
|
||||
num_patches = img_out.shape[1]
|
||||
patch_size = params.patch_size
|
||||
|
||||
# Store the raw pixel values of each patch for the NeRF head later.
|
||||
# unfold creates patches: [B, C * P * P, NumPatches]
|
||||
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
|
||||
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
|
||||
|
||||
# Reshape for per-patch processing
|
||||
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||
|
||||
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
|
||||
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
|
||||
# the tile size.
|
||||
img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||
else:
|
||||
# Get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||
|
||||
# Pass through the dynamic MLP blocks (the NeRF)
|
||||
for block in self.nerf_blocks:
|
||||
img_dct = block(img_dct, nerf_hidden)
|
||||
|
||||
# Reassemble the patches into the final image.
|
||||
img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P]
|
||||
# Reshape to combine with batch dimension for fold
|
||||
img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P]
|
||||
img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches]
|
||||
img_dct = nn.functional.fold(
|
||||
img_dct,
|
||||
output_size=(H, W),
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
)
|
||||
return self._nerf_final_layer(img_dct)
|
||||
|
||||
def forward_tiled_nerf(
|
||||
self,
|
||||
nerf_hidden: Tensor,
|
||||
nerf_pixels: Tensor,
|
||||
batch: int,
|
||||
channels: int,
|
||||
num_patches: int,
|
||||
patch_size: int,
|
||||
params: ChromaRadianceParams,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Processes the NeRF head in tiles to save memory.
|
||||
nerf_hidden has shape [B, L, D]
|
||||
nerf_pixels has shape [B, L, C * P * P]
|
||||
"""
|
||||
tile_size = params.nerf_tile_size
|
||||
output_tiles = []
|
||||
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
|
||||
for i in range(0, num_patches, tile_size):
|
||||
end = min(i + tile_size, num_patches)
|
||||
|
||||
# Slice the current tile from the input tensors
|
||||
nerf_hidden_tile = nerf_hidden[i * batch:end * batch]
|
||||
nerf_pixels_tile = nerf_pixels[i * batch:end * batch]
|
||||
|
||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||
|
||||
# pass through the dynamic MLP blocks (the NeRF)
|
||||
for block in self.nerf_blocks:
|
||||
img_dct_tile = block(img_dct_tile, nerf_hidden_tile)
|
||||
|
||||
output_tiles.append(img_dct_tile)
|
||||
|
||||
# Concatenate the processed tiles along the patch dimension
|
||||
return torch.cat(output_tiles, dim=0)
|
||||
|
||||
def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams:
|
||||
params = self.params
|
||||
if not overrides:
|
||||
return params
|
||||
params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__}
|
||||
nullable_keys = frozenset(("nerf_embedder_dtype",))
|
||||
bad_keys = tuple(k for k in overrides if k not in params_dict)
|
||||
if bad_keys:
|
||||
e = f"Unknown key(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||
raise ValueError(e)
|
||||
bad_keys = tuple(
|
||||
k
|
||||
for k, v in overrides.items()
|
||||
if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys)
|
||||
)
|
||||
if bad_keys:
|
||||
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||
raise ValueError(e)
|
||||
# At this point it's all valid keys and values so we can merge with the existing params.
|
||||
params_dict |= overrides
|
||||
return params.__class__(**params_dict)
|
||||
|
||||
def _apply_x0_residual(self, predicted, noisy, timesteps):
|
||||
|
||||
# non zero during training to prevent 0 div
|
||||
eps = 0.0
|
||||
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
timestep: Tensor,
|
||||
context: Tensor,
|
||||
guidance: Optional[Tensor],
|
||||
control: Optional[dict]=None,
|
||||
transformer_options: dict={},
|
||||
**kwargs: dict,
|
||||
) -> Tensor:
|
||||
bs, c, h, w = x.shape
|
||||
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
if img.ndim != 4:
|
||||
raise ValueError("Input img tensor must be in [B, C, H, W] format.")
|
||||
if context.ndim != 3:
|
||||
raise ValueError("Input txt tensors must have 3 dimensions.")
|
||||
|
||||
params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {}))
|
||||
|
||||
h_len = (img.shape[-2] // self.patch_size)
|
||||
w_len = (img.shape[-1] // self.patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
|
||||
img_out = self.forward_orig(
|
||||
img,
|
||||
img_ids,
|
||||
context,
|
||||
txt_ids,
|
||||
timestep,
|
||||
guidance,
|
||||
control,
|
||||
transformer_options,
|
||||
attn_mask=kwargs.get("attention_mask", None),
|
||||
)
|
||||
|
||||
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||
|
||||
# If x0 variant → v-pred, just return this instead
|
||||
if hasattr(self, "__x0__"):
|
||||
out = self._apply_x0_residual(out, img, timestep)
|
||||
return out
|
||||
|
||||
@ -176,6 +176,7 @@ class Attention(nn.Module):
|
||||
context=None,
|
||||
mask=None,
|
||||
rope_emb=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -184,7 +185,7 @@ class Attention(nn.Module):
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
out = rearrange(out, " b n s c -> s b (n c)")
|
||||
return self.to_out(out)
|
||||
@ -546,6 +547,7 @@ class VideoAttn(nn.Module):
|
||||
context: Optional[torch.Tensor] = None,
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for video attention.
|
||||
@ -571,6 +573,7 @@ class VideoAttn(nn.Module):
|
||||
context_M_B_D,
|
||||
crossattn_mask,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
||||
return x_T_H_W_B_D
|
||||
@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module):
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for dynamically configured blocks with adaptive normalization.
|
||||
@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module):
|
||||
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||
context=None,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
elif self.block_type in ["cross_attn", "ca"]:
|
||||
x = x + gate_1_1_1_B_D * self.block(
|
||||
@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module):
|
||||
context=crossattn_emb,
|
||||
crossattn_mask=crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown block type: {self.block_type}")
|
||||
@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module):
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
for block in self.blocks:
|
||||
x = block(
|
||||
@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module):
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
return x
|
||||
|
||||
@ -27,6 +27,8 @@ from torchvision import transforms
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
import comfy.patcher_extension
|
||||
|
||||
from .blocks import (
|
||||
FinalLayer,
|
||||
GeneralDITTransformerBlock,
|
||||
@ -435,6 +437,42 @@ class GeneralDIT(nn.Module):
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask,
|
||||
fps,
|
||||
image_size,
|
||||
padding_mask,
|
||||
scalar_feature,
|
||||
data_type,
|
||||
latent_condition,
|
||||
latent_condition_sigma,
|
||||
condition_video_augment_sigma,
|
||||
**kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
# crossattn_emb: torch.Tensor,
|
||||
# crossattn_mask: Optional[torch.Tensor] = None,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
image_size: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
scalar_feature: Optional[torch.Tensor] = None,
|
||||
data_type: Optional[DataType] = DataType.VIDEO,
|
||||
latent_condition: Optional[torch.Tensor] = None,
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -482,6 +520,7 @@ class GeneralDIT(nn.Module):
|
||||
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
for _, block in self.blocks.items():
|
||||
assert (
|
||||
self.blocks["block0"].x_format == block.x_format
|
||||
@ -496,6 +535,7 @@ class GeneralDIT(nn.Module):
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||
|
||||
@ -11,6 +11,7 @@ import math
|
||||
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
||||
from torchvision import transforms
|
||||
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
@ -43,7 +44,7 @@ class GPT2FeedForward(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
||||
"""Computes multi-head attention using PyTorch's native implementation.
|
||||
|
||||
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
||||
@ -70,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
|
||||
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
||||
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
|
||||
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@ -179,8 +180,8 @@ class Attention(nn.Module):
|
||||
|
||||
return q, k, v
|
||||
|
||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
result = self.attn_op(q, k, v) # [B, S, H, D]
|
||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
||||
result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
|
||||
return self.output_dropout(self.output_proj(result))
|
||||
|
||||
def forward(
|
||||
@ -188,6 +189,7 @@ class Attention(nn.Module):
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -195,7 +197,7 @@ class Attention(nn.Module):
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
||||
return self.compute_attention(q, k, v)
|
||||
return self.compute_attention(q, k, v, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
@ -458,6 +460,7 @@ class Block(nn.Module):
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
@ -511,6 +514,7 @@ class Block(nn.Module):
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
@ -524,6 +528,7 @@ class Block(nn.Module):
|
||||
layer_norm_cross_attn: Callable,
|
||||
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
_normalized_x_B_T_H_W_D = _fn(
|
||||
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
||||
@ -533,6 +538,7 @@ class Block(nn.Module):
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
@ -546,6 +552,7 @@ class Block(nn.Module):
|
||||
self.layer_norm_cross_attn,
|
||||
scale_cross_attn_B_T_1_1_D,
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
|
||||
@ -805,7 +812,21 @@ class MiniTrainDIT(nn.Module):
|
||||
)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timesteps, context, fps, padding_mask, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
@ -850,6 +871,7 @@ class MiniTrainDIT(nn.Module):
|
||||
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
||||
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
"transformer_options": kwargs.get("transformer_options", {}),
|
||||
}
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
|
||||
@ -48,15 +48,44 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
||||
return embedding
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
class YakMLP(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
if yak_mlp:
|
||||
return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
|
||||
if mlp_silu_act:
|
||||
return nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||
SiLUActivation(),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
return nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
@ -80,14 +109,14 @@ class QKNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -98,11 +127,11 @@ class ModulationOut:
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
|
||||
def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple:
|
||||
if vec.ndim == 2:
|
||||
@ -129,77 +158,107 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
return tensor
|
||||
|
||||
|
||||
class SiLUActivation(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return self.gate_fn(x1) * x2
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
self.modulation = modulation
|
||||
|
||||
if self.modulation:
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if self.modulation:
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
if self.modulation:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
else:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
del img_modulated
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del img_qkv
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
del txt_modulated
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
if self.flipped_img_txt:
|
||||
q = torch.cat((img_q, txt_q), dim=2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat((img_k, txt_k), dim=2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat((img_v, txt_v), dim=2)
|
||||
del img_v, txt_v
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||
torch.cat((img_k, txt_k), dim=2),
|
||||
torch.cat((img_v, txt_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
del img_attn
|
||||
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||
del txt_attn
|
||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
@ -220,6 +279,10 @@ class SingleStreamBlock(nn.Module):
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
modulation=True,
|
||||
mlp_silu_act=False,
|
||||
bias=True,
|
||||
yak_mlp=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
@ -231,30 +294,55 @@ class SingleStreamBlock(nn.Module):
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
self.mlp_hidden_dim_first = self.mlp_hidden_dim
|
||||
self.yak_mlp = yak_mlp
|
||||
if mlp_silu_act:
|
||||
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
|
||||
self.mlp_act = SiLUActivation()
|
||||
else:
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
if self.yak_mlp:
|
||||
self.mlp_hidden_dim_first *= 2
|
||||
self.mlp_act = nn.SiLU()
|
||||
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
if modulation:
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.modulation = None
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
||||
if self.modulation:
|
||||
mod, _ = self.modulation(vec)
|
||||
else:
|
||||
mod = vec
|
||||
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del qkv
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
if self.yak_mlp:
|
||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||
else:
|
||||
mlp = self.mlp_act(mlp)
|
||||
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
@ -262,11 +350,11 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
|
||||
if vec.ndim == 2:
|
||||
|
||||
@ -4,23 +4,16 @@ from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import logging
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||
if pe is not None:
|
||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
q, k = apply_rope(q, k, pe)
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
||||
@ -36,10 +29,19 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
try:
|
||||
import comfy.quant_ops
|
||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
except:
|
||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
@ -6,6 +6,7 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
@ -14,6 +15,8 @@ from .layers import (
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Modulation,
|
||||
RMSNorm
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@ -32,6 +35,14 @@ class FluxParams:
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
txt_ids_dims: list
|
||||
global_modulation: bool = False
|
||||
mlp_silu_act: bool = False
|
||||
ops_bias: bool = True
|
||||
default_ref_method: str = "offset"
|
||||
ref_index_scale: float = 1.0
|
||||
yak_mlp: bool = False
|
||||
txt_norm: bool = False
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
@ -57,13 +68,22 @@ class Flux(nn.Module):
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||
if params.vec_in_dim is not None:
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.vector_in = None
|
||||
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
|
||||
if params.txt_norm:
|
||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.txt_norm = None
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
@ -72,6 +92,10 @@ class Flux(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=params.global_modulation is False,
|
||||
mlp_silu_act=params.mlp_silu_act,
|
||||
proj_bias=params.ops_bias,
|
||||
yak_mlp=params.yak_mlp,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@ -80,13 +104,30 @@ class Flux(nn.Module):
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if params.global_modulation:
|
||||
self.double_stream_modulation_img = Modulation(
|
||||
self.hidden_size,
|
||||
double=True,
|
||||
bias=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.double_stream_modulation_txt = Modulation(
|
||||
self.hidden_size,
|
||||
double=True,
|
||||
bias=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.single_stream_modulation = Modulation(
|
||||
self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
@ -102,9 +143,7 @@ class Flux(nn.Module):
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@ -116,9 +155,27 @@ class Flux(nn.Module):
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
if self.vector_in is not None:
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
|
||||
if self.txt_norm is not None:
|
||||
txt = self.txt_norm(txt)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
vec_orig = vec
|
||||
if self.params.global_modulation:
|
||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
||||
img = out["img"]
|
||||
txt = out["txt"]
|
||||
img_ids = out["img_ids"]
|
||||
txt_ids = out["txt_ids"]
|
||||
|
||||
if img_ids is not None:
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
@ -126,7 +183,10 @@ class Flux(nn.Module):
|
||||
pe = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -134,14 +194,16 @@ class Flux(nn.Module):
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
@ -150,52 +212,61 @@ class Flux(nn.Module):
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
img[:, :add.shape[1]] += add
|
||||
|
||||
if img.dtype == torch.float16:
|
||||
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
if self.params.global_modulation:
|
||||
vec, _ = self.single_stream_modulation(vec_orig)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
@ -207,30 +278,55 @@ class Flux(nn.Module):
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
steps_h = h_len
|
||||
steps_w = w_len
|
||||
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
index += rope_options.get("shift_t", 0.0)
|
||||
h_offset += rope_options.get("shift_y", 0.0)
|
||||
w_offset += rope_options.get("shift_x", 0.0)
|
||||
|
||||
img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0)
|
||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h_orig, w_orig = x.shape
|
||||
patch_size = self.patch_size
|
||||
|
||||
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||
img, img_ids = self.process_img(x)
|
||||
img, img_ids = self.process_img(x, transformer_options=transformer_options)
|
||||
img_tokens = img.shape[1]
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
|
||||
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
|
||||
for ref in ref_latents:
|
||||
if index_ref_method:
|
||||
index += 1
|
||||
if ref_latents_method == "index":
|
||||
index += self.params.ref_index_scale
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
elif ref_latents_method == "uxo":
|
||||
index = 0
|
||||
h_offset = h_len * patch_size + h
|
||||
w_offset = w_len * patch_size + w
|
||||
h += ref.shape[-2]
|
||||
w += ref.shape[-1]
|
||||
else:
|
||||
index = 1
|
||||
h_offset = 0
|
||||
@ -246,7 +342,12 @@ class Flux(nn.Module):
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
|
||||
if len(self.params.txt_ids_dims) > 0:
|
||||
for i in self.params.txt_ids_dims:
|
||||
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = out[:, :img_tokens]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
|
||||
|
||||
@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module):
|
||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||
crop_y,
|
||||
transformer_options={},
|
||||
**rope_rotation,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
rope_cos = rope_rotation.get("rope_cos")
|
||||
@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module):
|
||||
|
||||
xy = optimized_attention(q,
|
||||
k,
|
||||
v, self.num_heads, skip_reshape=True)
|
||||
v, self.num_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||
x = self.proj_x(x)
|
||||
@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
transformer_options={},
|
||||
**attn_kwargs,
|
||||
):
|
||||
"""Forward pass of a block.
|
||||
@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
y,
|
||||
scale_x=scale_msa_x,
|
||||
scale_y=scale_msa_y,
|
||||
transformer_options=transformer_options,
|
||||
**attn_kwargs,
|
||||
)
|
||||
|
||||
@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module):
|
||||
args["txt"],
|
||||
rope_cos=args["rope_cos"],
|
||||
rope_sin=args["rope_sin"],
|
||||
crop_y=args["num_tokens"]
|
||||
crop_y=args["num_tokens"],
|
||||
transformer_options=args["transformer_options"]
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
y_feat = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module):
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
crop_y=num_tokens,
|
||||
transformer_options=transformer_options,
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
@ -71,8 +72,8 @@ class TimestepEmbed(nn.Module):
|
||||
return t_emb
|
||||
|
||||
|
||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}):
|
||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options)
|
||||
|
||||
|
||||
class HiDreamAttnProcessor_flashattn:
|
||||
@ -85,6 +86,7 @@ class HiDreamAttnProcessor_flashattn:
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
@ -132,7 +134,7 @@ class HiDreamAttnProcessor_flashattn:
|
||||
query = torch.cat([query_1, query_2], dim=-1)
|
||||
key = torch.cat([key_1, key_2], dim=-1)
|
||||
|
||||
hidden_states = attention(query, key, value)
|
||||
hidden_states = attention(query, key, value, transformer_options=transformer_options)
|
||||
|
||||
if not attn.single:
|
||||
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
|
||||
@ -198,6 +200,7 @@ class HiDreamAttention(nn.Module):
|
||||
image_tokens_masks: torch.FloatTensor = None,
|
||||
norm_text_tokens: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
@ -205,6 +208,7 @@ class HiDreamAttention(nn.Module):
|
||||
image_tokens_masks = image_tokens_masks,
|
||||
text_tokens = norm_text_tokens,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
|
||||
@ -405,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
|
||||
transformer_options={},
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
|
||||
@ -418,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
norm_image_tokens,
|
||||
image_tokens_masks,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
|
||||
@ -482,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
|
||||
@ -499,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
image_tokens_masks,
|
||||
norm_text_tokens,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
@ -549,6 +556,7 @@ class HiDreamImageBlock(nn.Module):
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
) -> torch.FloatTensor:
|
||||
return self.block(
|
||||
image_tokens,
|
||||
@ -556,6 +564,7 @@ class HiDreamImageBlock(nn.Module):
|
||||
text_tokens,
|
||||
adaln_input,
|
||||
rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
|
||||
@ -692,7 +701,23 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
raise NotImplementedError
|
||||
return x, x_masks, img_sizes
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_llama3=None,
|
||||
image_cond=None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
@ -769,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
text_tokens = cur_encoder_hidden_states,
|
||||
adaln_input = adaln_input,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
||||
block_id += 1
|
||||
@ -792,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
text_tokens=None,
|
||||
adaln_input=adaln_input,
|
||||
rope=rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
||||
block_id += 1
|
||||
|
||||
@ -7,6 +7,7 @@ from comfy.ldm.flux.layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
class Hunyuan3Dv2(nn.Module):
|
||||
@ -67,6 +68,13 @@ class Hunyuan3Dv2(nn.Module):
|
||||
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, guidance, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
x = x.movedim(-1, -2)
|
||||
timestep = 1.0 - timestep
|
||||
txt = context
|
||||
@ -91,14 +99,16 @@ class Hunyuan3Dv2(nn.Module):
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
@ -107,7 +117,8 @@ class Hunyuan3Dv2(nn.Module):
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
@ -118,17 +129,19 @@ class Hunyuan3Dv2(nn.Module):
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
img = img[:, txt.shape[1]:, ...]
|
||||
img = self.final_layer(img, vec)
|
||||
|
||||
@ -4,81 +4,458 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from typing import Union, Tuple, List, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
from einops import repeat, rearrange
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import logging
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def generate_dense_grid_points(
|
||||
bbox_min: np.ndarray,
|
||||
bbox_max: np.ndarray,
|
||||
octree_resolution: int,
|
||||
indexing: str = "ij",
|
||||
):
|
||||
length = bbox_max - bbox_min
|
||||
num_cells = octree_resolution
|
||||
def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
|
||||
|
||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
||||
xyz = np.stack((xs, ys, zs), axis=-1)
|
||||
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
||||
# manually create the pointer vector
|
||||
assert src.size(0) == batch.numel()
|
||||
|
||||
return xyz, grid_size, length
|
||||
batch_size = int(batch.max()) + 1
|
||||
deg = src.new_zeros(batch_size, dtype = torch.long)
|
||||
|
||||
deg.scatter_add_(0, batch, torch.ones_like(batch))
|
||||
|
||||
ptr_vec = deg.new_zeros(batch_size + 1)
|
||||
torch.cumsum(deg, 0, out=ptr_vec[1:])
|
||||
|
||||
#return fps_sampling(src, ptr_vec, ratio)
|
||||
sampled_indicies = []
|
||||
|
||||
for b in range(batch_size):
|
||||
# start and the end of each batch
|
||||
start, end = ptr_vec[b].item(), ptr_vec[b + 1].item()
|
||||
# points from the point cloud
|
||||
points = src[start:end]
|
||||
|
||||
num_points = points.size(0)
|
||||
num_samples = max(1, math.ceil(num_points * sampling_ratio))
|
||||
|
||||
selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
|
||||
distances = torch.full((num_points,), float("inf"), device = src.device)
|
||||
|
||||
# select a random start point
|
||||
if start_random:
|
||||
farthest = torch.randint(0, num_points, (1,), device = src.device)
|
||||
else:
|
||||
farthest = torch.tensor([0], device = src.device, dtype = torch.long)
|
||||
|
||||
for i in range(num_samples):
|
||||
selected[i] = farthest
|
||||
centroid = points[farthest].squeeze(0)
|
||||
dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
|
||||
distances = torch.minimum(distances, dist)
|
||||
farthest = torch.argmax(distances)
|
||||
|
||||
sampled_indicies.append(torch.arange(start, end)[selected])
|
||||
|
||||
return torch.cat(sampled_indicies, dim = 0)
|
||||
class PointCrossAttention(nn.Module):
|
||||
def __init__(self,
|
||||
num_latents: int,
|
||||
downsample_ratio: float,
|
||||
pc_size: int,
|
||||
pc_sharpedge_size: int,
|
||||
point_feats: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
layers: int,
|
||||
fourier_embedder,
|
||||
normal_pe: bool = False,
|
||||
qkv_bias: bool = False,
|
||||
use_ln_post: bool = True,
|
||||
qk_norm: bool = True):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.fourier_embedder = fourier_embedder
|
||||
|
||||
self.pc_size = pc_size
|
||||
self.normal_pe = normal_pe
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.pc_sharpedge_size = pc_sharpedge_size
|
||||
self.num_latents = num_latents
|
||||
self.point_feats = point_feats
|
||||
|
||||
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
|
||||
|
||||
self.cross_attn = ResidualCrossAttentionBlock(
|
||||
width = width,
|
||||
heads = heads,
|
||||
qkv_bias = qkv_bias,
|
||||
qk_norm = qk_norm
|
||||
)
|
||||
|
||||
self.self_attn = None
|
||||
if layers > 0:
|
||||
self.self_attn = Transformer(
|
||||
width = width,
|
||||
heads = heads,
|
||||
qkv_bias = qkv_bias,
|
||||
qk_norm = qk_norm,
|
||||
layers = layers
|
||||
)
|
||||
|
||||
if use_ln_post:
|
||||
self.ln_post = nn.LayerNorm(width)
|
||||
else:
|
||||
self.ln_post = None
|
||||
|
||||
def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor):
|
||||
|
||||
"""
|
||||
Subsample points randomly from the point cloud (input_pc)
|
||||
Further sample the subsampled points to get query_pc
|
||||
take the fourier embeddings for both input and query pc
|
||||
|
||||
Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc.
|
||||
Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc).
|
||||
More computationally efficient.
|
||||
|
||||
Features are additional information for each point in the cloud
|
||||
"""
|
||||
|
||||
B, _, D = point_cloud.shape
|
||||
|
||||
num_latents = int(self.num_latents)
|
||||
|
||||
num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
|
||||
num_sharpedge_query = num_latents - num_random_query
|
||||
|
||||
# Split random and sharpedge surface points
|
||||
random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1)
|
||||
|
||||
# assert statements
|
||||
assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
|
||||
assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
|
||||
|
||||
input_random_pc_size = int(num_random_query * self.downsample_ratio)
|
||||
random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
|
||||
self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)
|
||||
|
||||
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
|
||||
|
||||
if input_sharpedge_pc_size == 0:
|
||||
sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
|
||||
sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)
|
||||
|
||||
else:
|
||||
sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
|
||||
self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)
|
||||
|
||||
# concat the random and sharpedges
|
||||
query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
|
||||
input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)
|
||||
|
||||
query = self.fourier_embedder(query_pc)
|
||||
data = self.fourier_embedder(input_pc)
|
||||
|
||||
if self.point_feats > 0:
|
||||
random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)
|
||||
|
||||
input_random_surface_features, query_random_features = \
|
||||
self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
|
||||
input_pc_size = input_random_pc_size, idx_query = random_idx_query)
|
||||
|
||||
if input_sharpedge_pc_size == 0:
|
||||
input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
|
||||
dtype = input_random_surface_features.dtype, device = point_cloud.device)
|
||||
|
||||
query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
|
||||
dtype = query_random_features.dtype, device = point_cloud.device)
|
||||
else:
|
||||
|
||||
input_sharpedge_surface_features, query_sharpedge_features = \
|
||||
self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
|
||||
batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)
|
||||
|
||||
query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
|
||||
input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)
|
||||
|
||||
if self.normal_pe:
|
||||
# apply the fourier embeddings on the first 3 dims (xyz)
|
||||
input_features_pe = self.fourier_embedder(input_features[..., :3])
|
||||
query_features_pe = self.fourier_embedder(query_features[..., :3])
|
||||
# replace the first 3 dims with the new PE ones
|
||||
input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
|
||||
query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)
|
||||
|
||||
# concat at the channels dim
|
||||
query = torch.cat([query, query_features], dim = -1)
|
||||
data = torch.cat([data, input_features], dim = -1)
|
||||
|
||||
# don't return pc_info to avoid unnecessary memory usuage
|
||||
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])
|
||||
|
||||
def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):
|
||||
|
||||
query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)
|
||||
|
||||
# apply projections
|
||||
query = self.input_proj(query)
|
||||
data = self.input_proj(data)
|
||||
|
||||
# apply cross attention between query and data
|
||||
latents = self.cross_attn(query, data)
|
||||
|
||||
if self.self_attn is not None:
|
||||
latents = self.self_attn(latents)
|
||||
|
||||
if self.ln_post is not None:
|
||||
latents = self.ln_post(latents)
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
class VanillaVolumeDecoder:
|
||||
def subsample(self, pc, num_query, input_pc_size: int):
|
||||
|
||||
"""
|
||||
num_query: number of points to keep after FPS
|
||||
input_pc_size: number of points to select before FPS
|
||||
"""
|
||||
|
||||
B, _, D = pc.shape
|
||||
query_ratio = num_query / input_pc_size
|
||||
|
||||
# random subsampling of points inside the point cloud
|
||||
idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
|
||||
input_pc = pc[:, idx_pc, :]
|
||||
|
||||
# flatten to allow applying fps across the whole batch
|
||||
flattent_input_pc = input_pc.view(B * input_pc_size, D)
|
||||
|
||||
# construct a batch_down tensor to tell fps
|
||||
# which points belong to which batch
|
||||
N_down = int(flattent_input_pc.shape[0] / B)
|
||||
batch_down = torch.arange(B).to(pc.device)
|
||||
batch_down = torch.repeat_interleave(batch_down, N_down)
|
||||
|
||||
idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
|
||||
query_pc = flattent_input_pc[idx_query].view(B, -1, D)
|
||||
|
||||
return query_pc, input_pc, idx_pc, idx_query
|
||||
|
||||
def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query):
|
||||
|
||||
B = batch_size
|
||||
|
||||
input_surface_features = features[:, idx_pc, :]
|
||||
flattent_input_features = input_surface_features.view(B * input_pc_size, -1)
|
||||
query_features = flattent_input_features[idx_query].view(B, -1,
|
||||
flattent_input_features.shape[-1])
|
||||
|
||||
return input_surface_features, query_features
|
||||
|
||||
def normalize_mesh(mesh, scale = 0.9999):
|
||||
"""Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
|
||||
|
||||
bbox = mesh.bounds
|
||||
center = (bbox[1] + bbox[0]) / 2
|
||||
|
||||
max_extent = (bbox[1] - bbox[0]).max()
|
||||
mesh.apply_translation(-center)
|
||||
mesh.apply_scale((2 * scale) / max_extent)
|
||||
|
||||
return mesh
|
||||
|
||||
def sample_pointcloud(mesh, num = 200000):
|
||||
""" Uniformly sample points from the surface of the mesh """
|
||||
|
||||
points, face_idx = mesh.sample(num, return_index = True)
|
||||
normals = mesh.face_normals[face_idx]
|
||||
return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
|
||||
|
||||
def detect_sharp_edges(mesh, threshold=0.985):
|
||||
"""Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
|
||||
|
||||
V, F = mesh.vertices, mesh.faces
|
||||
VN, FN = mesh.vertex_normals, mesh.face_normals
|
||||
|
||||
sharp_mask = np.ones(V.shape[0])
|
||||
for i in range(3):
|
||||
indices = F[:, i]
|
||||
alignment = np.einsum('ij,ij->i', VN[indices], FN)
|
||||
dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1)
|
||||
sharp_mask[indices] = np.min(dot_stack, axis=-1)
|
||||
|
||||
edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
|
||||
edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
|
||||
sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold)
|
||||
|
||||
return edge_a[sharp_edges], edge_b[sharp_edges]
|
||||
|
||||
|
||||
def sharp_sample_pointcloud(mesh, num = 16384):
|
||||
""" Sample points preferentially from sharp edges in the mesh. """
|
||||
|
||||
edge_a, edge_b = detect_sharp_edges(mesh)
|
||||
V, VN = mesh.vertices, mesh.vertex_normals
|
||||
|
||||
va, vb = V[edge_a], V[edge_b]
|
||||
na, nb = VN[edge_a], VN[edge_b]
|
||||
|
||||
edge_lengths = np.linalg.norm(vb - va, axis=-1)
|
||||
weights = edge_lengths / edge_lengths.sum()
|
||||
|
||||
indices = np.searchsorted(np.cumsum(weights), np.random.rand(num))
|
||||
t = np.random.rand(num, 1)
|
||||
|
||||
samples = t * va[indices] + (1 - t) * vb[indices]
|
||||
normals = t * na[indices] + (1 - t) * nb[indices]
|
||||
|
||||
return samples.astype(np.float32), normals.astype(np.float32)
|
||||
|
||||
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
|
||||
"""Load a surface with optional sharp-edge annotations from a trimesh mesh."""
|
||||
|
||||
import trimesh
|
||||
|
||||
try:
|
||||
mesh_full = trimesh.util.concatenate(mesh.dump())
|
||||
except Exception:
|
||||
mesh_full = trimesh.util.concatenate(mesh)
|
||||
|
||||
mesh_full = normalize_mesh(mesh_full)
|
||||
|
||||
faces = mesh_full.faces
|
||||
vertices = mesh_full.vertices
|
||||
origin_face_count = faces.shape[0]
|
||||
|
||||
mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count])
|
||||
mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:])
|
||||
|
||||
area_surface = mesh_surface.area
|
||||
area_fill = mesh_fill.area
|
||||
total_area = area_surface + area_fill
|
||||
|
||||
sample_num = 499712 // 2
|
||||
fill_ratio = area_fill / total_area if total_area > 0 else 0
|
||||
|
||||
num_fill = int(sample_num * fill_ratio)
|
||||
num_surface = sample_num - num_fill
|
||||
|
||||
surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface)
|
||||
fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill)
|
||||
|
||||
sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num)
|
||||
|
||||
def assemble_tensor(points, normals, label=None):
|
||||
|
||||
data = torch.cat([points, normals], dim=1).half().to(device)
|
||||
|
||||
if label is not None:
|
||||
label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device)
|
||||
data = torch.cat([data, label_tensor], dim=1)
|
||||
|
||||
return data
|
||||
|
||||
surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
|
||||
torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
|
||||
label = 0 if sharpedge_flag else None)
|
||||
|
||||
sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
|
||||
label = 1 if sharpedge_flag else None)
|
||||
|
||||
rng = np.random.default_rng()
|
||||
|
||||
surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
|
||||
sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]
|
||||
|
||||
full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)
|
||||
|
||||
return full
|
||||
|
||||
class SharpEdgeSurfaceLoader:
|
||||
""" Load mesh surface and sharp edge samples. """
|
||||
|
||||
def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):
|
||||
|
||||
self.num_uniform_points = num_uniform_points
|
||||
self.num_sharp_points = num_sharp_points
|
||||
self.total_points = num_uniform_points + num_sharp_points
|
||||
|
||||
def __call__(self, mesh_input, device = "cuda"):
|
||||
mesh = self._load_mesh(mesh_input)
|
||||
return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)
|
||||
|
||||
@staticmethod
|
||||
def _load_mesh(mesh_input):
|
||||
import trimesh
|
||||
|
||||
if isinstance(mesh_input, str):
|
||||
mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
|
||||
else:
|
||||
mesh = mesh_input
|
||||
|
||||
if isinstance(mesh, trimesh.Scene):
|
||||
combined = None
|
||||
for obj in mesh.geometry.values():
|
||||
combined = obj if combined is None else combined + obj
|
||||
return combined
|
||||
|
||||
return mesh
|
||||
|
||||
class DiagonalGaussianDistribution:
|
||||
def __init__(self, params: torch.Tensor, feature_dim: int = -1):
|
||||
|
||||
# divide quant channels (8) into mean and log variance
|
||||
self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)
|
||||
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
|
||||
def sample(self):
|
||||
|
||||
eps = torch.randn_like(self.std)
|
||||
z = self.mean + eps * self.std
|
||||
|
||||
return z
|
||||
|
||||
################################################
|
||||
# Volume Decoder
|
||||
################################################
|
||||
|
||||
class VanillaVolumeDecoder():
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
geo_decoder: Callable,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||
num_chunks: int = 10000,
|
||||
octree_resolution: int = None,
|
||||
enable_pbar: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
device = latents.device
|
||||
dtype = latents.dtype
|
||||
batch_size = latents.shape[0]
|
||||
def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
|
||||
num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):
|
||||
|
||||
# 1. generate query points
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
|
||||
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_resolution=octree_resolution,
|
||||
indexing="ij"
|
||||
)
|
||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||
bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])
|
||||
|
||||
x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
|
||||
[xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
|
||||
xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
|
||||
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
|
||||
|
||||
# 2. latents to 3d volume
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
|
||||
for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
|
||||
disable=not enable_pbar):
|
||||
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
||||
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||
|
||||
chunk_queries = xyz[start: start + num_chunks, :]
|
||||
chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
|
||||
logits = geo_decoder(queries = chunk_queries, latents = latents)
|
||||
batch_logits.append(logits)
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1)
|
||||
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
||||
grid_logits = torch.cat(batch_logits, dim = 1)
|
||||
grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()
|
||||
|
||||
return grid_logits
|
||||
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||
each feature dimension of `x[..., i]` into:
|
||||
@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module):
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionProcessor:
|
||||
def __call__(self, attn, q, k, v):
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
@ -232,38 +607,41 @@ class MLP(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||
|
||||
|
||||
class QKVMultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
n_data = None,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=ops.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.n_data = n_data
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
self.attn_processor = CrossAttentionProcessor()
|
||||
|
||||
def forward(self, q, kv):
|
||||
|
||||
_, n_ctx, _ = q.shape
|
||||
bs, n_data, width = kv.shape
|
||||
|
||||
attn_ch = width // self.heads // 2
|
||||
q = q.view(bs, n_ctx, self.heads, -1)
|
||||
|
||||
kv = kv.view(bs, n_data, self.heads, -1)
|
||||
k, v = torch.split(kv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = self.attn_processor(self, q, k, v)
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
|
||||
return out
|
||||
|
||||
class MultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module):
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCrossAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
|
||||
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module):
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
|
||||
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
|
||||
self.c_proj = ops.Linear(width, width)
|
||||
self.attention = QKVMultiheadAttention(
|
||||
@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module):
|
||||
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
|
||||
if self.downsample_ratio != 1:
|
||||
self.latents_proj = ops.Linear(width * downsample_ratio, width)
|
||||
if self.enable_ln_post == False:
|
||||
if not self.enable_ln_post:
|
||||
qk_norm = False
|
||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||
width=width,
|
||||
@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module):
|
||||
|
||||
class ShapeVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
embed_dim: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
num_decoder_layers: int,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
include_pi: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary",
|
||||
drop_path_rate: float = 0.0,
|
||||
scale_factor: float = 1.0,
|
||||
self,
|
||||
*,
|
||||
num_latents: int = 4096,
|
||||
embed_dim: int = 64,
|
||||
width: int = 1024,
|
||||
heads: int = 16,
|
||||
num_decoder_layers: int = 16,
|
||||
num_encoder_layers: int = 8,
|
||||
pc_size: int = 81920,
|
||||
pc_sharpedge_size: int = 0,
|
||||
point_feats: int = 4,
|
||||
downsample_ratio: int = 20,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
drop_path_rate: float = 0.0,
|
||||
include_pi: bool = False,
|
||||
scale_factor: float = 1.0039506158752403,
|
||||
label_type: str = "binary",
|
||||
):
|
||||
super().__init__()
|
||||
self.geo_decoder_ln_post = geo_decoder_ln_post
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
|
||||
self.encoder = PointCrossAttention(layers = num_encoder_layers,
|
||||
num_latents = num_latents,
|
||||
downsample_ratio = downsample_ratio,
|
||||
heads = heads,
|
||||
pc_size = pc_size,
|
||||
width = width,
|
||||
point_feats = point_feats,
|
||||
fourier_embedder = self.fourier_embedder,
|
||||
pc_sharpedge_size = pc_sharpedge_size)
|
||||
|
||||
self.post_kl = ops.Linear(embed_dim, width)
|
||||
|
||||
self.transformer = Transformer(
|
||||
@ -583,5 +975,14 @@ class ShapeVAE(nn.Module):
|
||||
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
|
||||
return grid_logits.movedim(-2, -1)
|
||||
|
||||
def encode(self, x):
|
||||
return None
|
||||
def encode(self, surface):
|
||||
|
||||
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
||||
latents = self.encoder(pc, feats)
|
||||
|
||||
moments = self.pre_kl(latents)
|
||||
posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)
|
||||
|
||||
latents = posterior.sample()
|
||||
|
||||
return latents
|
||||
|
||||
659
comfy/ldm/hunyuan3dv2_1/hunyuandit.py
Normal file
659
comfy/ldm/hunyuan3dv2_1/hunyuandit.py
Normal file
@ -0,0 +1,659 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
class GELU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, operations, device, dtype):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype)
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
if gate.device.type == "mps":
|
||||
return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype)
|
||||
|
||||
return F.gelu(gate)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.gelu(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, dim_out = None, mult: int = 4,
|
||||
dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None):
|
||||
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
self.net.append(act_fn)
|
||||
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class AddAuxLoss(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, loss):
|
||||
# do nothing in forward (no computation)
|
||||
ctx.requires_aux_loss = loss.requires_grad
|
||||
ctx.dtype = loss.dtype
|
||||
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# add the aux loss gradients
|
||||
grad_loss = None
|
||||
# put the aux grad the same as the main grad loss
|
||||
# aux grad contributes equally
|
||||
if ctx.requires_aux_loss:
|
||||
grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device)
|
||||
|
||||
return grad_output, grad_loss
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None):
|
||||
|
||||
super().__init__()
|
||||
self.top_k = num_experts_per_tok
|
||||
self.n_routed_experts = num_experts
|
||||
|
||||
self.alpha = aux_loss_alpha
|
||||
|
||||
self.gating_dim = embed_dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# flatten hidden states
|
||||
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
||||
|
||||
# get logits and pass it to softmax
|
||||
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None)
|
||||
scores = logits.softmax(dim = -1)
|
||||
|
||||
topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
|
||||
# used bincount instead of one hot encoding
|
||||
counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float()
|
||||
ce = counts / topk_idx.numel() # normalized expert usage
|
||||
|
||||
# mean expert score
|
||||
Pi = scores_for_aux.mean(0)
|
||||
|
||||
# expert balance loss
|
||||
aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = None
|
||||
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
class MoEBlock(nn.Module):
|
||||
def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0,
|
||||
ff_inner_dim: int = None, operations = None, device = None, dtype = None):
|
||||
super().__init__()
|
||||
|
||||
self.moe_top_k = moe_top_k
|
||||
self.num_experts = num_experts
|
||||
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
|
||||
for _ in range(num_experts)
|
||||
])
|
||||
|
||||
self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype)
|
||||
self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
|
||||
if self.training:
|
||||
|
||||
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0)
|
||||
y = torch.empty_like(hidden_states, dtype = hidden_states.dtype)
|
||||
|
||||
for i, expert in enumerate(self.experts):
|
||||
tmp = expert(hidden_states[flat_topk_idx == i])
|
||||
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
|
||||
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1)
|
||||
y = y.view(*orig_shape)
|
||||
|
||||
y = AddAuxLoss.apply(y, aux_loss)
|
||||
else:
|
||||
y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
|
||||
y = y + self.shared_experts(identity)
|
||||
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
|
||||
# no need for .numpy().cpu() here
|
||||
tokens_per_expert = flat_expert_indices.bincount().cumsum(0)
|
||||
token_idxs = idxs // self.moe_top_k
|
||||
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
||||
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens)
|
||||
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
|
||||
# use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_
|
||||
# + avoid dtype conversion
|
||||
expert_cache.index_add_(0, exp_token_idx, expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0,
|
||||
scale: float = 1.0, max_period: int = 10000):
|
||||
super().__init__()
|
||||
|
||||
self.num_channels = num_channels
|
||||
half_dim = num_channels // 2
|
||||
|
||||
# precompute the “inv_freq” vector once
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
half_dim, dtype=torch.float32
|
||||
) / (half_dim - downscale_freq_shift)
|
||||
|
||||
inv_freq = torch.exp(exponent)
|
||||
|
||||
# pad
|
||||
if num_channels % 2 == 1:
|
||||
# we’ll pad a zero at the end of the cos-half
|
||||
inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)])
|
||||
|
||||
# register to buffer so it moves with the device
|
||||
self.register_buffer("inv_freq", inv_freq, persistent = False)
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps: torch.Tensor):
|
||||
|
||||
x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0)
|
||||
|
||||
|
||||
# fused CUDA kernels for sin and cos
|
||||
sin_emb = x.sin()
|
||||
cos_emb = x.cos()
|
||||
|
||||
emb = torch.cat([sin_emb, cos_emb], dim = 1)
|
||||
|
||||
# scale factor
|
||||
if self.scale != 1.0:
|
||||
emb = emb * self.scale
|
||||
|
||||
# If we padded inv_freq for odd, emb is already wide enough; otherwise:
|
||||
if emb.shape[1] > self.num_channels:
|
||||
emb = emb[:, :self.num_channels]
|
||||
|
||||
return emb
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None):
|
||||
super().__init__()
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype),
|
||||
nn.GELU(),
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype)
|
||||
|
||||
self.time_embed = Timesteps(hidden_size)
|
||||
|
||||
def forward(self, timesteps, condition):
|
||||
|
||||
timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype)
|
||||
|
||||
if condition is not None:
|
||||
cond_embed = self.cond_proj(condition)
|
||||
timestep_embed = timestep_embed + cond_embed
|
||||
|
||||
time_conditioned = self.mlp(timestep_embed)
|
||||
|
||||
# for broadcasting with image tokens
|
||||
return time_conditioned.unsqueeze(1)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, *, width: int, operations = None, device = None, dtype = None):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype)
|
||||
self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype)
|
||||
self.gelu = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.gelu(self.fc1(x)))
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
qdim,
|
||||
kdim,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm,
|
||||
use_fp16: bool = False,
|
||||
operations = None,
|
||||
dtype = None,
|
||||
device = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.qdim = qdim
|
||||
self.kdim = kdim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.qdim // num_heads
|
||||
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||
self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||
self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
if norm_layer == nn.LayerNorm:
|
||||
norm_layer = operations.LayerNorm
|
||||
else:
|
||||
norm_layer = operations.RMSNorm
|
||||
|
||||
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x, y):
|
||||
|
||||
b, s1, _ = x.shape
|
||||
_, s2, _ = y.shape
|
||||
|
||||
y = y.to(next(self.to_k.parameters()).dtype)
|
||||
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(y)
|
||||
v = self.to_v(y)
|
||||
|
||||
kv = torch.cat((k, v), dim=-1)
|
||||
split_size = kv.shape[-1] // self.num_heads // 2
|
||||
|
||||
kv = kv.view(1, -1, self.num_heads, split_size * 2)
|
||||
k, v = torch.split(kv, split_size, dim=-1)
|
||||
|
||||
q = q.view(b, s1, self.num_heads, self.head_dim)
|
||||
k = k.view(b, s2, self.num_heads, self.head_dim)
|
||||
v = v.reshape(b, s2, self.num_heads * self.head_dim)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
x = optimized_attention(
|
||||
q.reshape(b, s1, self.num_heads * self.head_dim),
|
||||
k.reshape(b, s2, self.num_heads * self.head_dim),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
out = self.out_proj(x)
|
||||
|
||||
return out
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
qkv_bias = True,
|
||||
qk_norm = False,
|
||||
norm_layer = nn.LayerNorm,
|
||||
use_fp16: bool = False,
|
||||
operations = None,
|
||||
device = None,
|
||||
dtype = None
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||
self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||
self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
if norm_layer == nn.LayerNorm:
|
||||
norm_layer = operations.LayerNorm
|
||||
else:
|
||||
norm_layer = operations.RMSNorm
|
||||
|
||||
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, _ = x.shape
|
||||
|
||||
query = self.to_q(x)
|
||||
key = self.to_k(x)
|
||||
value = self.to_v(x)
|
||||
|
||||
qkv_combined = torch.cat((query, key, value), dim=-1)
|
||||
split_size = qkv_combined.shape[-1] // self.num_heads // 3
|
||||
|
||||
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
query = query.reshape(B, N, self.num_heads, self.head_dim)
|
||||
key = key.reshape(B, N, self.num_heads, self.head_dim)
|
||||
value = value.reshape(B, N, self.num_heads * self.head_dim)
|
||||
|
||||
query = self.q_norm(query)
|
||||
key = self.k_norm(key)
|
||||
|
||||
x = optimized_attention(
|
||||
query.reshape(B, N, self.num_heads * self.head_dim),
|
||||
key.reshape(B, N, self.num_heads * self.head_dim),
|
||||
value,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
class HunYuanDiTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
c_emb_size,
|
||||
num_heads,
|
||||
text_states_dim=1024,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm_layer=True,
|
||||
qkv_bias=True,
|
||||
skip_connection=True,
|
||||
timested_modulate=False,
|
||||
use_moe: bool = False,
|
||||
num_experts: int = 8,
|
||||
moe_top_k: int = 2,
|
||||
use_fp16: bool = False,
|
||||
operations = None,
|
||||
device = None, dtype = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# eps can't be 1e-6 in fp16 mode because of numerical stability issues
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
|
||||
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
|
||||
norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations)
|
||||
|
||||
self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
|
||||
self.timested_modulate = timested_modulate
|
||||
if self.timested_modulate:
|
||||
self.default_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype)
|
||||
)
|
||||
|
||||
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16,
|
||||
device = device, dtype = dtype, operations = operations)
|
||||
|
||||
self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
|
||||
if skip_connection:
|
||||
self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype)
|
||||
else:
|
||||
self.skip_linear = None
|
||||
|
||||
self.use_moe = use_moe
|
||||
|
||||
if self.use_moe:
|
||||
self.moe = MoEBlock(
|
||||
hidden_size,
|
||||
num_experts = num_experts,
|
||||
moe_top_k = moe_top_k,
|
||||
dropout = 0.0,
|
||||
ff_inner_dim = int(hidden_size * 4.0),
|
||||
device = device, dtype = dtype,
|
||||
operations = operations
|
||||
)
|
||||
else:
|
||||
self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None):
|
||||
|
||||
if self.skip_linear is not None:
|
||||
combined = torch.cat([skip_tensor, hidden_states], dim=-1)
|
||||
hidden_states = self.skip_linear(combined)
|
||||
hidden_states = self.skip_norm(hidden_states)
|
||||
|
||||
# self attention
|
||||
if self.timested_modulate:
|
||||
modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1)
|
||||
hidden_states = hidden_states + modulation_shift
|
||||
|
||||
self_attn_out = self.attn1(self.norm1(hidden_states))
|
||||
hidden_states = hidden_states + self_attn_out
|
||||
|
||||
# cross attention
|
||||
hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states)
|
||||
|
||||
# MLP Layer
|
||||
mlp_input = self.norm3(hidden_states)
|
||||
|
||||
if self.use_moe:
|
||||
hidden_states = hidden_states + self.moe(mlp_input)
|
||||
else:
|
||||
hidden_states = hidden_states + self.mlp(mlp_input)
|
||||
|
||||
return hidden_states
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
|
||||
def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None):
|
||||
super().__init__()
|
||||
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm_final(x)
|
||||
x = x[:, 1:]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class HunYuanDiTPlain(nn.Module):
|
||||
|
||||
# init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 64,
|
||||
hidden_size: int = 2048,
|
||||
context_dim: int = 1024,
|
||||
depth: int = 21,
|
||||
num_heads: int = 16,
|
||||
qk_norm: bool = True,
|
||||
qkv_bias: bool = False,
|
||||
num_moe_layers: int = 6,
|
||||
guidance_cond_proj_dim = 2048,
|
||||
norm_type = 'layer',
|
||||
num_experts: int = 8,
|
||||
moe_top_k: int = 2,
|
||||
use_fp16: bool = False,
|
||||
dtype = None,
|
||||
device = None,
|
||||
operations = None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
self.dtype = dtype
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.depth = depth
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm
|
||||
qk_norm = operations.RMSNorm
|
||||
|
||||
self.context_dim = context_dim
|
||||
self.guidance_cond_proj_dim = guidance_cond_proj_dim
|
||||
|
||||
self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations)
|
||||
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||
c_emb_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
text_states_dim=context_dim,
|
||||
qk_norm=qk_norm,
|
||||
norm_layer = norm,
|
||||
qk_norm_layer = qk_norm,
|
||||
skip_connection=layer > depth // 2,
|
||||
qkv_bias=qkv_bias,
|
||||
use_moe=True if depth - layer <= num_moe_layers else False,
|
||||
num_experts=num_experts,
|
||||
moe_top_k=moe_top_k,
|
||||
use_fp16 = use_fp16,
|
||||
device = device, dtype = dtype, operations = operations)
|
||||
for layer in range(depth)
|
||||
])
|
||||
|
||||
self.depth = depth
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
||||
|
||||
x = x.movedim(-1, -2)
|
||||
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
||||
|
||||
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
||||
main_condition = context
|
||||
|
||||
t = 1.0 - t
|
||||
|
||||
time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond'))
|
||||
|
||||
x = x.to(dtype = next(self.x_embedder.parameters()).dtype)
|
||||
x_embedded = self.x_embedder(x)
|
||||
|
||||
combined = torch.cat([time_embedded, x_embedded], dim=1)
|
||||
|
||||
def block_wrap(args):
|
||||
return block(
|
||||
args["x"],
|
||||
args["t"],
|
||||
args["cond"],
|
||||
skip_tensor=args.get("skip"),)
|
||||
|
||||
skip_stack = []
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for idx, block in enumerate(self.blocks):
|
||||
if idx <= self.depth // 2:
|
||||
skip_input = None
|
||||
else:
|
||||
skip_input = skip_stack.pop()
|
||||
|
||||
if ("block", idx) in blocks_replace:
|
||||
|
||||
combined = blocks_replace[("block", idx)](
|
||||
{
|
||||
"x": combined,
|
||||
"t": time_embedded,
|
||||
"cond": main_condition,
|
||||
"skip": skip_input,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
else:
|
||||
combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input)
|
||||
|
||||
if idx < self.depth // 2:
|
||||
skip_stack.append(combined)
|
||||
|
||||
output = self.final_layer(combined)
|
||||
output = output.movedim(-2, -1) * (-1.0)
|
||||
|
||||
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
||||
return torch.cat([uncond_emb, cond_emb])
|
||||
@ -1,11 +1,11 @@
|
||||
#Based on Flux code because of weird hunyuan video code license.
|
||||
|
||||
import torch
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.flux.layers
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from einops import repeat
|
||||
|
||||
@ -39,6 +39,11 @@ class HunyuanVideoParams:
|
||||
patch_size: list
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
byt5: bool
|
||||
meanflow: bool
|
||||
use_cond_type_embedding: bool
|
||||
vision_in_dim: int
|
||||
meanflow_sum: bool
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
@ -77,13 +82,13 @@ class TokenRefinerBlock(nn.Module):
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask):
|
||||
def forward(self, x, c, mask, transformer_options={}):
|
||||
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn.qkv(norm_x)
|
||||
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
|
||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
|
||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
||||
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
||||
@ -114,14 +119,14 @@ class IndividualTokenRefiner(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask):
|
||||
def forward(self, x, c, mask, transformer_options={}):
|
||||
m = None
|
||||
if mask is not None:
|
||||
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
|
||||
m = m + m.transpose(2, 3)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, c, m)
|
||||
x = block(x, c, m, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
@ -149,17 +154,45 @@ class TokenRefiner(nn.Module):
|
||||
x,
|
||||
timesteps,
|
||||
mask,
|
||||
transformer_options={},
|
||||
):
|
||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||
# m = mask.float().unsqueeze(-1)
|
||||
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
if x.dtype == torch.float16:
|
||||
c = x.float().sum(dim=1) / x.shape[1]
|
||||
else:
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
|
||||
c = t + self.c_embedder(c.to(x.dtype))
|
||||
x = self.input_embedder(x)
|
||||
x = self.individual_token_refiner(x, c, mask)
|
||||
x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
class ByT5Mapper(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device)
|
||||
self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device)
|
||||
self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||
self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device)
|
||||
self.use_res = use_res
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res:
|
||||
res = x
|
||||
x = self.layernorm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
x2 = self.act_fn(x)
|
||||
x2 = self.fc3(x2)
|
||||
if self.use_res:
|
||||
x2 = x2 + res
|
||||
return x2
|
||||
|
||||
class HunyuanVideo(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
@ -168,11 +201,15 @@ class HunyuanVideo(nn.Module):
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
params = HunyuanVideoParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
self.use_cond_type_embedding = params.use_cond_type_embedding
|
||||
self.vision_in_dim = params.vision_in_dim
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@ -184,9 +221,13 @@ class HunyuanVideo(nn.Module):
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
|
||||
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
|
||||
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
if params.vec_in_dim is not None:
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.vector_in = None
|
||||
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
@ -214,9 +255,38 @@ class HunyuanVideo(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
if params.byt5:
|
||||
self.byt5_in = ByT5Mapper(
|
||||
in_dim=1472,
|
||||
out_dim=2048,
|
||||
hidden_dim=2048,
|
||||
out_dim1=self.hidden_size,
|
||||
use_res=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.byt5_in = None
|
||||
|
||||
if params.meanflow:
|
||||
self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.time_r_in = None
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# HunyuanVideo 1.5 specific modules
|
||||
if self.vision_in_dim is not None:
|
||||
from comfy.ldm.wan.model import MLPProj
|
||||
self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
|
||||
else:
|
||||
self.vision_in = None
|
||||
if self.use_cond_type_embedding:
|
||||
# 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
|
||||
self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
|
||||
else:
|
||||
self.cond_type_embedding = None
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
@ -225,10 +295,13 @@ class HunyuanVideo(nn.Module):
|
||||
txt_ids: Tensor,
|
||||
txt_mask: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
y: Tensor = None,
|
||||
txt_byt5=None,
|
||||
clip_fea=None,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
disable_time_r=False,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
@ -239,6 +312,14 @@ class HunyuanVideo(nn.Module):
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||
|
||||
if (self.time_r_in is not None) and (not disable_time_r):
|
||||
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
|
||||
if len(w) > 0:
|
||||
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
||||
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
|
||||
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
|
||||
vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2
|
||||
|
||||
if ref_latent is not None:
|
||||
ref_latent_ids = self.img_ids(ref_latent)
|
||||
ref_latent = self.img_in(ref_latent)
|
||||
@ -249,13 +330,17 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
if guiding_frame_index is not None:
|
||||
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||
if self.vector_in is not None:
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||
else:
|
||||
vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1)
|
||||
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
||||
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
||||
modulation_dims_txt = [(0, None, 1)]
|
||||
else:
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
if self.vector_in is not None:
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
modulation_dims = None
|
||||
modulation_dims_txt = None
|
||||
|
||||
@ -266,7 +351,32 @@ class HunyuanVideo(nn.Module):
|
||||
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
||||
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
||||
|
||||
txt = self.txt_in(txt, timesteps, txt_mask)
|
||||
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.cond_type_embedding is not None:
|
||||
self.cond_type_embedding.to(txt.device)
|
||||
cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
|
||||
txt = txt + cond_emb.to(txt.dtype)
|
||||
|
||||
if self.byt5_in is not None and txt_byt5 is not None:
|
||||
txt_byt5 = self.byt5_in(txt_byt5)
|
||||
if self.cond_type_embedding is not None:
|
||||
cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
|
||||
txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
|
||||
txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5
|
||||
else:
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
|
||||
|
||||
if clip_fea is not None:
|
||||
txt_vision_states = self.vision_in(clip_fea)
|
||||
if self.cond_type_embedding is not None:
|
||||
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
|
||||
txt_vision_states = txt_vision_states + cond_emb
|
||||
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
|
||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
@ -280,18 +390,21 @@ class HunyuanVideo(nn.Module):
|
||||
attn_mask = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@ -302,17 +415,20 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@ -327,12 +443,16 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
shape = initial_shape[-3:]
|
||||
shape = initial_shape[-len(self.patch_size):]
|
||||
for i in range(len(shape)):
|
||||
shape[i] = shape[i] // self.patch_size[i]
|
||||
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
if img.ndim == 8:
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
else:
|
||||
img = img.permute(0, 3, 1, 4, 2, 5)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3])
|
||||
return img
|
||||
|
||||
def img_ids(self, x):
|
||||
@ -347,9 +467,30 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
||||
def img_ids_2d(self, x):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
h_len = ((h + (patch_size[0] // 2)) // patch_size[0])
|
||||
w_len = ((w + (patch_size[1] // 2)) // patch_size[1])
|
||||
img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
bs = x.shape[0]
|
||||
if len(self.patch_size) == 3:
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
img_ids = self.img_ids_2d(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
|
||||
122
comfy/ldm/hunyuan_video/upsampler.py
Normal file
122
comfy/ldm/hunyuan_video/upsampler.py
Normal file
@ -0,0 +1,122 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
|
||||
class SRResidualCausalBlock3D(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
nn.SiLU(inplace=True),
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
nn.SiLU(inplace=True),
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.block(x)
|
||||
|
||||
class SRModel3DV2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
hidden_channels: int = 64,
|
||||
num_blocks: int = 6,
|
||||
global_residual: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
|
||||
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
|
||||
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
|
||||
self.global_residual = bool(global_residual)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
y = self.in_conv(x)
|
||||
for blk in self.blocks:
|
||||
y = blk(y)
|
||||
y = self.out_conv(y)
|
||||
if self.global_residual and (y.shape == residual.shape):
|
||||
y = y + residual
|
||||
return y
|
||||
|
||||
|
||||
class Upsampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
z_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: tuple[int, ...],
|
||||
num_res_blocks: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.block_out_channels = block_out_channels
|
||||
self.z_channels = z_channels
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_shortcut=False,
|
||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = RMS_norm(ch)
|
||||
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
|
||||
|
||||
def forward(self, z):
|
||||
"""
|
||||
Args:
|
||||
z: (B, C, T, H, W)
|
||||
target_shape: (H, W)
|
||||
"""
|
||||
# z to block_in
|
||||
repeats = self.block_out_channels[0] // (self.z_channels)
|
||||
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
|
||||
|
||||
# upsampling
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
|
||||
out = self.conv_out(F.silu(self.norm_out(x)))
|
||||
return out
|
||||
|
||||
UPSAMPLERS = {
|
||||
"720p": SRModel3DV2,
|
||||
"1080p": Upsampler,
|
||||
}
|
||||
|
||||
class HunyuanVideo15SRModel():
|
||||
def __init__(self, model_type, config):
|
||||
self.load_device = comfy.model_management.vae_device()
|
||||
offload_device = comfy.model_management.vae_offload_device()
|
||||
self.dtype = comfy.model_management.vae_dtype(self.load_device)
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def resample_latent(self, latent):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
return self.model(latent.to(self.load_device))
|
||||
136
comfy/ldm/hunyuan_video/vae.py
Normal file
136
comfy/ldm/hunyuan_video/vae.py
Normal file
@ -0,0 +1,136 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class PixelShuffle2D(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.conv = op(in_dim, out_dim >> 2, 3, 1, 1)
|
||||
self.ratio = (in_dim << 2) // out_dim
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
h2, w2 = h >> 1, w >> 1
|
||||
y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2)
|
||||
r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2)
|
||||
return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2)
|
||||
|
||||
|
||||
class PixelUnshuffle2D(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.conv = op(in_dim, out_dim << 2, 3, 1, 1)
|
||||
self.scale = (out_dim << 2) // in_dim
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
h2, w2 = h << 1, w << 1
|
||||
y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
|
||||
r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
|
||||
return y + r
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, downsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
ch = block_out_channels[0]
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=ops.Conv2d)
|
||||
for j in range(num_res_blocks)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||
stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d)
|
||||
ch = nxt
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
|
||||
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
|
||||
self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x = stage.downsample(x)
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
b, c, h, w = x.shape
|
||||
grp = c // (self.z_channels << 1)
|
||||
skip = x.view(b, c // grp, grp, h, w).mean(2)
|
||||
|
||||
return self.conv_out(F.silu(self.norm_out(x))) + skip
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, upsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
block_out_channels = block_out_channels[::-1]
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=ops.Conv2d)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||
stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d)
|
||||
ch = nxt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
|
||||
self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1)
|
||||
|
||||
def forward(self, z):
|
||||
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x = stage.upsample(x)
|
||||
|
||||
return self.conv_out(F.silu(self.norm_out(x)))
|
||||
313
comfy/ldm/hunyuan_video/vae_refiner.py
Normal file
313
comfy/ldm/hunyuan_video/vae_refiner.py
Normal file
@ -0,0 +1,313 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
|
||||
import comfy.ops
|
||||
import comfy.ldm.models.autoencoder
|
||||
import comfy.model_management
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
shape = (dim, 1, 1, 1)
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.empty(shape))
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
|
||||
|
||||
class DnSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tds, refiner_vae, op):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
||||
assert oc % fct == 0
|
||||
self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
|
||||
self.refiner_vae = refiner_vae
|
||||
|
||||
self.tds = tds
|
||||
self.gs = fct * ic // oc
|
||||
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
r1 = 2 if self.tds else 1
|
||||
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
|
||||
if self.tds and self.refiner_vae and conv_carry_in is None:
|
||||
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
||||
hf = hf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
|
||||
hf = torch.cat([hf, hf], dim=1)
|
||||
|
||||
h = h[:, :, 1:, :, :]
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2)
|
||||
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
|
||||
B, C, T, H, W = xf.shape
|
||||
xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
||||
|
||||
x = x[:, :, 1:, :, :]
|
||||
|
||||
if h.shape[2] == 0:
|
||||
return hf + xf
|
||||
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nf = frms // r1
|
||||
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
|
||||
b, ci, frms, ht, wd = x.shape
|
||||
nf = frms // r1
|
||||
x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
x = x.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = x.shape
|
||||
x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
|
||||
if self.tds and self.refiner_vae and conv_carry_in is None:
|
||||
h = torch.cat([hf, h], dim=2)
|
||||
x = torch.cat([xf, x], dim=2)
|
||||
|
||||
return h + x
|
||||
|
||||
|
||||
class UpSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tus, refiner_vae, op):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
||||
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
|
||||
self.refiner_vae = refiner_vae
|
||||
|
||||
self.tus = tus
|
||||
self.rp = fct * oc // ic
|
||||
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
r1 = 2 if self.tus else 1
|
||||
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
|
||||
if self.tus and self.refiner_vae and conv_carry_in is None:
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
nc = c // (2 * 2)
|
||||
hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
|
||||
hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
hf = hf[:, : hf.shape[1] // 2]
|
||||
|
||||
h = h[:, :, 1:, :, :]
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
|
||||
b, c, f, ht, wd = xf.shape
|
||||
nc = c // (2 * 2)
|
||||
xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
|
||||
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
|
||||
x = x[:, :, 1:, :, :]
|
||||
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
x = x.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = x.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
if self.tus and self.refiner_vae and conv_carry_in is None:
|
||||
h = torch.cat([hf, h], dim=2)
|
||||
x = torch.cat([xf, x], dim=2)
|
||||
|
||||
return h + x
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
|
||||
super().__init__()
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.ffactor_temporal = ffactor_temporal
|
||||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = CarriedConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
norm_op = Normalize
|
||||
|
||||
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
ch = block_out_channels[0]
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
for j in range(num_res_blocks)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
|
||||
ch = nxt
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
self.norm_out = norm_op(ch)
|
||||
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
||||
|
||||
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
|
||||
|
||||
def forward(self, x):
|
||||
if not self.refiner_vae and x.shape[2] == 1:
|
||||
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
|
||||
|
||||
if self.refiner_vae:
|
||||
xl = [x[:, :, :1, :, :]]
|
||||
if x.shape[2] > self.ffactor_temporal:
|
||||
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2)
|
||||
x = xl
|
||||
else:
|
||||
x = [x]
|
||||
out = []
|
||||
|
||||
conv_carry_in = None
|
||||
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
|
||||
x1 = [ x1 ]
|
||||
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
out.append(x1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
|
||||
del out
|
||||
|
||||
b, c, t, h, w = x.shape
|
||||
grp = c // (self.z_channels << 1)
|
||||
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
||||
|
||||
out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip
|
||||
|
||||
if self.refiner_vae:
|
||||
out = self.regul(out)[0]
|
||||
|
||||
return out
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
|
||||
super().__init__()
|
||||
block_out_channels = block_out_channels[::-1]
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = CarriedConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
norm_op = Normalize
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
depth_temporal = (ffactor_temporal >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
|
||||
ch = nxt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = norm_op(ch)
|
||||
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
if self.refiner_vae:
|
||||
x = torch.split(x, 2, dim=2)
|
||||
else:
|
||||
x = [ x ]
|
||||
out = []
|
||||
|
||||
conv_carry_in = None
|
||||
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
x1 = [ F.silu(self.norm_out(x1)) ]
|
||||
x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out)
|
||||
out.append(x1)
|
||||
conv_carry_in = conv_carry_out
|
||||
del x
|
||||
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
|
||||
if not self.refiner_vae:
|
||||
if z.shape[-3] == 1:
|
||||
out = out[:, :, -1:]
|
||||
|
||||
return out
|
||||
|
||||
413
comfy/ldm/kandinsky5/model.py
Normal file
413
comfy/ldm/kandinsky5/model.py
Normal file
@ -0,0 +1,413 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
def attention(q, k, v, heads, transformer_options={}):
|
||||
return optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
heads=heads,
|
||||
skip_reshape=True,
|
||||
transformer_options=transformer_options
|
||||
)
|
||||
|
||||
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||
return torch.addcmul(shift, norm(x), scale + 1.0)
|
||||
|
||||
def apply_gate_sum(x, out, gate):
|
||||
return torch.addcmul(x, gate, out)
|
||||
|
||||
def get_shift_scale_gate(params):
|
||||
shift, scale, gate = torch.chunk(params, 3, dim=-1)
|
||||
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
|
||||
|
||||
def get_freqs(dim, max_period=10000.0):
|
||||
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
|
||||
|
||||
|
||||
class TimeEmbeddings(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
|
||||
super().__init__()
|
||||
assert model_dim % 2 == 0
|
||||
self.model_dim = model_dim
|
||||
self.max_period = max_period
|
||||
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
|
||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
|
||||
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||
return time_embed
|
||||
|
||||
|
||||
class TextEmbeddings(nn.Module):
|
||||
def __init__(self, text_dim, model_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, text_embed):
|
||||
text_embed = self.in_layer(text_embed)
|
||||
return self.norm(text_embed).type_as(text_embed)
|
||||
|
||||
|
||||
class VisualEmbeddings(nn.Module):
|
||||
def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.movedim(1, -1) # B C T H W -> B T H W C
|
||||
B, T, H, W, dim = x.shape
|
||||
pt, ph, pw = self.patch_size
|
||||
|
||||
x = x.view(
|
||||
B,
|
||||
T // pt, pt,
|
||||
H // ph, ph,
|
||||
W // pw, pw,
|
||||
dim,
|
||||
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
|
||||
|
||||
return self.in_layer(x)
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
|
||||
super().__init__()
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(x))
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, num_channels, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
self.head_dim = head_dim
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 2
|
||||
|
||||
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
|
||||
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
return apply_rope1(norm_fn(result), freqs)
|
||||
|
||||
def _forward(self, x, freqs, transformer_options={}):
|
||||
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
|
||||
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def _forward_chunked(self, x, freqs, transformer_options={}):
|
||||
def process_chunks(proj_fn, norm_fn):
|
||||
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
|
||||
chunks = []
|
||||
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
|
||||
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
|
||||
return torch.cat(chunks, dim=1)
|
||||
|
||||
q = process_chunks(self.to_query, self.query_norm)
|
||||
k = process_chunks(self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
|
||||
else:
|
||||
return self._forward(x, freqs, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class CrossAttention(SelfAttention):
|
||||
def get_qkv(self, x, context):
|
||||
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
return q, k, v
|
||||
|
||||
def forward(self, x, context, transformer_options={}):
|
||||
q, k, v = self.get_qkv(x, context)
|
||||
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, ff_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.activation = nn.GELU()
|
||||
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 4
|
||||
|
||||
def _forward(self, x):
|
||||
return self.out_layer(self.activation(self.in_layer(x)))
|
||||
|
||||
def _forward_chunked(self, x):
|
||||
chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||
output_chunks = []
|
||||
for chunk in chunks:
|
||||
output_chunks.append(self._forward(chunk))
|
||||
return torch.cat(output_chunks, dim=1)
|
||||
|
||||
def forward(self, x):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
class OutLayer(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
|
||||
operations = operation_settings.get("operations")
|
||||
self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, visual_embed, time_embed):
|
||||
B, T, H, W, _ = visual_embed.shape
|
||||
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
|
||||
scale = scale[:, None, None, None, :]
|
||||
shift = shift[:, None, None, None, :]
|
||||
visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
|
||||
x = self.out_layer(visual_embed)
|
||||
|
||||
out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
|
||||
x = x.view(
|
||||
B, T, H, W,
|
||||
out_dim,
|
||||
self.patch_size[0], self.patch_size[1], self.patch_size[2]
|
||||
)
|
||||
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
|
||||
|
||||
|
||||
class TransformerEncoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
|
||||
operations = operation_settings.get("operations")
|
||||
|
||||
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, x, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
|
||||
out = self.self_attention(out, freqs, transformer_options=transformer_options)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
|
||||
out = self.feed_forward(out)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
|
||||
# self attention
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# cross attention
|
||||
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# feed forward
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
|
||||
visual_out = self.feed_forward(visual_out)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
return visual_embed
|
||||
|
||||
|
||||
class Kandinsky5(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
|
||||
model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
|
||||
axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
|
||||
dtype=None, device=None, operations=None, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
head_dim = sum(axes_dims)
|
||||
self.rope_scale_factor = rope_scale_factor
|
||||
self.in_visual_dim = in_visual_dim
|
||||
self.model_dim = model_dim
|
||||
self.patch_size = patch_size
|
||||
self.visual_embed_dim = visual_embed_dim
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
|
||||
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
|
||||
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
|
||||
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
|
||||
|
||||
self.text_transformer_blocks = nn.ModuleList(
|
||||
[TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
|
||||
)
|
||||
|
||||
self.visual_transformer_blocks = nn.ModuleList(
|
||||
[TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
|
||||
)
|
||||
|
||||
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
|
||||
|
||||
self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
|
||||
self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
|
||||
|
||||
def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
|
||||
steps = seq_len if steps is None else steps
|
||||
seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
|
||||
seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
|
||||
freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
|
||||
if steps_t is None:
|
||||
steps_t = t_len
|
||||
if steps_h is None:
|
||||
steps_h = h_len
|
||||
if steps_w is None:
|
||||
steps_w = w_len
|
||||
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
t_start += rope_options.get("shift_t", 0.0)
|
||||
h_start += rope_options.get("shift_y", 0.0)
|
||||
w_start += rope_options.get("shift_x", 0.0)
|
||||
else:
|
||||
rope_scale_factor = self.rope_scale_factor
|
||||
if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
|
||||
if h * w >= 14080:
|
||||
rope_scale_factor = (1.0, 3.16, 3.16)
|
||||
|
||||
t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
|
||||
h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
|
||||
w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
|
||||
|
||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
context = self.text_embeddings(context)
|
||||
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
|
||||
|
||||
for block in self.text_transformer_blocks:
|
||||
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = self.visual_embeddings(x)
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_embed = visual_embed.flatten(1, -2)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.visual_transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
||||
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
|
||||
else:
|
||||
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||
return self.out_layer(visual_embed, time_embed)
|
||||
|
||||
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
original_dims = x.ndim
|
||||
if original_dims == 4:
|
||||
x = x.unsqueeze(2)
|
||||
bs, c, t_len, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if time_dim_replace is not None:
|
||||
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
|
||||
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
|
||||
|
||||
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
|
||||
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
||||
if original_dims == 4:
|
||||
out = out.squeeze(2)
|
||||
return out
|
||||
|
||||
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)
|
||||
913
comfy/ldm/lightricks/av_model.py
Normal file
913
comfy/ldm/lightricks/av_model.py
Normal file
@ -0,0 +1,913 @@
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.lightricks.model import (
|
||||
CrossAttention,
|
||||
FeedForward,
|
||||
AdaLayerNormSingle,
|
||||
PixArtAlphaTextProjection,
|
||||
LTXVModel,
|
||||
)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
"""
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
||||
"""
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
|
||||
# All patches in a frame are identical, so we only keep the first one
|
||||
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
|
||||
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
|
||||
else:
|
||||
# Not divisible or too small - store directly without compression
|
||||
self.patches_per_frame = 1
|
||||
self.num_frames = num_tokens
|
||||
self.data = tensor
|
||||
|
||||
def expand(self):
|
||||
"""Expand back to original tensor."""
|
||||
if self.patches_per_frame == 1:
|
||||
return self.data
|
||||
|
||||
# [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim]
|
||||
expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim)
|
||||
return expanded.reshape(self.batch_size, -1, self.feature_dim)
|
||||
|
||||
def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)):
|
||||
"""Compute ada values on compressed per-frame data, then expand spatially."""
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
# No compression - compute directly
|
||||
if self.patches_per_frame == 1:
|
||||
num_tokens = self.data.shape[1]
|
||||
dim_per_param = self.feature_dim // num_ada_params
|
||||
reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :]
|
||||
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype)
|
||||
ada_values = (table_values + reshaped).unbind(dim=2)
|
||||
return ada_values
|
||||
|
||||
# Compressed: compute on per-frame data then expand spatially
|
||||
# Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param]
|
||||
frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :]
|
||||
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(
|
||||
device=self.data.device, dtype=self.data.dtype
|
||||
)
|
||||
frame_ada = (table_values + frame_reshaped).unbind(dim=2)
|
||||
|
||||
# Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim]
|
||||
return tuple(
|
||||
frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1)
|
||||
.reshape(batch_size, -1, frame_val.shape[-1])
|
||||
for frame_val in frame_ada
|
||||
)
|
||||
|
||||
class BasicAVTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
v_dim,
|
||||
a_dim,
|
||||
v_heads,
|
||||
a_heads,
|
||||
vd_head,
|
||||
ad_head,
|
||||
v_context_dim=None,
|
||||
a_context_dim=None,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
heads=v_heads,
|
||||
dim_head=vd_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.audio_attn1 = CrossAttention(
|
||||
query_dim=a_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
context_dim=v_context_dim,
|
||||
heads=v_heads,
|
||||
dim_head=vd_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.audio_attn2 = CrossAttention(
|
||||
query_dim=a_dim,
|
||||
context_dim=a_context_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Q: Video, K,V: Audio
|
||||
self.audio_to_video_attn = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
context_dim=a_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Q: Audio, K,V: Video
|
||||
self.video_to_audio_attn = CrossAttention(
|
||||
query_dim=a_dim,
|
||||
context_dim=v_dim,
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.ff = FeedForward(
|
||||
v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.audio_ff = FeedForward(
|
||||
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
||||
self.audio_scale_shift_table = nn.Parameter(
|
||||
torch.empty(6, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
self.scale_shift_table_a2v_ca_video = nn.Parameter(
|
||||
torch.empty(5, v_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
def get_ada_values(
|
||||
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
||||
):
|
||||
if isinstance(timestep, CompressedTimestep):
|
||||
return timestep.expand_for_computation(scale_shift_table, batch_size, indices)
|
||||
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
ada_values = (
|
||||
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
|
||||
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
|
||||
).unbind(dim=2)
|
||||
return ada_values
|
||||
|
||||
def get_av_ca_ada_values(
|
||||
self,
|
||||
scale_shift_table: torch.Tensor,
|
||||
batch_size: int,
|
||||
scale_shift_timestep: torch.Tensor,
|
||||
gate_timestep: torch.Tensor,
|
||||
num_scale_shift_values: int = 4,
|
||||
):
|
||||
scale_shift_ada_values = self.get_ada_values(
|
||||
scale_shift_table[:num_scale_shift_values, :],
|
||||
batch_size,
|
||||
scale_shift_timestep,
|
||||
)
|
||||
gate_ada_values = self.get_ada_values(
|
||||
scale_shift_table[num_scale_shift_values:, :],
|
||||
batch_size,
|
||||
gate_timestep,
|
||||
)
|
||||
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tuple[torch.Tensor, torch.Tensor],
|
||||
v_context=None,
|
||||
a_context=None,
|
||||
attention_mask=None,
|
||||
v_timestep=None,
|
||||
a_timestep=None,
|
||||
v_pe=None,
|
||||
a_pe=None,
|
||||
v_cross_pe=None,
|
||||
a_cross_pe=None,
|
||||
v_cross_scale_shift_timestep=None,
|
||||
a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None,
|
||||
a_cross_gate_timestep=None,
|
||||
transformer_options=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
|
||||
vx, ax = x
|
||||
run_ax = run_ax and ax.numel() > 0
|
||||
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||
|
||||
if run_vx:
|
||||
vshift_msa, vscale_msa, vgate_msa = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
||||
vx += self.attn2(
|
||||
comfy.ldm.common_dit.rms_norm(vx),
|
||||
context=v_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del vshift_msa, vscale_msa, vgate_msa
|
||||
|
||||
if run_ax:
|
||||
ashift_msa, ascale_msa, agate_msa = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||
ax += (
|
||||
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
* agate_msa
|
||||
)
|
||||
ax += self.audio_attn2(
|
||||
comfy.ldm.common_dit.rms_norm(ax),
|
||||
context=a_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del ashift_msa, ascale_msa, agate_msa
|
||||
|
||||
# Audio - Video cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
# norm3
|
||||
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||
|
||||
(
|
||||
scale_ca_audio_hidden_states_a2v,
|
||||
shift_ca_audio_hidden_states_a2v,
|
||||
scale_ca_audio_hidden_states_v2a,
|
||||
shift_ca_audio_hidden_states_v2a,
|
||||
gate_out_v2a,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
a_cross_scale_shift_timestep,
|
||||
a_cross_gate_timestep,
|
||||
)
|
||||
|
||||
(
|
||||
scale_ca_video_hidden_states_a2v,
|
||||
shift_ca_video_hidden_states_a2v,
|
||||
scale_ca_video_hidden_states_v2a,
|
||||
shift_ca_video_hidden_states_v2a,
|
||||
gate_out_a2v,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
v_cross_scale_shift_timestep,
|
||||
v_cross_gate_timestep,
|
||||
)
|
||||
|
||||
if run_a2v:
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
||||
+ shift_ca_video_hidden_states_a2v
|
||||
)
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
||||
+ shift_ca_audio_hidden_states_a2v
|
||||
)
|
||||
vx += (
|
||||
self.audio_to_video_attn(
|
||||
vx_scaled,
|
||||
context=ax_scaled,
|
||||
pe=v_cross_pe,
|
||||
k_pe=a_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_a2v
|
||||
)
|
||||
|
||||
del gate_out_a2v
|
||||
del scale_ca_video_hidden_states_a2v,\
|
||||
shift_ca_video_hidden_states_a2v,\
|
||||
scale_ca_audio_hidden_states_a2v,\
|
||||
shift_ca_audio_hidden_states_a2v,\
|
||||
|
||||
if run_v2a:
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
||||
+ shift_ca_audio_hidden_states_v2a
|
||||
)
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
||||
+ shift_ca_video_hidden_states_v2a
|
||||
)
|
||||
ax += (
|
||||
self.video_to_audio_attn(
|
||||
ax_scaled,
|
||||
context=vx_scaled,
|
||||
pe=a_cross_pe,
|
||||
k_pe=v_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_v2a
|
||||
)
|
||||
|
||||
del gate_out_v2a
|
||||
del scale_ca_video_hidden_states_v2a,\
|
||||
shift_ca_video_hidden_states_v2a,\
|
||||
scale_ca_audio_hidden_states_v2a,\
|
||||
shift_ca_audio_hidden_states_v2a
|
||||
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||
vx += self.ff(vx_scaled) * vgate_mlp
|
||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp, agate_mlp = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||
ax += self.audio_ff(ax_scaled) * agate_mlp
|
||||
|
||||
del ashift_mlp, ascale_mlp, agate_mlp
|
||||
|
||||
|
||||
return vx, ax
|
||||
|
||||
|
||||
class LTXAVModel(LTXVModel):
|
||||
"""LTXAV model for audio-video generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=128,
|
||||
audio_in_channels=128,
|
||||
cross_attention_dim=4096,
|
||||
audio_cross_attention_dim=2048,
|
||||
attention_head_dim=128,
|
||||
audio_attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
audio_num_attention_heads=32,
|
||||
caption_channels=3840,
|
||||
num_layers=48,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
audio_positional_embedding_max_pos=[20],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier=1000.0,
|
||||
av_ca_timestep_scale_multiplier=1.0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Store audio-specific parameters
|
||||
self.audio_in_channels = audio_in_channels
|
||||
self.audio_cross_attention_dim = audio_cross_attention_dim
|
||||
self.audio_attention_head_dim = audio_attention_head_dim
|
||||
self.audio_num_attention_heads = audio_num_attention_heads
|
||||
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||
|
||||
# Calculate audio dimensions
|
||||
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||
self.audio_out_channels = audio_in_channels
|
||||
|
||||
# Audio-specific constants
|
||||
self.num_audio_channels = 8
|
||||
self.audio_frequency_bins = 16
|
||||
|
||||
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
||||
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
caption_channels=caption_channels,
|
||||
num_layers=num_layers,
|
||||
positional_embedding_theta=positional_embedding_theta,
|
||||
positional_embedding_max_pos=positional_embedding_max_pos,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize LTXAV-specific components."""
|
||||
# Audio-specific projections
|
||||
self.audio_patchify_proj = self.operations.Linear(
|
||||
self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Audio-specific AdaLN
|
||||
self.audio_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
use_additional_conditions=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
num_scale_shift_values = 4
|
||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
use_additional_conditions=False,
|
||||
embedding_coefficient=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
# Audio caption projection
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks for LTXAV."""
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicAVTransformerBlock(
|
||||
v_dim=self.inner_dim,
|
||||
a_dim=self.audio_inner_dim,
|
||||
v_heads=self.num_attention_heads,
|
||||
a_heads=self.audio_num_attention_heads,
|
||||
vd_head=self.attention_head_dim,
|
||||
ad_head=self.audio_attention_head_dim,
|
||||
v_context_dim=self.cross_attention_dim,
|
||||
a_context_dim=self.audio_cross_attention_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def _init_output_components(self, device, dtype):
|
||||
"""Initialize output components for LTXAV."""
|
||||
# Video output components
|
||||
super()._init_output_components(device, dtype)
|
||||
# Audio output components
|
||||
self.audio_scale_shift_table = nn.Parameter(
|
||||
torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device)
|
||||
)
|
||||
self.audio_norm_out = self.operations.LayerNorm(
|
||||
self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||
)
|
||||
self.audio_proj_out = self.operations.Linear(
|
||||
self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device
|
||||
)
|
||||
self.a_patchifier = AudioPatchifier(1, start_end=True)
|
||||
|
||||
def separate_audio_and_video_latents(self, x, audio_length):
|
||||
"""Separate audio and video latents from combined input."""
|
||||
# vx = x[:, : self.in_channels]
|
||||
# ax = x[:, self.in_channels :]
|
||||
#
|
||||
# ax = ax.reshape(ax.shape[0], -1)
|
||||
# ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins]
|
||||
#
|
||||
# ax = ax.reshape(
|
||||
# ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins
|
||||
# )
|
||||
|
||||
vx = x[0]
|
||||
ax = x[1] if len(x) > 1 else torch.zeros(
|
||||
(vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins),
|
||||
device=vx.device, dtype=vx.dtype
|
||||
)
|
||||
return vx, ax
|
||||
|
||||
def recombine_audio_and_video_latents(self, vx, ax, target_shape=None):
|
||||
if ax.numel() == 0:
|
||||
return vx
|
||||
else:
|
||||
return [vx, ax]
|
||||
"""Recombine audio and video latents for output."""
|
||||
# if ax.device != vx.device or ax.dtype != vx.dtype:
|
||||
# logging.warning("Audio and video latents are on different devices or dtypes.")
|
||||
# ax = ax.to(device=vx.device, dtype=vx.dtype)
|
||||
# logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}")
|
||||
#
|
||||
# ax = ax.reshape(ax.shape[0], -1)
|
||||
# # pad to f x h x w of the video latents
|
||||
# divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3]
|
||||
# if target_shape is None:
|
||||
# repetitions = math.ceil(ax.shape[-1] / divisor)
|
||||
# else:
|
||||
# repetitions = target_shape[1] - vx.shape[1]
|
||||
# padded_len = repetitions * divisor
|
||||
# ax = F.pad(ax, (0, padded_len - ax.shape[-1]))
|
||||
# ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1])
|
||||
# return torch.cat([vx, ax], dim=1)
|
||||
|
||||
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||
"""Process input for LTXAV - separate audio and video, then patchify."""
|
||||
audio_length = kwargs.get("audio_length", 0)
|
||||
# Separate audio and video latents
|
||||
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||
)
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
|
||||
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
|
||||
|
||||
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||
"""Prepare timestep embeddings."""
|
||||
# TODO: some code reuse is needed here.
|
||||
grid_mask = kwargs.get("grid_mask", None)
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
|
||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||
timestep_scaled.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
v_patches_per_frame = None
|
||||
if orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
# Reshape to [batch_size, num_tokens, dim] and compress for storage
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
if a_timestep is not None:
|
||||
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||
a_timestep_flat = a_timestep_scaled.flatten()
|
||||
timestep_flat = timestep_scaled.flatten()
|
||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
||||
|
||||
# Cross-attention timesteps - compress these too
|
||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
a_timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
timestep_flat * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
a_timestep_flat * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||
]
|
||||
|
||||
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
||||
a_timestep_flat,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
# Audio timesteps
|
||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||
else:
|
||||
a_timestep = timestep_scaled
|
||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||
cross_av_timestep_ss = []
|
||||
|
||||
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
||||
v_embedded_timestep,
|
||||
a_embedded_timestep,
|
||||
]
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
v_context, a_context = torch.split(
|
||||
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
||||
)
|
||||
|
||||
v_context, attention_mask = super()._prepare_context(
|
||||
v_context, batch_size, vx, attention_mask
|
||||
)
|
||||
if self.audio_caption_projection is not None:
|
||||
a_context = self.audio_caption_projection(a_context)
|
||||
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
||||
|
||||
return [v_context, a_context], attention_mask
|
||||
|
||||
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
||||
v_pixel_coords = pixel_coords[0]
|
||||
v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype)
|
||||
|
||||
a_latent_coords = pixel_coords[1]
|
||||
a_pe = self._precompute_freqs_cis(
|
||||
a_latent_coords,
|
||||
dim=self.audio_inner_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=self.audio_positional_embedding_max_pos,
|
||||
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
)
|
||||
|
||||
# calculate positional embeddings for the middle of the token duration, to use in av cross attention layers.
|
||||
max_pos = max(
|
||||
self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]
|
||||
)
|
||||
v_pixel_coords = v_pixel_coords.to(torch.float32)
|
||||
v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate)
|
||||
av_cross_video_freq_cis = self._precompute_freqs_cis(
|
||||
v_pixel_coords[:, 0:1, :],
|
||||
dim=self.audio_cross_attention_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=[max_pos],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
)
|
||||
av_cross_audio_freq_cis = self._precompute_freqs_cis(
|
||||
a_latent_coords[:, 0:1, :],
|
||||
dim=self.audio_cross_attention_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=[max_pos],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
)
|
||||
|
||||
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||
|
||||
def _process_transformer_blocks(
|
||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
||||
):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
v_context = context[0]
|
||||
a_context = context[1]
|
||||
v_timestep = timestep[0]
|
||||
a_timestep = timestep[1]
|
||||
v_pe, av_cross_video_freq_cis = pe[0]
|
||||
a_pe, av_cross_audio_freq_cis = pe[1]
|
||||
|
||||
(
|
||||
av_ca_audio_scale_shift_timestep,
|
||||
av_ca_video_scale_shift_timestep,
|
||||
av_ca_a2v_gate_noise_timestep,
|
||||
av_ca_v2a_gate_noise_timestep,
|
||||
) = timestep[2]
|
||||
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
# Process transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(
|
||||
args["img"],
|
||||
v_context=args["v_context"],
|
||||
a_context=args["a_context"],
|
||||
attention_mask=args["attention_mask"],
|
||||
v_timestep=args["v_timestep"],
|
||||
a_timestep=args["a_timestep"],
|
||||
v_pe=args["v_pe"],
|
||||
a_pe=args["a_pe"],
|
||||
v_cross_pe=args["v_cross_pe"],
|
||||
a_cross_pe=args["a_cross_pe"],
|
||||
v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"],
|
||||
a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"],
|
||||
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||
transformer_options=args["transformer_options"],
|
||||
)
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)](
|
||||
{
|
||||
"img": (vx, ax),
|
||||
"v_context": v_context,
|
||||
"a_context": a_context,
|
||||
"attention_mask": attention_mask,
|
||||
"v_timestep": v_timestep,
|
||||
"a_timestep": a_timestep,
|
||||
"v_pe": v_pe,
|
||||
"a_pe": a_pe,
|
||||
"v_cross_pe": av_cross_video_freq_cis,
|
||||
"a_cross_pe": av_cross_audio_freq_cis,
|
||||
"v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep,
|
||||
"a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep,
|
||||
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||
"transformer_options": transformer_options,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
vx, ax = out["img"]
|
||||
else:
|
||||
vx, ax = block(
|
||||
(vx, ax),
|
||||
v_context=v_context,
|
||||
a_context=a_context,
|
||||
attention_mask=attention_mask,
|
||||
v_timestep=v_timestep,
|
||||
a_timestep=a_timestep,
|
||||
v_pe=v_pe,
|
||||
a_pe=a_pe,
|
||||
v_cross_pe=av_cross_video_freq_cis,
|
||||
a_cross_pe=av_cross_audio_freq_cis,
|
||||
v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep,
|
||||
a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep,
|
||||
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
v_embedded_timestep = embedded_timestep[0]
|
||||
a_embedded_timestep = embedded_timestep[1]
|
||||
|
||||
# Expand compressed video timestep if needed
|
||||
if isinstance(v_embedded_timestep, CompressedTimestep):
|
||||
v_embedded_timestep = v_embedded_timestep.expand()
|
||||
|
||||
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
||||
|
||||
# Process audio output
|
||||
a_scale_shift_values = (
|
||||
self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype)
|
||||
+ a_embedded_timestep[:, :, None]
|
||||
)
|
||||
a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1]
|
||||
|
||||
ax = self.audio_norm_out(ax)
|
||||
ax = ax * (1 + a_scale) + a_shift
|
||||
ax = self.audio_proj_out(ax)
|
||||
|
||||
# Unpatchify audio
|
||||
ax = self.a_patchifier.unpatchify(
|
||||
ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins
|
||||
)
|
||||
|
||||
# Recombine audio and video
|
||||
original_shape = kwargs.get("av_orig_shape")
|
||||
return self.recombine_audio_and_video_latents(vx, ax, original_shape)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
context,
|
||||
attention_mask=None,
|
||||
frame_rate=25,
|
||||
transformer_options={},
|
||||
keyframe_idxs=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass for LTXAV model.
|
||||
|
||||
Args:
|
||||
x: Combined audio-video input tensor
|
||||
timestep: Tuple of (video_timestep, audio_timestep) or single timestep
|
||||
context: Context tensor (e.g., text embeddings)
|
||||
attention_mask: Attention mask tensor
|
||||
frame_rate: Frame rate for temporal processing
|
||||
transformer_options: Additional options for transformer blocks
|
||||
keyframe_idxs: Keyframe indices for temporal processing
|
||||
**kwargs: Additional keyword arguments including audio_length
|
||||
|
||||
Returns:
|
||||
Combined audio-video output tensor
|
||||
"""
|
||||
# Handle timestep format
|
||||
if isinstance(timestep, (tuple, list)) and len(timestep) == 2:
|
||||
v_timestep, a_timestep = timestep
|
||||
kwargs["a_timestep"] = a_timestep
|
||||
timestep = v_timestep
|
||||
else:
|
||||
kwargs["a_timestep"] = timestep
|
||||
|
||||
# Call parent forward method
|
||||
return super().forward(
|
||||
x,
|
||||
timestep,
|
||||
context,
|
||||
attention_mask,
|
||||
frame_rate,
|
||||
transformer_options,
|
||||
keyframe_idxs,
|
||||
**kwargs,
|
||||
)
|
||||
305
comfy/ldm/lightricks/embeddings_connector.py
Normal file
305
comfy/ldm/lightricks/embeddings_connector.py
Normal file
@ -0,0 +1,305 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
import torch
|
||||
from comfy.ldm.lightricks.model import (
|
||||
CrossAttention,
|
||||
FeedForward,
|
||||
generate_freq_grid_np,
|
||||
interleaved_freqs_cis,
|
||||
split_freqs_cis,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
|
||||
class BasicTransformerBlock1D(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers.
|
||||
qk_norm (`str`, *optional*, defaults to None):
|
||||
Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`.
|
||||
ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer.
|
||||
attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer.
|
||||
use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE).
|
||||
ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
context_dim=None,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=None,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dim_out=dim,
|
||||
glu=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor:
|
||||
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
|
||||
# 1. Normalization Before Self-Attention
|
||||
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||
|
||||
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||
|
||||
# 2. Self-Attention
|
||||
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 3. Normalization before Feed-Forward
|
||||
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Embeddings1DConnector(nn.Module):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=128,
|
||||
num_attention_heads=30,
|
||||
num_layers=2,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[4096],
|
||||
causal_temporal_positioning=False,
|
||||
num_learnable_registers: Optional[int] = 128,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
split_rope=False,
|
||||
double_precision_rope=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.positional_embedding_max_pos = positional_embedding_max_pos
|
||||
self.split_rope = split_rope
|
||||
self.double_precision_rope = double_precision_rope
|
||||
self.transformer_1d_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock1D(
|
||||
self.inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
if self.num_learnable_registers:
|
||||
self.learnable_registers = nn.Parameter(
|
||||
torch.rand(
|
||||
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
||||
)
|
||||
* 2.0
|
||||
- 1.0
|
||||
)
|
||||
|
||||
def get_fractional_positions(self, indices_grid):
|
||||
fractional_positions = torch.stack(
|
||||
[
|
||||
indices_grid[:, i] / self.positional_embedding_max_pos[i]
|
||||
for i in range(1)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return fractional_positions
|
||||
|
||||
def precompute_freqs(self, indices_grid, spacing):
|
||||
source_dtype = indices_grid.dtype
|
||||
dtype = (
|
||||
torch.float32
|
||||
if source_dtype in (torch.bfloat16, torch.float16)
|
||||
else source_dtype
|
||||
)
|
||||
|
||||
fractional_positions = self.get_fractional_positions(indices_grid)
|
||||
indices = (
|
||||
generate_freq_grid_np(
|
||||
self.positional_embedding_theta,
|
||||
indices_grid.shape[1],
|
||||
self.inner_dim,
|
||||
)
|
||||
if self.double_precision_rope
|
||||
else self.generate_freq_grid(spacing, dtype, fractional_positions.device)
|
||||
).to(device=fractional_positions.device)
|
||||
|
||||
if spacing == "exp_2":
|
||||
freqs = (
|
||||
(indices * fractional_positions.unsqueeze(-1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
else:
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
return freqs
|
||||
|
||||
def generate_freq_grid(self, spacing, dtype, device):
|
||||
dim = self.inner_dim
|
||||
theta = self.positional_embedding_theta
|
||||
n_pos_dims = 1
|
||||
n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6
|
||||
start = 1
|
||||
end = theta
|
||||
|
||||
if spacing == "exp":
|
||||
indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem))
|
||||
indices = indices.to(dtype=dtype, device=device)
|
||||
elif spacing == "exp_2":
|
||||
indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim)
|
||||
indices = indices.to(dtype=dtype)
|
||||
elif spacing == "linear":
|
||||
indices = torch.linspace(
|
||||
start, end, dim // n_elem, device=device, dtype=dtype
|
||||
)
|
||||
elif spacing == "sqrt":
|
||||
indices = torch.linspace(
|
||||
start**2, end**2, dim // n_elem, device=device, dtype=dtype
|
||||
).sqrt()
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
return indices
|
||||
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
||||
dim = self.inner_dim
|
||||
n_elem = 2 # 2 because of cos and sin
|
||||
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||
if self.split_rope:
|
||||
expected_freqs = dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
cos_freq, sin_freq = split_freqs_cis(
|
||||
freqs, pad_size, self.num_attention_heads
|
||||
)
|
||||
else:
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
The [`Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input `hidden_states`.
|
||||
indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
|
||||
attention_mask ( `torch.Tensor`, *optional*):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
|
||||
if self.num_learnable_registers:
|
||||
num_registers_duplications = math.ceil(
|
||||
max(1024, hidden_states.shape[1]) / self.num_learnable_registers
|
||||
)
|
||||
learnable_registers = torch.tile(
|
||||
self.learnable_registers.to(hidden_states), (num_registers_duplications, 1)
|
||||
)
|
||||
|
||||
hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
|
||||
indices_grid = torch.arange(
|
||||
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
indices_grid = indices_grid[None, None, :]
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
||||
|
||||
# 2. Blocks
|
||||
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states, attention_mask=attention_mask, pe=freqs_cis
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
# if self.output_scale is not None:
|
||||
# hidden_states = hidden_states / self.output_scale
|
||||
|
||||
hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
292
comfy/ldm/lightricks/latent_upsampler.py
Normal file
292
comfy/ldm/lightricks/latent_upsampler.py
Normal file
@ -0,0 +1,292 @@
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
|
||||
if float(scale) not in mapping:
|
||||
raise ValueError(
|
||||
f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}"
|
||||
)
|
||||
return mapping[float(scale)]
|
||||
|
||||
|
||||
class PixelShuffleND(nn.Module):
|
||||
def __init__(self, dims, upscale_factors=(2, 2, 2)):
|
||||
super().__init__()
|
||||
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
|
||||
self.dims = dims
|
||||
self.upscale_factors = upscale_factors
|
||||
|
||||
def forward(self, x):
|
||||
if self.dims == 3:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.upscale_factors[0],
|
||||
p2=self.upscale_factors[1],
|
||||
p3=self.upscale_factors[2],
|
||||
)
|
||||
elif self.dims == 2:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||
p1=self.upscale_factors[0],
|
||||
p2=self.upscale_factors[1],
|
||||
)
|
||||
elif self.dims == 1:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1) f h w -> b c (f p1) h w",
|
||||
p1=self.upscale_factors[0],
|
||||
)
|
||||
|
||||
|
||||
class BlurDownsample(nn.Module):
|
||||
"""
|
||||
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
|
||||
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, stride: int):
|
||||
super().__init__()
|
||||
assert dims in (2, 3)
|
||||
assert stride >= 1 and isinstance(stride, int)
|
||||
self.dims = dims
|
||||
self.stride = stride
|
||||
|
||||
# 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized
|
||||
k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0])
|
||||
k2d = k[:, None] @ k[None, :]
|
||||
k2d = (k2d / k2d.sum()).float() # shape (5,5)
|
||||
self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.stride == 1:
|
||||
return x
|
||||
|
||||
def _apply_2d(x2d: torch.Tensor) -> torch.Tensor:
|
||||
# x2d: (B, C, H, W)
|
||||
B, C, H, W = x2d.shape
|
||||
weight = self.kernel.expand(C, 1, 5, 5) # depthwise
|
||||
x2d = F.conv2d(
|
||||
x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C
|
||||
)
|
||||
return x2d
|
||||
|
||||
if self.dims == 2:
|
||||
return _apply_2d(x)
|
||||
else:
|
||||
# dims == 3: apply per-frame on H,W
|
||||
b, c, f, h, w = x.shape
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = _apply_2d(x)
|
||||
h2, w2 = x.shape[-2:]
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
|
||||
return x
|
||||
|
||||
|
||||
class SpatialRationalResampler(nn.Module):
|
||||
"""
|
||||
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
|
||||
downsample by 'den' using fixed blur + stride. Operates on H,W only.
|
||||
|
||||
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels: int, scale: float):
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
self.num, self.den = _rational_for_scale(self.scale)
|
||||
self.conv = nn.Conv2d(
|
||||
mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1
|
||||
)
|
||||
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, c, f, h, w = x.shape
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self.conv(x)
|
||||
x = self.pixel_shuffle(x)
|
||||
x = self.blur_down(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
|
||||
):
|
||||
super().__init__()
|
||||
if mid_channels is None:
|
||||
mid_channels = channels
|
||||
|
||||
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
||||
|
||||
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.norm1 = nn.GroupNorm(32, mid_channels)
|
||||
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||
self.norm2 = nn.GroupNorm(32, channels)
|
||||
self.activation = nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.activation(x)
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = self.activation(x + residual)
|
||||
return x
|
||||
|
||||
|
||||
class LatentUpsampler(nn.Module):
|
||||
"""
|
||||
Model to spatially upsample VAE latents.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): Number of channels in the input latent
|
||||
mid_channels (`int`): Number of channels in the middle layers
|
||||
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||
dims (`int`): Number of dimensions for convolutions (2 or 3)
|
||||
spatial_upsample (`bool`): Whether to spatially upsample the latent
|
||||
temporal_upsample (`bool`): Whether to temporally upsample the latent
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
mid_channels: int = 512,
|
||||
num_blocks_per_stage: int = 4,
|
||||
dims: int = 3,
|
||||
spatial_upsample: bool = True,
|
||||
temporal_upsample: bool = False,
|
||||
spatial_scale: float = 2.0,
|
||||
rational_resampler: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.num_blocks_per_stage = num_blocks_per_stage
|
||||
self.dims = dims
|
||||
self.spatial_upsample = spatial_upsample
|
||||
self.temporal_upsample = temporal_upsample
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
self.rational_resampler = rational_resampler
|
||||
|
||||
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
||||
|
||||
self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.initial_norm = nn.GroupNorm(32, mid_channels)
|
||||
self.initial_activation = nn.SiLU()
|
||||
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
if spatial_upsample and temporal_upsample:
|
||||
self.upsampler = nn.Sequential(
|
||||
nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(3),
|
||||
)
|
||||
elif spatial_upsample:
|
||||
if rational_resampler:
|
||||
self.upsampler = SpatialRationalResampler(
|
||||
mid_channels=mid_channels, scale=self.spatial_scale
|
||||
)
|
||||
else:
|
||||
self.upsampler = nn.Sequential(
|
||||
nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(2),
|
||||
)
|
||||
elif temporal_upsample:
|
||||
self.upsampler = nn.Sequential(
|
||||
nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(1),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either spatial_upsample or temporal_upsample must be True"
|
||||
)
|
||||
|
||||
self.post_upsample_res_blocks = nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
||||
b, c, f, h, w = latent.shape
|
||||
|
||||
if self.dims == 2:
|
||||
x = rearrange(latent, "b c f h w -> (b f) c h w")
|
||||
x = self.initial_conv(x)
|
||||
x = self.initial_norm(x)
|
||||
x = self.initial_activation(x)
|
||||
|
||||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.upsampler(x)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
else:
|
||||
x = self.initial_conv(latent)
|
||||
x = self.initial_norm(x)
|
||||
x = self.initial_activation(x)
|
||||
|
||||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
|
||||
if self.temporal_upsample:
|
||||
x = self.upsampler(x)
|
||||
x = x[:, :, 1:, :, :]
|
||||
else:
|
||||
if isinstance(self.upsampler, SpatialRationalResampler):
|
||||
x = self.upsampler(x)
|
||||
else:
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self.upsampler(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(
|
||||
in_channels=config.get("in_channels", 4),
|
||||
mid_channels=config.get("mid_channels", 128),
|
||||
num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
|
||||
dims=config.get("dims", 2),
|
||||
spatial_upsample=config.get("spatial_upsample", True),
|
||||
temporal_upsample=config.get("temporal_upsample", False),
|
||||
spatial_scale=config.get("spatial_scale", 2.0),
|
||||
rational_resampler=config.get("rational_resampler", False),
|
||||
)
|
||||
|
||||
def config(self):
|
||||
return {
|
||||
"_class_name": "LatentUpsampler",
|
||||
"in_channels": self.in_channels,
|
||||
"mid_channels": self.mid_channels,
|
||||
"num_blocks_per_stage": self.num_blocks_per_stage,
|
||||
"dims": self.dims,
|
||||
"spatial_upsample": self.spatial_upsample,
|
||||
"temporal_upsample": self.temporal_upsample,
|
||||
"spatial_scale": self.spatial_scale,
|
||||
"rational_resampler": self.rational_resampler,
|
||||
}
|
||||
@ -1,13 +1,47 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.ldm.modules.attention
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import functools
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.modules.attention
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
def _log_base(x, base):
|
||||
return np.log(x) / np.log(base)
|
||||
|
||||
class LTXRopeType(str, Enum):
|
||||
INTERLEAVED = "interleaved"
|
||||
SPLIT = "split"
|
||||
|
||||
KEY = "rope_type"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, kwargs, default=None):
|
||||
if default is None:
|
||||
default = cls.INTERLEAVED
|
||||
return cls(kwargs.get(cls.KEY, default))
|
||||
|
||||
|
||||
class LTXFrequenciesPrecision(str, Enum):
|
||||
FLOAT32 = "float32"
|
||||
FLOAT64 = "float64"
|
||||
|
||||
KEY = "frequencies_precision"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, kwargs, default=None):
|
||||
if default is None:
|
||||
default = cls.FLOAT32
|
||||
return cls(kwargs.get(cls.KEY, default))
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
@ -39,9 +73,7 @@ def get_timestep_embedding(
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
@ -73,7 +105,9 @@ class TimestepEmbedding(nn.Module):
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
dtype=None, device=None, operations=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -90,7 +124,9 @@ class TimestepEmbedding(nn.Module):
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(
|
||||
time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
@ -139,12 +175,22 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim,
|
||||
size_emb_dim,
|
||||
use_additional_conditions: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.outdim = size_emb_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
@ -163,15 +209,22 @@ class AdaLayerNormSingle(nn.Module):
|
||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
|
||||
embedding_dim,
|
||||
size_emb_dim=embedding_dim // 3,
|
||||
use_additional_conditions=use_additional_conditions,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -185,6 +238,7 @@ class AdaLayerNormSingle(nn.Module):
|
||||
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
||||
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
||||
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
"""
|
||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||
@ -192,18 +246,24 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = hidden_size
|
||||
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.linear_1 = operations.Linear(
|
||||
in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act_1 = nn.SiLU()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(
|
||||
in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
@ -222,25 +282,28 @@ class GELU_approx(nn.Module):
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
|
||||
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis):
|
||||
cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1]
|
||||
split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False
|
||||
return (
|
||||
apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
||||
if split_pe else
|
||||
apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
||||
)
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
||||
cos_freqs = freqs_cis[0]
|
||||
sin_freqs = freqs_cis[1]
|
||||
|
||||
def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one
|
||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
@ -250,9 +313,37 @@ def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and
|
||||
|
||||
return out
|
||||
|
||||
def apply_split_rotary_emb(input_tensor, cos, sin):
|
||||
needs_reshape = False
|
||||
if input_tensor.ndim != 4 and cos.ndim == 4:
|
||||
B, H, T, _ = cos.shape
|
||||
input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2)
|
||||
needs_reshape = True
|
||||
split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
|
||||
first_half_input = split_input[..., :1, :]
|
||||
second_half_input = split_input[..., 1:, :]
|
||||
output = split_input * cos.unsqueeze(-2)
|
||||
first_half_output = output[..., :1, :]
|
||||
second_half_output = output[..., 1:, :]
|
||||
first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input)
|
||||
second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input)
|
||||
output = rearrange(output, "... d r -> ... (d r)")
|
||||
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
@ -268,9 +359,11 @@ class CrossAttention(nn.Module):
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None, pe=None):
|
||||
def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}):
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
@ -281,149 +374,505 @@ class CrossAttention(nn.Module):
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe)
|
||||
k = apply_rotary_emb(k, pe if k_pe is None else k_pe)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
|
||||
def __init__(
|
||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
attn_precision=self.attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||
x.addcmul_(attn1_input, gate_msa)
|
||||
del attn1_input
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask)
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
y = comfy.ldm.common_dit.rms_norm(x)
|
||||
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||
x.addcmul_(self.ff(y), gate_mlp)
|
||||
|
||||
return x
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
n_pos_dims = indices_grid.shape[1]
|
||||
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
||||
fractional_positions = torch.stack(
|
||||
[
|
||||
indices_grid[:, i] / max_pos[i]
|
||||
for i in range(3)
|
||||
],
|
||||
dim=-1,
|
||||
[indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
|
||||
axis=-1,
|
||||
)
|
||||
return fractional_positions
|
||||
|
||||
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
dtype = torch.float32 #self.dtype
|
||||
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
|
||||
@functools.lru_cache(maxsize=5)
|
||||
def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None):
|
||||
theta = positional_embedding_theta
|
||||
start = 1
|
||||
end = theta
|
||||
device = fractional_positions.device
|
||||
|
||||
n_elem = 2 * positional_embedding_max_pos_count
|
||||
pow_indices = np.power(
|
||||
theta,
|
||||
np.linspace(
|
||||
_log_base(start, theta),
|
||||
_log_base(end, theta),
|
||||
inner_dim // n_elem,
|
||||
dtype=np.float64,
|
||||
),
|
||||
)
|
||||
return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
|
||||
|
||||
def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device):
|
||||
theta = positional_embedding_theta
|
||||
start = 1
|
||||
end = theta
|
||||
n_elem = 2 * positional_embedding_max_pos_count
|
||||
|
||||
indices = theta ** (
|
||||
torch.linspace(
|
||||
math.log(start, theta),
|
||||
math.log(end, theta),
|
||||
dim // 6,
|
||||
inner_dim // n_elem,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
indices = indices.to(dtype=dtype)
|
||||
indices = indices.to(dtype=torch.float32)
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
return indices
|
||||
|
||||
def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid):
|
||||
if use_middle_indices_grid:
|
||||
assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2)
|
||||
indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
|
||||
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
||||
elif len(indices_grid.shape) == 4:
|
||||
indices_grid = indices_grid[..., 0]
|
||||
|
||||
# Get fractional positions and compute frequency indices
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
indices = indices.to(device=fractional_positions.device)
|
||||
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
return freqs
|
||||
|
||||
def interleaved_freqs_cis(freqs, pad_size):
|
||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
if dim % 6 != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
||||
if pad_size != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : pad_size])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size])
|
||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
||||
return cos_freq, sin_freq
|
||||
|
||||
def split_freqs_cis(freqs, pad_size, num_attention_heads):
|
||||
cos_freq = freqs.cos()
|
||||
sin_freq = freqs.sin()
|
||||
|
||||
class LTXVModel(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
if pad_size != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
||||
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
|
||||
|
||||
caption_channels=4096,
|
||||
num_layers=28,
|
||||
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Reshape freqs to be compatible with multi-head attention
|
||||
B , T, half_HD = cos_freq.shape
|
||||
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
||||
sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
||||
|
||||
cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
|
||||
sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
|
||||
return cos_freq, sin_freq
|
||||
|
||||
class LTXBaseModel(torch.nn.Module, ABC):
|
||||
"""
|
||||
Abstract base class for LTX models (Lightricks Transformer models).
|
||||
|
||||
This class defines the common interface and shared functionality for all LTX models,
|
||||
including LTXV (video) and LTXAV (audio-video) variants.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
cross_attention_dim: int,
|
||||
attention_head_dim: int,
|
||||
num_attention_heads: int,
|
||||
caption_channels: int,
|
||||
num_layers: int,
|
||||
positional_embedding_theta: float = 10000.0,
|
||||
positional_embedding_max_pos: list = [20, 2048, 2048],
|
||||
causal_temporal_positioning: bool = False,
|
||||
vae_scale_factors: tuple = (8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.generator = None
|
||||
self.vae_scale_factors = vae_scale_factors
|
||||
self.use_middle_indices_grid = use_middle_indices_grid
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.in_channels = in_channels
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.caption_channels = caption_channels
|
||||
self.num_layers = num_layers
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.positional_embedding_max_pos = positional_embedding_max_pos
|
||||
self.split_positional_embedding = LTXRopeType.from_dict(kwargs)
|
||||
self.freq_grid_generator = (
|
||||
generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64
|
||||
else generate_freq_grid_pytorch
|
||||
)
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.operations = operations
|
||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||
|
||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
# Common dimensions
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.out_channels = in_channels
|
||||
|
||||
# Initialize common components
|
||||
self._init_common_components(device, dtype)
|
||||
|
||||
# Initialize model-specific components
|
||||
self._init_model_components(device, dtype, **kwargs)
|
||||
|
||||
# Initialize transformer blocks
|
||||
self._init_transformer_blocks(device, dtype, **kwargs)
|
||||
|
||||
# Initialize output components
|
||||
self._init_output_components(device, dtype)
|
||||
|
||||
def _init_common_components(self, device, dtype):
|
||||
"""Initialize components common to all LTX models
|
||||
- patchify_proj: Linear projection for patchifying input
|
||||
- adaln_single: AdaLN layer for timestep embedding
|
||||
- caption_projection: Linear projection for caption embedding
|
||||
"""
|
||||
self.patchify_proj = self.operations.Linear(
|
||||
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
)
|
||||
|
||||
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize model-specific components. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _init_output_components(self, device, dtype):
|
||||
"""Initialize output components. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||
"""Process input data. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
||||
"""Process transformer blocks. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
"""Process output data. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||
"""Prepare timestep embeddings."""
|
||||
grid_mask = kwargs.get("grid_mask", None)
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||
|
||||
return timestep, embedded_timestep
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
"""Prepare context for transformer blocks."""
|
||||
if self.caption_projection is not None:
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
|
||||
return context, attention_mask
|
||||
|
||||
def _precompute_freqs_cis(
|
||||
self,
|
||||
indices_grid,
|
||||
dim,
|
||||
out_dtype,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=False,
|
||||
num_attention_heads=32,
|
||||
):
|
||||
split_mode = self.split_positional_embedding == LTXRopeType.SPLIT
|
||||
indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device)
|
||||
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
||||
|
||||
if split_mode:
|
||||
expected_freqs = dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
||||
else:
|
||||
# 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
|
||||
n_elem = 2 * indices_grid.shape[1]
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode
|
||||
|
||||
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
||||
"""Prepare positional embeddings."""
|
||||
fractional_coords = pixel_coords.to(torch.float32)
|
||||
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||
pe = self._precompute_freqs_cis(
|
||||
fractional_coords,
|
||||
dim=self.inner_dim,
|
||||
out_dtype=x_dtype,
|
||||
max_pos=self.positional_embedding_max_pos,
|
||||
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
)
|
||||
return pe
|
||||
|
||||
def _prepare_attention_mask(self, attention_mask, x_dtype):
|
||||
"""Prepare attention mask."""
|
||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||
attention_mask = (attention_mask - 1).to(x_dtype).reshape(
|
||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
) * torch.finfo(x_dtype).max
|
||||
return attention_mask
|
||||
|
||||
def forward(
|
||||
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Forward pass for LTX models.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
timestep: Timestep tensor
|
||||
context: Context tensor (e.g., text embeddings)
|
||||
attention_mask: Attention mask tensor
|
||||
frame_rate: Frame rate for temporal processing
|
||||
transformer_options: Additional options for transformer blocks
|
||||
keyframe_idxs: Keyframe indices for temporal processing
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
Processed output tensor
|
||||
"""
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(
|
||||
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options
|
||||
),
|
||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Internal forward pass for LTX models.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
timestep: Timestep tensor
|
||||
context: Context tensor (e.g., text embeddings)
|
||||
attention_mask: Attention mask tensor
|
||||
frame_rate: Frame rate for temporal processing
|
||||
transformer_options: Additional options for transformer blocks
|
||||
keyframe_idxs: Keyframe indices for temporal processing
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
Processed output tensor
|
||||
"""
|
||||
if isinstance(x, list):
|
||||
input_dtype = x[0].dtype
|
||||
batch_size = x[0].shape[0]
|
||||
else:
|
||||
input_dtype = x.dtype
|
||||
batch_size = x.shape[0]
|
||||
# Process input
|
||||
merged_args = {**transformer_options, **kwargs}
|
||||
x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args)
|
||||
merged_args.update(additional_args)
|
||||
|
||||
# Prepare timestep and context
|
||||
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
||||
|
||||
# Prepare attention mask and positional embeddings
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
||||
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
||||
|
||||
# Process transformer blocks
|
||||
x = self._process_transformer_blocks(
|
||||
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
||||
)
|
||||
|
||||
# Process output
|
||||
x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args)
|
||||
return x
|
||||
|
||||
|
||||
class LTXVModel(LTXBaseModel):
|
||||
"""LTXV model for video generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=128,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=32,
|
||||
caption_channels=4096,
|
||||
num_layers=28,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
caption_channels=caption_channels,
|
||||
num_layers=num_layers,
|
||||
positional_embedding_theta=positional_embedding_theta,
|
||||
positional_embedding_max_pos=positional_embedding_max_pos,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize LTXV-specific components."""
|
||||
# No additional components needed for LTXV beyond base class
|
||||
pass
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks for LTXV."""
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
# attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
self.num_attention_heads,
|
||||
self.attention_head_dim,
|
||||
context_dim=self.cross_attention_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def _init_output_components(self, device, dtype):
|
||||
"""Initialize output components for LTXV."""
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
||||
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
self.norm_out = self.operations.LayerNorm(
|
||||
self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||
)
|
||||
self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
||||
self.patchifier = SymmetricPatchifier(1, start_end=True)
|
||||
|
||||
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
||||
"""Process input for LTXV."""
|
||||
additional_args = {"orig_shape": list(x.shape)}
|
||||
x, latent_coords = self.patchifier.patchify(x)
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords=latent_coords,
|
||||
@ -431,50 +880,36 @@ class LTXVModel(torch.nn.Module):
|
||||
causal_fix=self.causal_temporal_positioning,
|
||||
)
|
||||
|
||||
grid_mask = None
|
||||
if keyframe_idxs is not None:
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
|
||||
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
||||
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
||||
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
||||
additional_args.update({"grid_mask": grid_mask})
|
||||
x = x[:, grid_mask, :]
|
||||
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
||||
|
||||
fractional_coords = pixel_coords.to(torch.float32)
|
||||
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
|
||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
||||
|
||||
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=x.dtype,
|
||||
)
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(
|
||||
batch_size, -1, embedded_timestep.shape[-1]
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = x.shape[0]
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(
|
||||
batch_size, -1, x.shape[-1]
|
||||
)
|
||||
return x, pixel_coords, additional_args
|
||||
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
||||
"""Process transformer blocks for LTXV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
@ -482,19 +917,32 @@ class LTXVModel(torch.nn.Module):
|
||||
context=context,
|
||||
attention_mask=attention_mask,
|
||||
timestep=timestep,
|
||||
pe=pe
|
||||
pe=pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
return x
|
||||
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
"""Process output for LTXV."""
|
||||
# Apply scale-shift modulation
|
||||
scale_shift_values = (
|
||||
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||
)
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = x * (1 + scale) + shift
|
||||
x = self.proj_out(x)
|
||||
|
||||
if keyframe_idxs is not None:
|
||||
grid_mask = kwargs["grid_mask"]
|
||||
orig_patchified_shape = kwargs["orig_patchified_shape"]
|
||||
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
|
||||
full_x[:, grid_mask, :] = x
|
||||
x = full_x
|
||||
# Unpatchify to restore original dimensions
|
||||
orig_shape = kwargs["orig_shape"]
|
||||
x = self.patchifier.unpatchify(
|
||||
latents=x,
|
||||
output_height=orig_shape[3],
|
||||
|
||||
@ -21,20 +21,23 @@ def latent_to_pixel_coords(
|
||||
Returns:
|
||||
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
||||
"""
|
||||
shape = [1] * latent_coords.ndim
|
||||
shape[1] = -1
|
||||
pixel_coords = (
|
||||
latent_coords
|
||||
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
||||
* torch.tensor(scale_factors, device=latent_coords.device).view(*shape)
|
||||
)
|
||||
if causal_fix:
|
||||
# Fix temporal scale for first frame to 1 due to causality
|
||||
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
||||
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||
return pixel_coords
|
||||
|
||||
|
||||
class Patchifier(ABC):
|
||||
def __init__(self, patch_size: int):
|
||||
def __init__(self, patch_size: int, start_end: bool=False):
|
||||
super().__init__()
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
self.start_end = start_end
|
||||
|
||||
@abstractmethod
|
||||
def patchify(
|
||||
@ -71,11 +74,23 @@ class Patchifier(ABC):
|
||||
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
||||
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_coords = rearrange(
|
||||
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
|
||||
latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0)
|
||||
delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None]
|
||||
latent_sample_coords_end = latent_sample_coords_start + delta
|
||||
|
||||
latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_sample_coords_start = rearrange(
|
||||
latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size
|
||||
)
|
||||
if self.start_end:
|
||||
latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_sample_coords_end = rearrange(
|
||||
latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size
|
||||
)
|
||||
|
||||
latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1)
|
||||
else:
|
||||
latent_coords = latent_sample_coords_start
|
||||
return latent_coords
|
||||
|
||||
|
||||
@ -115,3 +130,61 @@ class SymmetricPatchifier(Patchifier):
|
||||
q=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
|
||||
|
||||
class AudioPatchifier(Patchifier):
|
||||
def __init__(self, patch_size: int,
|
||||
sample_rate=16000,
|
||||
hop_length=160,
|
||||
audio_latent_downsample_factor=4,
|
||||
is_causal=True,
|
||||
start_end=False,
|
||||
shift = 0
|
||||
):
|
||||
super().__init__(patch_size, start_end=start_end)
|
||||
self.hop_length = hop_length
|
||||
self.sample_rate = sample_rate
|
||||
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||
self.is_causal = is_causal
|
||||
self.shift = shift
|
||||
|
||||
def copy_with_shift(self, shift):
|
||||
return AudioPatchifier(
|
||||
self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor,
|
||||
self.is_causal, self.start_end, shift
|
||||
)
|
||||
|
||||
def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device):
|
||||
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
|
||||
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
|
||||
if self.is_causal:
|
||||
audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0)
|
||||
return audio_mel_frame * self.hop_length / self.sample_rate
|
||||
|
||||
|
||||
def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# audio_latents: (batch, channels, time, freq)
|
||||
b, _, t, _ = audio_latents.shape
|
||||
audio_latents = rearrange(
|
||||
audio_latents,
|
||||
"b c t f -> b t (c f)",
|
||||
)
|
||||
|
||||
audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device)
|
||||
audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
|
||||
|
||||
if self.start_end:
|
||||
audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device)
|
||||
audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
|
||||
|
||||
audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1)
|
||||
else:
|
||||
audio_latents_timings = audio_latents_start_timings
|
||||
return audio_latents, audio_latents_timings
|
||||
|
||||
def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor:
|
||||
# audio_latents: (batch, time, freq * channels)
|
||||
audio_latents = rearrange(
|
||||
audio_latents, "b t (c f) -> b c t f", c=channels, f=freq
|
||||
)
|
||||
return audio_latents
|
||||
|
||||
286
comfy/ldm/lightricks/vae/audio_vae.py
Normal file
286
comfy/ldm/lightricks/vae/audio_vae.py
Normal file
@ -0,0 +1,286 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
import comfy.utils as utils
|
||||
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
||||
CausalityAxis,
|
||||
CausalAudioAutoencoder,
|
||||
)
|
||||
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioVAEComponentConfig:
|
||||
"""Container for model component configuration extracted from metadata."""
|
||||
|
||||
autoencoder: dict
|
||||
vocoder: dict
|
||||
|
||||
@classmethod
|
||||
def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig":
|
||||
assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE"
|
||||
|
||||
raw_config = metadata["config"]
|
||||
if isinstance(raw_config, str):
|
||||
parsed_config = json.loads(raw_config)
|
||||
else:
|
||||
parsed_config = raw_config
|
||||
|
||||
audio_config = parsed_config.get("audio_vae")
|
||||
vocoder_config = parsed_config.get("vocoder")
|
||||
|
||||
assert audio_config is not None, "Audio VAE config is required for audio VAE"
|
||||
assert vocoder_config is not None, "Vocoder config is required for audio VAE"
|
||||
|
||||
return cls(autoencoder=audio_config, vocoder=vocoder_config)
|
||||
|
||||
|
||||
class ModelDeviceManager:
|
||||
"""Manages device placement and GPU residency for the composed model."""
|
||||
|
||||
def __init__(self, module: torch.nn.Module):
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
offload_device = comfy.model_management.vae_offload_device()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
|
||||
|
||||
def ensure_model_loaded(self) -> None:
|
||||
comfy.model_management.free_memory(
|
||||
self.patcher.model_size(),
|
||||
self.patcher.load_device,
|
||||
)
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
|
||||
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(self.patcher.load_device)
|
||||
|
||||
@property
|
||||
def load_device(self):
|
||||
return self.patcher.load_device
|
||||
|
||||
|
||||
class AudioLatentNormalizer:
|
||||
"""Applies per-channel statistics in patch space and restores original layout."""
|
||||
|
||||
def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module):
|
||||
self.patchifier = patchfier
|
||||
self.statistics = statistics_processor
|
||||
|
||||
def normalize(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
channels = latents.shape[1]
|
||||
freq = latents.shape[3]
|
||||
patched, _ = self.patchifier.patchify(latents)
|
||||
normalized = self.statistics.normalize(patched)
|
||||
return self.patchifier.unpatchify(normalized, channels=channels, freq=freq)
|
||||
|
||||
def denormalize(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
channels = latents.shape[1]
|
||||
freq = latents.shape[3]
|
||||
patched, _ = self.patchifier.patchify(latents)
|
||||
denormalized = self.statistics.un_normalize(patched)
|
||||
return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq)
|
||||
|
||||
|
||||
class AudioPreprocessor:
|
||||
"""Prepares raw waveforms for the autoencoder by matching training conditions."""
|
||||
|
||||
def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int):
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.mel_bins = mel_bins
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.n_fft = n_fft
|
||||
|
||||
def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor:
|
||||
if source_rate == self.target_sample_rate:
|
||||
return waveform
|
||||
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
|
||||
|
||||
@staticmethod
|
||||
def normalize_amplitude(
|
||||
waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5
|
||||
) -> torch.Tensor:
|
||||
waveform = waveform - waveform.mean(dim=2, keepdim=True)
|
||||
peak = torch.max(torch.abs(waveform)) + eps
|
||||
scale = peak.clamp(max=max_amplitude) / peak
|
||||
return waveform * scale
|
||||
|
||||
def waveform_to_mel(
|
||||
self, waveform: torch.Tensor, waveform_sample_rate: int, device
|
||||
) -> torch.Tensor:
|
||||
waveform = self.resample(waveform, waveform_sample_rate)
|
||||
waveform = self.normalize_amplitude(waveform)
|
||||
|
||||
mel_transform = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=self.target_sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.n_fft,
|
||||
hop_length=self.mel_hop_length,
|
||||
f_min=0.0,
|
||||
f_max=self.target_sample_rate / 2.0,
|
||||
n_mels=self.mel_bins,
|
||||
window_fn=torch.hann_window,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
power=1.0,
|
||||
mel_scale="slaney",
|
||||
norm="slaney",
|
||||
).to(device)
|
||||
|
||||
mel = mel_transform(waveform)
|
||||
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return mel.permute(0, 1, 3, 2).contiguous()
|
||||
|
||||
|
||||
class AudioVAE(torch.nn.Module):
|
||||
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
|
||||
|
||||
def __init__(self, state_dict: dict, metadata: dict):
|
||||
super().__init__()
|
||||
|
||||
component_config = AudioVAEComponentConfig.from_metadata(metadata)
|
||||
|
||||
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
|
||||
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||
|
||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
|
||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||
|
||||
autoencoder_config = self.autoencoder.get_config()
|
||||
self.normalizer = AudioLatentNormalizer(
|
||||
AudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=autoencoder_config["sampling_rate"],
|
||||
hop_length=autoencoder_config["mel_hop_length"],
|
||||
is_causal=autoencoder_config["is_causal"],
|
||||
),
|
||||
self.autoencoder.per_channel_statistics,
|
||||
)
|
||||
|
||||
self.preprocessor = AudioPreprocessor(
|
||||
target_sample_rate=autoencoder_config["sampling_rate"],
|
||||
mel_bins=autoencoder_config["mel_bins"],
|
||||
mel_hop_length=autoencoder_config["mel_hop_length"],
|
||||
n_fft=autoencoder_config["n_fft"],
|
||||
)
|
||||
|
||||
self.device_manager = ModelDeviceManager(self)
|
||||
|
||||
def encode(self, audio: dict) -> torch.Tensor:
|
||||
"""Encode a waveform dictionary into normalized latent tensors."""
|
||||
|
||||
waveform = audio["waveform"]
|
||||
waveform_sample_rate = audio["sample_rate"]
|
||||
input_device = waveform.device
|
||||
# Ensure that Audio VAE is loaded on the correct device.
|
||||
self.device_manager.ensure_model_loaded()
|
||||
|
||||
waveform = self.device_manager.move_to_load_device(waveform)
|
||||
expected_channels = self.autoencoder.encoder.in_channels
|
||||
if waveform.shape[1] != expected_channels:
|
||||
raise ValueError(
|
||||
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
|
||||
)
|
||||
|
||||
mel_spec = self.preprocessor.waveform_to_mel(
|
||||
waveform, waveform_sample_rate, device=self.device_manager.load_device
|
||||
)
|
||||
|
||||
latents = self.autoencoder.encode(mel_spec)
|
||||
posterior = DiagonalGaussianDistribution(latents)
|
||||
latent_mode = posterior.mode()
|
||||
|
||||
normalized = self.normalizer.normalize(latent_mode)
|
||||
return normalized.to(input_device)
|
||||
|
||||
def decode(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
"""Decode normalized latent tensors into an audio waveform."""
|
||||
original_shape = latents.shape
|
||||
|
||||
# Ensure that Audio VAE is loaded on the correct device.
|
||||
self.device_manager.ensure_model_loaded()
|
||||
|
||||
latents = self.device_manager.move_to_load_device(latents)
|
||||
latents = self.normalizer.denormalize(latents)
|
||||
|
||||
target_shape = self.target_shape_from_latents(original_shape)
|
||||
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
|
||||
|
||||
waveform = self.run_vocoder(mel_spec)
|
||||
return self.device_manager.move_to_load_device(waveform)
|
||||
|
||||
def target_shape_from_latents(self, latents_shape):
|
||||
batch, _, time, _ = latents_shape
|
||||
target_length = time * LATENT_DOWNSAMPLE_FACTOR
|
||||
if self.autoencoder.causality_axis != CausalityAxis.NONE:
|
||||
target_length -= LATENT_DOWNSAMPLE_FACTOR - 1
|
||||
return (
|
||||
batch,
|
||||
self.autoencoder.decoder.out_ch,
|
||||
target_length,
|
||||
self.autoencoder.mel_bins,
|
||||
)
|
||||
|
||||
def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int:
|
||||
return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second)
|
||||
|
||||
def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
||||
audio_channels = self.autoencoder.decoder.out_ch
|
||||
vocoder_input = mel_spec.transpose(2, 3)
|
||||
|
||||
if audio_channels == 1:
|
||||
vocoder_input = vocoder_input.squeeze(1)
|
||||
elif audio_channels != 2:
|
||||
raise ValueError(f"Unsupported audio_channels: {audio_channels}")
|
||||
|
||||
return self.vocoder(vocoder_input)
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
return int(self.autoencoder.sampling_rate)
|
||||
|
||||
@property
|
||||
def mel_hop_length(self) -> int:
|
||||
return int(self.autoencoder.mel_hop_length)
|
||||
|
||||
@property
|
||||
def mel_bins(self) -> int:
|
||||
return int(self.autoencoder.mel_bins)
|
||||
|
||||
@property
|
||||
def latent_channels(self) -> int:
|
||||
return int(self.autoencoder.decoder.z_channels)
|
||||
|
||||
@property
|
||||
def latent_frequency_bins(self) -> int:
|
||||
return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR)
|
||||
|
||||
@property
|
||||
def latents_per_second(self) -> float:
|
||||
return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR
|
||||
|
||||
@property
|
||||
def output_sample_rate(self) -> int:
|
||||
output_rate = getattr(self.vocoder, "output_sample_rate", None)
|
||||
if output_rate is not None:
|
||||
return int(output_rate)
|
||||
upsample_factor = getattr(self.vocoder, "upsample_factor", None)
|
||||
if upsample_factor is None:
|
||||
raise AttributeError(
|
||||
"Vocoder is missing upsample_factor; cannot infer output sample rate"
|
||||
)
|
||||
return int(self.sample_rate * upsample_factor / self.mel_hop_length)
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
return self.device_manager.patcher.model_size()
|
||||
909
comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
Normal file
909
comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
Normal file
@ -0,0 +1,909 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from .pixel_norm import PixelNorm
|
||||
import comfy.ops
|
||||
import logging
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class StringConvertibleEnum(Enum):
|
||||
"""
|
||||
Base enum class that provides string-to-enum conversion functionality.
|
||||
|
||||
This mixin adds a str_to_enum() class method that handles conversion from
|
||||
strings, None, or existing enum instances with case-insensitive matching.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def str_to_enum(cls, value):
|
||||
"""
|
||||
Convert a string, enum instance, or None to the appropriate enum member.
|
||||
|
||||
Args:
|
||||
value: Can be an enum instance of this class, a string, or None
|
||||
|
||||
Returns:
|
||||
Enum member of this class
|
||||
|
||||
Raises:
|
||||
ValueError: If the value cannot be converted to a valid enum member
|
||||
"""
|
||||
# Already an enum instance of this class
|
||||
if isinstance(value, cls):
|
||||
return value
|
||||
|
||||
# None maps to NONE member if it exists
|
||||
if value is None:
|
||||
if hasattr(cls, "NONE"):
|
||||
return cls.NONE
|
||||
raise ValueError(f"{cls.__name__} does not have a NONE member to map None to")
|
||||
|
||||
# String conversion (case-insensitive)
|
||||
if isinstance(value, str):
|
||||
value_lower = value.lower()
|
||||
|
||||
# Try to match against enum values
|
||||
for member in cls:
|
||||
# Handle members with None values
|
||||
if member.value is None:
|
||||
if value_lower == "none":
|
||||
return member
|
||||
# Handle members with string values
|
||||
elif isinstance(member.value, str) and member.value.lower() == value_lower:
|
||||
return member
|
||||
|
||||
# Build helpful error message with valid values
|
||||
valid_values = []
|
||||
for member in cls:
|
||||
if member.value is None:
|
||||
valid_values.append("none")
|
||||
elif isinstance(member.value, str):
|
||||
valid_values.append(member.value)
|
||||
|
||||
raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}")
|
||||
|
||||
raise ValueError(
|
||||
f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. "
|
||||
f"Expected string, None, or {cls.__name__} instance."
|
||||
)
|
||||
|
||||
|
||||
class AttentionType(StringConvertibleEnum):
|
||||
"""Enum for specifying the attention mechanism type."""
|
||||
|
||||
VANILLA = "vanilla"
|
||||
LINEAR = "linear"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class CausalityAxis(StringConvertibleEnum):
|
||||
"""Enum for specifying the causality axis in causal convolutions."""
|
||||
|
||||
NONE = None
|
||||
WIDTH = "width"
|
||||
HEIGHT = "height"
|
||||
WIDTH_COMPATIBILITY = "width-compatibility"
|
||||
|
||||
|
||||
def Normalize(in_channels, *, num_groups=32, normtype="group"):
|
||||
if normtype == "group":
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif normtype == "pixel":
|
||||
return PixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||
|
||||
|
||||
class CausalConv2d(nn.Module):
|
||||
"""
|
||||
A causal 2D convolution.
|
||||
|
||||
This layer ensures that the output at time `t` only depends on inputs
|
||||
at time `t` and earlier. It achieves this by applying asymmetric padding
|
||||
to the time dimension (width) before the convolution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
# Ensure kernel_size and dilation are tuples
|
||||
kernel_size = nn.modules.utils._pair(kernel_size)
|
||||
dilation = nn.modules.utils._pair(dilation)
|
||||
|
||||
# Calculate padding dimensions
|
||||
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||
|
||||
# The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
|
||||
match self.causality_axis:
|
||||
case CausalityAxis.NONE:
|
||||
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
||||
case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
||||
case CausalityAxis.HEIGHT:
|
||||
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
||||
case _:
|
||||
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||
|
||||
# The internal convolution layer uses no padding, as we handle it manually
|
||||
self.conv = ops.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Apply causal padding before convolution
|
||||
x = F.pad(x, self.padding)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def make_conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=None,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
causality_axis: Optional[CausalityAxis] = None,
|
||||
):
|
||||
"""
|
||||
Create a 2D convolution layer that can be either causal or non-causal.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels
|
||||
out_channels: Number of output channels
|
||||
kernel_size: Size of the convolution kernel
|
||||
stride: Convolution stride
|
||||
padding: Padding (if None, will be calculated based on causal flag)
|
||||
dilation: Dilation rate
|
||||
groups: Number of groups for grouped convolution
|
||||
bias: Whether to use bias
|
||||
causality_axis: Dimension along which to apply causality.
|
||||
|
||||
Returns:
|
||||
Either a regular Conv2d or CausalConv2d layer
|
||||
"""
|
||||
if causality_axis is not None:
|
||||
# For causal convolution, padding is handled internally by CausalConv2d
|
||||
return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
|
||||
else:
|
||||
# For non-causal convolution, use symmetric padding if not specified
|
||||
if padding is None:
|
||||
if isinstance(kernel_size, int):
|
||||
padding = kernel_size // 2
|
||||
else:
|
||||
padding = tuple(k // 2 for k in kernel_size)
|
||||
return ops.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
|
||||
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
|
||||
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
|
||||
# So the output elements rely on the following windows:
|
||||
# 0: [-,-,0]
|
||||
# 1: [-,0,0]
|
||||
# 2: [0,0,1]
|
||||
# 3: [0,1,1]
|
||||
# 4: [1,1,2]
|
||||
# 5: [1,2,2]
|
||||
# Notice that the first and second elements in the output rely only on the first element in the input,
|
||||
# while all other elements rely on two elements in the input.
|
||||
# So we can drop the first element to undo the padding (rather than the last element).
|
||||
# This is a no-op for non-causal convolutions.
|
||||
match self.causality_axis:
|
||||
case CausalityAxis.NONE:
|
||||
pass # x remains unchanged
|
||||
case CausalityAxis.HEIGHT:
|
||||
x = x[:, :, 1:, :]
|
||||
case CausalityAxis.WIDTH:
|
||||
x = x[:, :, :, 1:]
|
||||
case CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
pass # x remains unchanged
|
||||
case _:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer that can use either a strided convolution
|
||||
or average pooling. Supports standard and causal padding for the
|
||||
convolutional mode.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
|
||||
raise ValueError("causality is only supported when `with_conv=True`.")
|
||||
|
||||
if self.with_conv:
|
||||
# Do time downsampling here
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
# (pad_left, pad_right, pad_top, pad_bottom)
|
||||
match self.causality_axis:
|
||||
case CausalityAxis.NONE:
|
||||
pad = (0, 1, 0, 1)
|
||||
case CausalityAxis.WIDTH:
|
||||
pad = (2, 0, 0, 1)
|
||||
case CausalityAxis.HEIGHT:
|
||||
pad = (0, 1, 2, 0)
|
||||
case CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
pad = (1, 0, 0, 1)
|
||||
case _:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
# This branch is only taken if with_conv=False, which implies causality_axis is NONE.
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512,
|
||||
norm_type="group",
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
):
|
||||
super().__init__()
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis != CausalityAxis.NONE and norm_type == "group":
|
||||
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels, normtype=norm_type)
|
||||
self.non_linearity = nn.SiLU()
|
||||
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = ops.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels, normtype=norm_type)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = self.non_linearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.non_linearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels, norm_type="group"):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels, normtype=norm_type)
|
||||
self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w).contiguous()
|
||||
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
||||
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
|
||||
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w).contiguous()
|
||||
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w).contiguous()
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", norm_type="group"):
|
||||
# Convert string to enum if needed
|
||||
attn_type = AttentionType.str_to_enum(attn_type)
|
||||
|
||||
if attn_type != AttentionType.NONE:
|
||||
logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels")
|
||||
else:
|
||||
logging.info(f"making identity attention with {in_channels} in_channels")
|
||||
|
||||
match attn_type:
|
||||
case AttentionType.VANILLA:
|
||||
return AttnBlock(in_channels, norm_type=norm_type)
|
||||
case AttentionType.NONE:
|
||||
return nn.Identity(in_channels)
|
||||
case AttentionType.LINEAR:
|
||||
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
||||
case _:
|
||||
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
attn_type="vanilla",
|
||||
mid_block_add_attention=True,
|
||||
norm_type="group",
|
||||
causality_axis=CausalityAxis.WIDTH.value,
|
||||
**ignore_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.z_channels = z_channels
|
||||
self.double_z = double_z
|
||||
self.norm_type = norm_type
|
||||
# Convert string to enum if needed (for config loading)
|
||||
causality_axis = CausalityAxis.str_to_enum(causality_axis)
|
||||
self.attn_type = AttentionType.str_to_enum(attn_type)
|
||||
|
||||
# downsampling
|
||||
self.conv_in = make_conv2d(
|
||||
in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
|
||||
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, normtype=self.norm_type)
|
||||
self.conv_out = make_conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the encoder.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape [batch, channels, time, n_mels]
|
||||
|
||||
Returns:
|
||||
Encoded latent representation
|
||||
"""
|
||||
feature_maps = [self.conv_in(x)]
|
||||
|
||||
# Process each resolution level (from high to low resolution)
|
||||
for resolution_level in range(self.num_resolutions):
|
||||
# Apply residual blocks at current resolution level
|
||||
for block_idx in range(self.num_res_blocks):
|
||||
# Apply ResNet block with optional timestep embedding
|
||||
current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None)
|
||||
|
||||
# Apply attention if configured for this resolution level
|
||||
if len(self.down[resolution_level].attn) > 0:
|
||||
current_features = self.down[resolution_level].attn[block_idx](current_features)
|
||||
|
||||
# Store processed features
|
||||
feature_maps.append(current_features)
|
||||
|
||||
# Downsample spatial dimensions (except at the final resolution level)
|
||||
if resolution_level != self.num_resolutions - 1:
|
||||
downsampled_features = self.down[resolution_level].downsample(feature_maps[-1])
|
||||
feature_maps.append(downsampled_features)
|
||||
|
||||
# === MIDDLE PROCESSING PHASE ===
|
||||
# Take the lowest resolution features for middle processing
|
||||
bottleneck_features = feature_maps[-1]
|
||||
|
||||
# Apply first middle ResNet block
|
||||
bottleneck_features = self.mid.block_1(bottleneck_features, temb=None)
|
||||
|
||||
# Apply middle attention block
|
||||
bottleneck_features = self.mid.attn_1(bottleneck_features)
|
||||
|
||||
# Apply second middle ResNet block
|
||||
bottleneck_features = self.mid.block_2(bottleneck_features, temb=None)
|
||||
|
||||
# === OUTPUT PHASE ===
|
||||
# Normalize the bottleneck features
|
||||
output_features = self.norm_out(bottleneck_features)
|
||||
|
||||
# Apply non-linearity (SiLU activation)
|
||||
output_features = self.non_linearity(output_features)
|
||||
|
||||
# Final convolution to produce latent representation
|
||||
# [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels]
|
||||
return self.conv_out(output_features)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
attn_type="vanilla",
|
||||
mid_block_add_attention=True,
|
||||
norm_type="group",
|
||||
causality_axis=CausalityAxis.WIDTH.value,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = out_ch
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
self.norm_type = norm_type
|
||||
self.z_channels = z_channels
|
||||
# Convert string to enum if needed (for config loading)
|
||||
causality_axis = CausalityAxis.str_to_enum(causality_axis)
|
||||
self.attn_type = AttentionType.str_to_enum(attn_type)
|
||||
|
||||
# compute block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, normtype=self.norm_type)
|
||||
self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
|
||||
def _adjust_output_shape(self, decoded_output, target_shape):
|
||||
"""
|
||||
Adjust output shape to match target dimensions for variable-length audio.
|
||||
|
||||
This function handles the common case where decoded audio spectrograms need to be
|
||||
resized to match a specific target shape.
|
||||
|
||||
Args:
|
||||
decoded_output: Tensor of shape (batch, channels, time, frequency)
|
||||
target_shape: Target shape tuple (batch, channels, time, frequency)
|
||||
|
||||
Returns:
|
||||
Tensor adjusted to match target_shape exactly
|
||||
"""
|
||||
# Current output shape: (batch, channels, time, frequency)
|
||||
_, _, current_time, current_freq = decoded_output.shape
|
||||
_, target_channels, target_time, target_freq = target_shape
|
||||
|
||||
# Step 1: Crop first to avoid exceeding target dimensions
|
||||
decoded_output = decoded_output[
|
||||
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
||||
]
|
||||
|
||||
# Step 2: Calculate padding needed for time and frequency dimensions
|
||||
time_padding_needed = target_time - decoded_output.shape[2]
|
||||
freq_padding_needed = target_freq - decoded_output.shape[3]
|
||||
|
||||
# Step 3: Apply padding if needed
|
||||
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
|
||||
# For audio: pad_left/right = frequency, pad_top/bottom = time
|
||||
padding = (
|
||||
0,
|
||||
max(freq_padding_needed, 0), # frequency padding (left, right)
|
||||
0,
|
||||
max(time_padding_needed, 0), # time padding (top, bottom)
|
||||
)
|
||||
decoded_output = F.pad(decoded_output, padding)
|
||||
|
||||
# Step 4: Final safety crop to ensure exact target shape
|
||||
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
||||
|
||||
return decoded_output
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"ch": self.ch,
|
||||
"out_ch": self.out_ch,
|
||||
"ch_mult": self.ch_mult,
|
||||
"num_res_blocks": self.num_res_blocks,
|
||||
"in_channels": self.in_channels,
|
||||
"resolution": self.resolution,
|
||||
"z_channels": self.z_channels,
|
||||
}
|
||||
|
||||
def forward(self, latent_features, target_shape=None):
|
||||
"""
|
||||
Decode latent features back to audio spectrograms.
|
||||
|
||||
Args:
|
||||
latent_features: Encoded latent representation of shape (batch, channels, height, width)
|
||||
target_shape: Optional target output shape (batch, channels, time, frequency)
|
||||
If provided, output will be cropped/padded to match this shape
|
||||
|
||||
Returns:
|
||||
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
|
||||
"""
|
||||
assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder"
|
||||
|
||||
# Transform latent features to decoder's internal feature dimension
|
||||
hidden_features = self.conv_in(latent_features)
|
||||
|
||||
# Middle processing
|
||||
hidden_features = self.mid.block_1(hidden_features, temb=None)
|
||||
hidden_features = self.mid.attn_1(hidden_features)
|
||||
hidden_features = self.mid.block_2(hidden_features, temb=None)
|
||||
|
||||
# Upsampling
|
||||
# Progressively increase spatial resolution from lowest to highest
|
||||
for resolution_level in reversed(range(self.num_resolutions)):
|
||||
# Apply residual blocks at current resolution level
|
||||
for block_index in range(self.num_res_blocks + 1):
|
||||
hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None)
|
||||
|
||||
if len(self.up[resolution_level].attn) > 0:
|
||||
hidden_features = self.up[resolution_level].attn[block_index](hidden_features)
|
||||
|
||||
if resolution_level != 0:
|
||||
hidden_features = self.up[resolution_level].upsample(hidden_features)
|
||||
|
||||
# Output
|
||||
if self.give_pre_end:
|
||||
# Return intermediate features before final processing (for debugging/analysis)
|
||||
decoded_output = hidden_features
|
||||
else:
|
||||
# Standard output path: normalize, activate, and convert to output channels
|
||||
# Final normalization layer
|
||||
hidden_features = self.norm_out(hidden_features)
|
||||
|
||||
# Apply SiLU (Swish) activation function
|
||||
hidden_features = self.non_linearity(hidden_features)
|
||||
|
||||
# Final convolution to map to output channels (typically 2 for stereo audio)
|
||||
decoded_output = self.conv_out(hidden_features)
|
||||
|
||||
# Optional tanh activation to bound output values to [-1, 1] range
|
||||
if self.tanh_out:
|
||||
decoded_output = torch.tanh(decoded_output)
|
||||
|
||||
# Adjust shape for audio data
|
||||
if target_shape is not None:
|
||||
decoded_output = self._adjust_output_shape(decoded_output, target_shape)
|
||||
|
||||
return decoded_output
|
||||
|
||||
|
||||
class processor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-means", torch.empty(128))
|
||||
|
||||
def un_normalize(self, x):
|
||||
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
|
||||
|
||||
def normalize(self, x):
|
||||
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
|
||||
|
||||
|
||||
class CausalAudioAutoencoder(nn.Module):
|
||||
def __init__(self, config=None):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self._guess_config()
|
||||
|
||||
# Extract encoder and decoder configs from the new format
|
||||
model_config = config.get("model", {}).get("params", {})
|
||||
variables_config = config.get("variables", {})
|
||||
|
||||
self.sampling_rate = variables_config.get(
|
||||
"sampling_rate",
|
||||
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
|
||||
)
|
||||
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
||||
decoder_config = model_config.get("decoder", encoder_config)
|
||||
|
||||
# Load mel spectrogram parameters
|
||||
self.mel_bins = encoder_config.get("mel_bins", 64)
|
||||
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||
|
||||
# Store causality configuration at VAE level (not just in encoder internals)
|
||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
|
||||
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
||||
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
||||
|
||||
self.encoder = Encoder(**encoder_config)
|
||||
self.decoder = Decoder(**decoder_config)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def _guess_config(self):
|
||||
encoder_config = {
|
||||
# Required parameters - based on ltx-video-av-1679000 model metadata
|
||||
"ch": 128,
|
||||
"out_ch": 8,
|
||||
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": [], # Based on metadata: empty list, no attention
|
||||
"dropout": 0.0,
|
||||
"resamp_with_conv": True,
|
||||
"in_channels": 2, # stereo
|
||||
"resolution": 256,
|
||||
"z_channels": 8,
|
||||
"double_z": True,
|
||||
"attn_type": "vanilla",
|
||||
"mid_block_add_attention": False, # Based on metadata: false
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height", # Based on metadata
|
||||
"mel_bins": 64, # Based on metadata: mel_bins = 64
|
||||
}
|
||||
|
||||
decoder_config = {
|
||||
# Inherits encoder config, can override specific params
|
||||
**encoder_config,
|
||||
"out_ch": 2, # Stereo audio output (2 channels)
|
||||
"give_pre_end": False,
|
||||
"tanh_out": False,
|
||||
}
|
||||
|
||||
config = {
|
||||
"_class_name": "CausalAudioAutoencoder",
|
||||
"sampling_rate": 16000,
|
||||
"model": {
|
||||
"params": {
|
||||
"encoder": encoder_config,
|
||||
"decoder": decoder_config,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"sampling_rate": self.sampling_rate,
|
||||
"mel_bins": self.mel_bins,
|
||||
"mel_hop_length": self.mel_hop_length,
|
||||
"n_fft": self.n_fft,
|
||||
"causality_axis": self.causality_axis.value,
|
||||
"is_causal": self.is_causal,
|
||||
}
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x, target_shape=None):
|
||||
return self.decoder(x, target_shape=target_shape)
|
||||
213
comfy/ldm/lightricks/vocoders/vocoder.py
Normal file
213
comfy/ldm/lightricks/vocoders/vocoder.py
Normal file
@ -0,0 +1,213 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
import numpy as np
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock1, self).__init__()
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super(ResBlock2, self).__init__()
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
|
||||
class Vocoder(torch.nn.Module):
|
||||
"""
|
||||
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
super(Vocoder, self).__init__()
|
||||
|
||||
if config is None:
|
||||
config = self.get_default_config()
|
||||
|
||||
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
||||
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
|
||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
|
||||
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
||||
stereo = config.get("stereo", True)
|
||||
resblock = config.get("resblock", "1")
|
||||
|
||||
self.output_sample_rate = config.get("output_sample_rate")
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
in_channels = 128 if stereo else 64
|
||||
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
ops.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock_class(ch, k, d))
|
||||
|
||||
out_channels = 2 if stereo else 1
|
||||
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
|
||||
|
||||
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
||||
|
||||
def get_default_config(self):
|
||||
"""Generate default configuration for the vocoder."""
|
||||
|
||||
config = {
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"upsample_rates": [6, 5, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"upsample_initial_channel": 1024,
|
||||
"stereo": True,
|
||||
"resblock": "1",
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the vocoder.
|
||||
|
||||
Args:
|
||||
x: Input spectrogram tensor. Can be:
|
||||
- 3D: (batch_size, channels, time_steps) for mono
|
||||
- 4D: (batch_size, 2, channels, time_steps) for stereo
|
||||
|
||||
Returns:
|
||||
Audio tensor of shape (batch_size, out_channels, audio_length)
|
||||
"""
|
||||
if x.dim() == 4: # stereo
|
||||
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
||||
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
||||
x = self.conv_pre(x)
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user