Commit Graph

314 Commits

Author SHA1 Message Date
c9d5b6d4a8 Replace FlashAttention with xformers (#70) 2023-05-05 02:01:08 -07:00
436e523bf1 Refactor attention kernels (#53) 2023-05-03 13:40:13 -07:00
a96d63c21d Add support for GPT-NeoX (Pythia) (#50) 2023-04-28 00:32:10 -07:00
e3cec88aa5 Memcpy kernel for flash attention (#29)
* optimize

* add benchmark

* add assert

* add test
2023-04-10 18:22:49 -07:00
b9926f7f66 Support block size 32 (#35) 2023-04-09 23:07:18 -07:00
c267b1a02c Add query stride to multi_query_cached_kv_attention & Add kernel benchmark script (#27)
* Add query stride to multi_query_cached_kv_attention

* Add kernel benchmark script
2023-04-08 13:36:09 -07:00
0f40557af6 Implement block copy kernel to optimize beam search (#32) 2023-04-07 17:45:07 -07:00
21b3671bbc Basic attention kernel that supports cached KV + (multi-)prompts (#24) 2023-04-04 20:34:46 -07:00
897cb2ae28 Optimize data movement (#20) 2023-04-02 00:30:17 -07:00
09e9245478 Add custom kernel for RMS normalization (#16) 2023-04-01 00:51:22 +08:00
88c0268a18 Implement custom kernel for LLaMA rotary embedding (#14) 2023-03-30 11:04:21 -07:00
a1b3de86cd Refactor the test code for attention kernels (#13) 2023-03-29 18:59:27 -07:00
3e9f991d6a Use FlashAttention for multi_query_kv_attention (#4) 2023-03-01 21:13:08 -08:00
0deacbce6e Implement single_query_cached_kv_attention kernel (#3) 2023-03-01 15:02:19 -08:00