[Bugfix] Fix cutlass dispatch for fp8/int8 to properly invoke M<=16 c… (#16751)

Signed-off-by: Ther-LF <2639852836@qq.com>
This commit is contained in:
TherLF
2025-04-28 10:38:42 +08:00
committed by GitHub
parent d1aeea7553
commit c12df53b60
2 changed files with 2 additions and 2 deletions

View File

@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]

View File

@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]