add {uint4, uint2, int2} => {fp16, bf16} conversion (#1966)
This commit is contained in:
@ -644,11 +644,26 @@ struct GetName {
|
||||
static constexpr char name[] = "UNSUPPORTED";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GetName<cutlass::int2b_t> {
|
||||
static constexpr char name[] = "int2b_t";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GetName<cutlass::uint2b_t> {
|
||||
static constexpr char name[] = "uint2b_t";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GetName<cutlass::int4b_t> {
|
||||
static constexpr char name[] = "int4b_t";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GetName<cutlass::uint4b_t> {
|
||||
static constexpr char name[] = "uint4b_t";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GetName<uint8_t> {
|
||||
static constexpr char name[] = "uint8_t";
|
||||
@ -709,9 +724,15 @@ using VectorConvertTypes = ::testing::Types<
|
||||
ResultSourcePair<cutlass::bfloat16_t, uint8_t>,
|
||||
ResultSourcePair<cutlass::bfloat16_t, int8_t>,
|
||||
|
||||
ResultSourcePair<cutlass::half_t, cutlass::int2b_t>,
|
||||
ResultSourcePair<cutlass::bfloat16_t, cutlass::int2b_t>,
|
||||
ResultSourcePair<cutlass::half_t, cutlass::uint2b_t>,
|
||||
ResultSourcePair<cutlass::bfloat16_t, cutlass::uint2b_t>,
|
||||
ResultSourcePair<cutlass::float_e4m3_t, cutlass::int4b_t>,
|
||||
ResultSourcePair<cutlass::half_t, cutlass::int4b_t>,
|
||||
ResultSourcePair<cutlass::bfloat16_t, cutlass::int4b_t>,
|
||||
ResultSourcePair<cutlass::half_t, cutlass::uint4b_t>,
|
||||
ResultSourcePair<cutlass::bfloat16_t, cutlass::uint4b_t>,
|
||||
ResultSourcePair<float, cutlass::int4b_t>
|
||||
>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user