[NVIDIA] Support SiluMul + NVFP4 quant fusion (#23671)
Signed-off-by: jindih <jindih@nvidia.com> Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: jindih <jindih@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Luka Govedic <lgovedic@redhat.com>
This commit is contained in:
@ -19,6 +19,13 @@
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_CASE_HALF_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__))
|
||||
|
||||
// ROCm devices might use either fn or fnuz, so set up dispatch table for both.
|
||||
// A host-based check at runtime will create a preferred FP8 type for ROCm
|
||||
// such that the correct kernel is dispatched.
|
||||
@ -45,6 +52,15 @@
|
||||
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
|
||||
|
||||
#define AT_DISPATCH_BYTE_CASE(enum_type, ...) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, byte_t, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_CASE_BYTE_TYPES(...) \
|
||||
AT_DISPATCH_BYTE_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_BYTE_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_BYTE_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user