[Feature] add quick all reduce (#19744)
Signed-off-by: ilmarkov <imarkov@redhat.com> Signed-off-by: Haoyang Li <Haoyang.Li@amd.com> Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
@ -725,6 +725,24 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle);
|
||||
|
||||
custom_ar.def("free_shared_buffer", &free_shared_buffer);
|
||||
#ifdef USE_ROCM
|
||||
// Quick Reduce all-reduce kernels
|
||||
custom_ar.def(
|
||||
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
|
||||
"cast_bf2half) -> ()");
|
||||
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
|
||||
|
||||
custom_ar.def("init_custom_qr", &init_custom_qr);
|
||||
custom_ar.def("qr_destroy", &qr_destroy);
|
||||
|
||||
custom_ar.def("qr_get_handle", &qr_get_handle);
|
||||
|
||||
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
|
||||
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
|
||||
|
||||
// Max input size in bytes
|
||||
custom_ar.def("qr_max_size", &qr_max_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
||||
Reference in New Issue
Block a user