add {uint4, uint2, int2} => {fp16, bf16} conversion (#1966)

This commit is contained in:
Lain
2024-12-03 11:03:43 -08:00
committed by GitHub
parent b0e09d7cd3
commit 80243e0b8c
2 changed files with 810 additions and 12 deletions

View File

@ -644,11 +644,26 @@ struct GetName {
static constexpr char name[] = "UNSUPPORTED";
};
template <>
struct GetName<cutlass::int2b_t> {
static constexpr char name[] = "int2b_t";
};
template <>
struct GetName<cutlass::uint2b_t> {
static constexpr char name[] = "uint2b_t";
};
template <>
struct GetName<cutlass::int4b_t> {
static constexpr char name[] = "int4b_t";
};
template <>
struct GetName<cutlass::uint4b_t> {
static constexpr char name[] = "uint4b_t";
};
template <>
struct GetName<uint8_t> {
static constexpr char name[] = "uint8_t";
@ -709,9 +724,15 @@ using VectorConvertTypes = ::testing::Types<
ResultSourcePair<cutlass::bfloat16_t, uint8_t>,
ResultSourcePair<cutlass::bfloat16_t, int8_t>,
ResultSourcePair<cutlass::half_t, cutlass::int2b_t>,
ResultSourcePair<cutlass::bfloat16_t, cutlass::int2b_t>,
ResultSourcePair<cutlass::half_t, cutlass::uint2b_t>,
ResultSourcePair<cutlass::bfloat16_t, cutlass::uint2b_t>,
ResultSourcePair<cutlass::float_e4m3_t, cutlass::int4b_t>,
ResultSourcePair<cutlass::half_t, cutlass::int4b_t>,
ResultSourcePair<cutlass::bfloat16_t, cutlass::int4b_t>,
ResultSourcePair<cutlass::half_t, cutlass::uint4b_t>,
ResultSourcePair<cutlass::bfloat16_t, cutlass::uint4b_t>,
ResultSourcePair<float, cutlass::int4b_t>
>;