fix cuda 12.6 issues (#2066)
This commit is contained in:
@ -219,7 +219,7 @@ to_CUtensorMapDataType() {
|
||||
if constexpr (is_same_v<T, double>) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else
|
||||
if constexpr (is_same_v<T, bfloat16_t>) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else
|
||||
if constexpr (is_same_v<T, tfloat32_t>) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION > 12060
|
||||
if constexpr (is_same_v<T, float_e2m3_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else
|
||||
if constexpr (is_same_v<T, float_e3m2_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else
|
||||
if constexpr (is_same_v<T, float_e2m1_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;} else
|
||||
@ -230,7 +230,7 @@ to_CUtensorMapDataType() {
|
||||
if constexpr (is_same_v<T, type_erased_dynamic_float6_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else
|
||||
if constexpr (is_same_v<T, detail::type_erased_dynamic_float4_unpacksmem_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B;} else
|
||||
if constexpr (is_same_v<T, type_erased_dynamic_float4_t>) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; } else
|
||||
|
||||
#endif
|
||||
{ static_assert(sizeof(T) < 0, "Unknown TMA Format!"); }
|
||||
}
|
||||
|
||||
@ -257,8 +257,10 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) {
|
||||
switch (b) {
|
||||
default: assert(false && "Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!");
|
||||
case SmemSwizzleBase::SWIZZLE_BASE_16B: return CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION > 12060
|
||||
case SmemSwizzleBase::SWIZZLE_BASE_32B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B;
|
||||
case SmemSwizzleBase::SWIZZLE_BASE_64B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_64B;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
Reference in New Issue
Block a user