fix cuda 12.6 issues (#2066)

This commit is contained in:
Haicheng Wu
2025-01-28 17:28:29 -05:00
committed by GitHub
parent 389e493055
commit 47daa33c61

View File

@ -219,7 +219,7 @@ to_CUtensorMapDataType() {
if constexpr (is_same_v<T, double>) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else
if constexpr (is_same_v<T, bfloat16_t>) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else
if constexpr (is_same_v<T, tfloat32_t>) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else
#if defined(CUDA_VERSION) && CUDA_VERSION > 12060
if constexpr (is_same_v<T, float_e2m3_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else
if constexpr (is_same_v<T, float_e3m2_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else
if constexpr (is_same_v<T, float_e2m1_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;} else
@ -230,7 +230,7 @@ to_CUtensorMapDataType() {
if constexpr (is_same_v<T, type_erased_dynamic_float6_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else
if constexpr (is_same_v<T, detail::type_erased_dynamic_float4_unpacksmem_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B;} else
if constexpr (is_same_v<T, type_erased_dynamic_float4_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; } else
#endif
{ static_assert(sizeof(T) < 0, "Unknown TMA Format!"); }
}
@ -257,8 +257,10 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) {
switch (b) {
default: assert(false && "Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!");
case SmemSwizzleBase::SWIZZLE_BASE_16B: return CU_TENSOR_MAP_SWIZZLE_128B;
#if defined(CUDA_VERSION) && CUDA_VERSION > 12060
case SmemSwizzleBase::SWIZZLE_BASE_32B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B;
case SmemSwizzleBase::SWIZZLE_BASE_64B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_64B;
#endif
}
#endif