add {uint4, uint2, int2} => {fp16, bf16} conversion (#1966)
This commit is contained in:
@ -3766,6 +3766,280 @@ public:
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Partial specialization for Array<cutlass::half_t, N> <= Array<cutlass::int2b_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::half_t, cutlass::int2b_t, N, Round> {
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<cutlass::int2b_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
using result_type_packed_16 = Array<cutlass::half_t, 16>;
|
||||
using result_type_packed_8 = Array<cutlass::half_t, 8>;
|
||||
using result_type_packed_4 = Array<cutlass::half_t, 4>;
|
||||
using source_type_packed_16 = Array<cutlass::int2b_t, 16>;
|
||||
using source_type_packed_8 = Array<cutlass::int2b_t, 8>;
|
||||
using source_type_packed_4 = Array<cutlass::int2b_t, 4>;
|
||||
|
||||
using ScalarConverter = NumericConverter<cutlass::half_t, cutlass::int2b_t, Round>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_4 const& source) {
|
||||
return static_cast<uint32_t>(
|
||||
reinterpret_cast<const uint8_t&>(source));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_8 const& source) {
|
||||
return static_cast<uint32_t>(
|
||||
reinterpret_cast<const uint16_t&>(source));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_16 const& source) {
|
||||
return reinterpret_cast<const uint32_t&>(source);
|
||||
}
|
||||
|
||||
template <typename PackedResultType, typename PackedSrcType>
|
||||
CUTLASS_DEVICE
|
||||
static PackedResultType packed_convert(PackedSrcType const &source) {
|
||||
|
||||
static_assert((platform::is_same<PackedSrcType, source_type_packed_4>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_4>::value) ||
|
||||
(platform::is_same<PackedSrcType, source_type_packed_8>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_8>::value) ||
|
||||
(platform::is_same<PackedSrcType, source_type_packed_16>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_16>::value),
|
||||
"Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch.");
|
||||
|
||||
// Hold output FP16s in reg. We need 1 reg for every 2 elements
|
||||
using RegArray = cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
// View the input as reg
|
||||
uint32_t src_reg = to_reg(source);
|
||||
uint32_t src_reg_shifted = src_reg >> 4;
|
||||
|
||||
// Below constructs the following temporary:
|
||||
// f1f0 = {0x00, i3i2i1i0, 0x00, i3i2i1i0}
|
||||
// f3f2 = {0x00, i5i4i3i2, 0x00, i5i4i3i2}
|
||||
// f5f4 = {0x00, i7i6i5i4, 0x00, i7i6i5i4}
|
||||
// f7f6 = {0x00, i9i8i7i6, 0x00, i9i8i7i6}
|
||||
// f9f8 = {0x00, i11i10i9i8, 0x00, i11i10i9i8}
|
||||
// f11f10 = {0x00, i13i12i11i10, 0x00, i13i12i11i10}
|
||||
// f13f12 = {0x00, i15i14i13i12, 0x00, i15i14i13i12}
|
||||
// f15f14 = {0x00, 0000i15i14, 0x00, 0000i15i14}
|
||||
// We use inline asm instead of __byte_perm intrinsic since we don't want the documented (& 0x7) on the index. NVCC
|
||||
// might be able to optimize it out since the index is a constexpr, but we choose to be safe about it here.
|
||||
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
|
||||
static_assert(RegArray::kElements <= 8, "Too many inputs for I2 -> FP16 vector converter");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
|
||||
asm volatile(
|
||||
"{ prmt.b32 %0, %1, %2, %3; }\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src_reg), "n"(0), "r"(prmt_indices[ii / 2]));
|
||||
|
||||
asm volatile(
|
||||
"{ prmt.b32 %0, %1, %2, %3; }\n"
|
||||
: "=r"(r[ii + 1])
|
||||
: "r"(src_reg_shifted), "n"(0), "r"(prmt_indices[ii / 2]));
|
||||
}
|
||||
|
||||
// The below XOR does the following:
|
||||
// Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing
|
||||
// 1024 + x + 2, 1024 + 4 * (x + 2)
|
||||
// We use lop3 so that we can use 1 instruction for AND and XOR.
|
||||
// static constexpr uint32_t xor_mask[2] = { 0x64086402, 0x64806420};
|
||||
// static constexpr uint32_t and_mask[2] = { 0x000C0003, 0x00C00030};
|
||||
static constexpr uint32_t xor_mask = 0x64086402;
|
||||
static constexpr uint32_t and_mask = 0x000C0003;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
// For each operand, computes:
|
||||
// r[i] = (r[i] & and_mask[i / 2]) ^ xor_mask[i / 2]
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{ lop3.b32 %0, %0, %1, %2, %3; }\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||
}
|
||||
|
||||
// {-258, -1026}
|
||||
static constexpr uint32_t hfma_bias_rep = 0xDC08E402;
|
||||
// {1/4, 1}
|
||||
static constexpr uint32_t hfma_scale_rep = 0x34003C00;
|
||||
|
||||
// Scale and subtract the FP16s to get the original int4 number as FP16.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
|
||||
fp16x2_val = __hfma2(fp16x2_val,
|
||||
reinterpret_cast<const half2&>(hfma_scale_rep),
|
||||
reinterpret_cast<const half2&>(hfma_bias_rep));
|
||||
}
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const &source) {
|
||||
result_type result;
|
||||
using ConverterType = NumericArrayConverter<typename result_type::Element, typename source_type::Element, N, Round>;
|
||||
detail::VectorizedConverter::convert<ConverterType,
|
||||
result_type_packed_16, source_type_packed_16,
|
||||
result_type_packed_8, source_type_packed_8,
|
||||
result_type_packed_4, source_type_packed_4>(result, source);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<cutlass::half_t, N> <= Array<cutlass::uint2b_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::half_t, cutlass::uint2b_t, N, Round> {
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<cutlass::uint2b_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
using result_type_packed_16 = Array<cutlass::half_t, 16>;
|
||||
using result_type_packed_8 = Array<cutlass::half_t, 8>;
|
||||
using result_type_packed_4 = Array<cutlass::half_t, 4>;
|
||||
using source_type_packed_16 = Array<cutlass::uint2b_t, 16>;
|
||||
using source_type_packed_8 = Array<cutlass::uint2b_t, 8>;
|
||||
using source_type_packed_4 = Array<cutlass::uint2b_t, 4>;
|
||||
|
||||
using ScalarConverter = NumericConverter<cutlass::half_t, cutlass::uint2b_t, Round>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_4 const& source) {
|
||||
return static_cast<uint32_t>(
|
||||
reinterpret_cast<const uint8_t&>(source));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_8 const& source) {
|
||||
return static_cast<uint32_t>(
|
||||
reinterpret_cast<const uint16_t&>(source));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_16 const& source) {
|
||||
return reinterpret_cast<const uint32_t&>(source);
|
||||
}
|
||||
|
||||
template <typename PackedResultType, typename PackedSrcType>
|
||||
CUTLASS_DEVICE
|
||||
static PackedResultType packed_convert(PackedSrcType const &source) {
|
||||
|
||||
static_assert((platform::is_same<PackedSrcType, source_type_packed_4>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_4>::value) ||
|
||||
(platform::is_same<PackedSrcType, source_type_packed_8>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_8>::value) ||
|
||||
(platform::is_same<PackedSrcType, source_type_packed_16>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_16>::value),
|
||||
"Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch.");
|
||||
|
||||
// Hold output FP16s in reg. We need 1 reg for every 2 elements
|
||||
using RegArray = cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
// View the input as reg
|
||||
uint32_t src_reg = to_reg(source);
|
||||
uint32_t src_reg_shifted = src_reg >> 4;
|
||||
|
||||
// Below constructs the following temporary:
|
||||
// f1f0 = {0x00, u3u2u1u0, 0x00, u3u2u1u0}
|
||||
// f3f2 = {0x00, u5u4u3u2, 0x00, u5u4u3u2}
|
||||
// f5f4 = {0x00, u7u6u5u4, 0x00, u7u6u5u4}
|
||||
// f7f6 = {0x00, u9u8u7u6, 0x00, u9u8u7u6}
|
||||
// f9f8 = {0x00, u11u10u9u8, 0x00, u11u10u9u8}
|
||||
// f11f10 = {0x00, u13u12u11u10, 0x00, u13u12u11u10}
|
||||
// f13f12 = {0x00, u15u14u13u12, 0x00, u15u14u13u12}
|
||||
// f15f14 = {0x00, 0000u15u14, 0x00, 0000u15u14}
|
||||
// We use inline asm instead of __byte_perm intrinsic since we don't want the documented (& 0x7) on the index. NVCC
|
||||
// might be able to optimize it out since the index is a constexpr, but we choose to be safe about it here.
|
||||
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
|
||||
static_assert(RegArray::kElements <= 8, "Too many inputs for I2 -> FP16 vector converter");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
|
||||
asm volatile(
|
||||
"{ prmt.b32 %0, %1, %2, %3; }\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src_reg), "n"(0), "r"(prmt_indices[ii / 2]));
|
||||
|
||||
asm volatile(
|
||||
"{ prmt.b32 %0, %1, %2, %3; }\n"
|
||||
: "=r"(r[ii + 1])
|
||||
: "r"(src_reg_shifted), "n"(0), "r"(prmt_indices[ii / 2]));
|
||||
}
|
||||
|
||||
// The below XOR does the following:
|
||||
// Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing
|
||||
// 1024 + x, 1024 + 4 * x
|
||||
// We use lop3 so that we can use 1 instruction for AND and OR.
|
||||
static constexpr uint32_t xor_mask = 0x64006400;
|
||||
static constexpr uint32_t and_mask = 0x000C0003;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
// For each operand, computes:
|
||||
// r[i] = (r[i] & and_mask[i / 2]) ^ xor_mask[i / 2]
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{ lop3.b32 %0, %0, %1, %2, %3; }\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||
}
|
||||
|
||||
// {-256, -1024}
|
||||
static constexpr uint32_t hfma_bias_rep = 0xDC00E400;
|
||||
// {1/4, 1}
|
||||
static constexpr uint32_t hfma_scale_rep = 0x34003C00;
|
||||
|
||||
// Scale and subtract the FP16s to get the original int4 number as FP16.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
|
||||
fp16x2_val = __hfma2(fp16x2_val,
|
||||
reinterpret_cast<const half2&>(hfma_scale_rep),
|
||||
reinterpret_cast<const half2&>(hfma_bias_rep));
|
||||
}
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const &source) {
|
||||
result_type result;
|
||||
using ConverterType = NumericArrayConverter<typename result_type::Element, typename source_type::Element, N, Round>;
|
||||
detail::VectorizedConverter::convert<ConverterType,
|
||||
result_type_packed_16, source_type_packed_16,
|
||||
result_type_packed_8, source_type_packed_8,
|
||||
result_type_packed_4, source_type_packed_4>(result, source);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<cutlass::half_t, N> <= Array<cutlass::int4b_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::half_t, cutlass::int4b_t, N, Round> {
|
||||
@ -3830,13 +4104,11 @@ private:
|
||||
// We use inline asm instead of __byte_perm intrinsic since we don't want the documented (& 0x7) on the index. NVCC
|
||||
// might be able to optimize it out since the index is a constexpr, but we choose to be safe about it here.
|
||||
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
|
||||
static_assert(RegArray::kElements <= 4, "Too many inputs for F16 -> I4 vector converter");
|
||||
static_assert(RegArray::kElements <= 4, "Too many inputs for I4 ->F16 vector converter");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" prmt.b32 %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
"{ prmt.b32 %0, %1, %2, %3; }\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src_reg), "n"(0), "r"(prmt_indices[ii]));
|
||||
}
|
||||
@ -3891,6 +4163,133 @@ private:
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const &source) {
|
||||
result_type result;
|
||||
using ConverterType = NumericArrayConverter<typename result_type::Element, typename source_type::Element, N, Round>;
|
||||
detail::VectorizedConverter::convert<ConverterType,
|
||||
result_type_packed_8, source_type_packed_8,
|
||||
result_type_packed_4, source_type_packed_4,
|
||||
result_type_packed_2, source_type_packed_2>(result, source);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<cutlass::half_t, N> <= Array<cutlass::uint4b_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::half_t, cutlass::uint4b_t, N, Round> {
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<cutlass::uint4b_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
using result_type_packed_8 = Array<cutlass::half_t, 8>;
|
||||
using result_type_packed_4 = Array<cutlass::half_t, 4>;
|
||||
using result_type_packed_2 = Array<cutlass::half_t, 2>;
|
||||
using source_type_packed_8 = Array<cutlass::uint4b_t, 8>;
|
||||
using source_type_packed_4 = Array<cutlass::uint4b_t, 4>;
|
||||
using source_type_packed_2 = Array<cutlass::uint4b_t, 2>;
|
||||
|
||||
using ScalarConverter = NumericConverter<cutlass::half_t, cutlass::uint4b_t, Round>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_2 const& source) {
|
||||
return static_cast<uint32_t>(
|
||||
reinterpret_cast<const uint8_t&>(source));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_4 const& source) {
|
||||
return static_cast<uint32_t>(
|
||||
reinterpret_cast<const uint16_t&>(source));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_8 const& source) {
|
||||
return reinterpret_cast<const uint32_t&>(source);
|
||||
}
|
||||
|
||||
// The core converter uses bit tricks to construct a known FP16 number, then does a
|
||||
// subtraction in FP16 for the final result.
|
||||
template <typename PackedResultType, typename PackedSrcType>
|
||||
CUTLASS_DEVICE
|
||||
static PackedResultType packed_convert(PackedSrcType const &source) {
|
||||
|
||||
static_assert((platform::is_same<PackedSrcType, source_type_packed_2>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_2>::value) ||
|
||||
(platform::is_same<PackedSrcType, source_type_packed_4>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_4>::value) ||
|
||||
(platform::is_same<PackedSrcType, source_type_packed_8>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_8>::value),
|
||||
"Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch.");
|
||||
|
||||
// Hold output FP16s in reg. We need 1 reg for every 2 elements
|
||||
using RegArray = cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
// View the input as reg
|
||||
uint32_t src_reg = to_reg(source);
|
||||
// Below constructs the following temporary:
|
||||
// fp16s_01 = {0x00, u4_01, 0x00, u4_01}
|
||||
// fp16s_23 = {0x00, u4_23, 0x00, u4_23}
|
||||
// fp16s_45 = {0x00, u4_45, 0x00, u4_45}
|
||||
// fp16s_67 = {0x00, u4_67, 0x00, u4_67}
|
||||
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
|
||||
static_assert(RegArray::kElements <= 4, "Too many inputs for u4 -> f16 vector converter");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{ prmt.b32 %0, %1, %2, %3; }\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src_reg), "n"(0), "r"(prmt_indices[ii]));
|
||||
}
|
||||
|
||||
// The below XOR does the following:
|
||||
// Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing
|
||||
// 1024 + x, then using hsub2 to subtract 1024 from that
|
||||
static constexpr uint32_t or_mask = 0x64006400;
|
||||
static constexpr uint32_t and_mask = 0x00F0000F;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
|
||||
// For each operand, computes:
|
||||
// r[i] = (r[i] & and_mask) | or_mask
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(and_mask), "n"(or_mask), "n"(immLut));
|
||||
|
||||
// We will issue 2 hfmas that do the following:
|
||||
// For the high FP16:
|
||||
// Divide by 16 {packed as a operand} to get:
|
||||
// 64 + x
|
||||
// Subtract 64 {packed as c operand} to get x
|
||||
// For the low FP16:
|
||||
// we subtract 1024 {packed as c operand} to get x
|
||||
|
||||
static constexpr uint32_t hfma_bias = 0xD400E400; // {-64, -1024}
|
||||
static constexpr uint32_t hfma_scale = 0x2C003C00; // {1 / 16, 1}
|
||||
|
||||
{
|
||||
__half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
|
||||
fp16x2_val = __hfma2(fp16x2_val, reinterpret_cast<const __half2&>(hfma_scale), reinterpret_cast<const __half2&>(hfma_bias));
|
||||
}
|
||||
}
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
|
||||
public:
|
||||
@ -4108,6 +4507,260 @@ public:
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Partial specialization for Array<cutlass::bfloat16_t, N> <= Array<cutlass::int2b_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::bfloat16_t, cutlass::int2b_t, N, Round> {
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<cutlass::int2b_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
using result_type_packed_16 = Array<cutlass::bfloat16_t, 16>;
|
||||
using result_type_packed_8 = Array<cutlass::bfloat16_t, 8>;
|
||||
using result_type_packed_4 = Array<cutlass::bfloat16_t, 4>;
|
||||
using source_type_packed_16 = Array<cutlass::int2b_t, 16>;
|
||||
using source_type_packed_8 = Array<cutlass::int2b_t, 8>;
|
||||
using source_type_packed_4 = Array<cutlass::int2b_t, 4>;
|
||||
|
||||
using ScalarConverter = NumericConverter<cutlass::bfloat16_t, cutlass::int2b_t, Round>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_4 const& source) {
|
||||
return static_cast<uint32_t>(
|
||||
reinterpret_cast<const uint8_t&>(source));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_8 const& source) {
|
||||
return static_cast<uint32_t>(
|
||||
reinterpret_cast<const uint16_t&>(source));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_16 const& source) {
|
||||
return reinterpret_cast<const uint32_t&>(source);
|
||||
}
|
||||
|
||||
template <typename PackedResultType, typename PackedSrcType>
|
||||
CUTLASS_DEVICE
|
||||
static PackedResultType packed_convert(PackedSrcType const &source) {
|
||||
|
||||
static_assert((platform::is_same<PackedSrcType, source_type_packed_4>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_4>::value) ||
|
||||
(platform::is_same<PackedSrcType, source_type_packed_8>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_8>::value) ||
|
||||
(platform::is_same<PackedSrcType, source_type_packed_16>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_16>::value),
|
||||
"Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch.");
|
||||
|
||||
using RegArray = cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
// View the input as reg
|
||||
uint32_t src_reg = to_reg(source);
|
||||
uint32_t src_reg_shifted_two = src_reg >> 2;
|
||||
uint32_t src_reg_shifted_four = src_reg >> 4;
|
||||
uint32_t src_reg_shifted_six = src_reg >> 6;
|
||||
|
||||
// Modified prmt indices for signed 2-bit values
|
||||
uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
|
||||
|
||||
static_assert(RegArray::kElements <= 8, "Too many inputs for I2 -> BF16 vector converter");
|
||||
|
||||
// First pass: extract and sign extend the 2-bit values
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
|
||||
asm volatile(
|
||||
"{ prmt.b32 %0, %1, %2, %3; }\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src_reg), "r"(src_reg_shifted_two), "r"(prmt_indices[ii / 2]));
|
||||
|
||||
asm volatile(
|
||||
"{ prmt.b32 %0, %1, %2, %3; }\n"
|
||||
: "=r"(r[ii + 1])
|
||||
: "r"(src_reg_shifted_four), "r"(src_reg_shifted_six), "r"(prmt_indices[ii / 2]));
|
||||
}
|
||||
|
||||
// For signed 2-bit integers:
|
||||
// 00 -> 0 (0)
|
||||
// 01 -> 1 (1)
|
||||
// 10 -> -2 (2 with sign extension)
|
||||
// 11 -> -1 (3 with sign extension)
|
||||
//static constexpr uint32_t sign_mask = 0x00020002; // Mask to check sign bit
|
||||
static constexpr uint32_t and_mask = 0x00030003; // Mask for 2 bits
|
||||
|
||||
// Modified for signed range (-2 to 1)
|
||||
// We'll construct numbers in the form 128 + (x + 2) and then subtract 130
|
||||
// to get back to our original range
|
||||
static constexpr uint32_t xor_mask = 0x43024302;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||
}
|
||||
|
||||
// Bias represents 130 in bfloat16 format
|
||||
// Subtracting 130 brings us back to our signed range (-2 to 1)
|
||||
static constexpr uint32_t bias_rep = 0x43024302; // {130, 130} in bfloat16
|
||||
const __nv_bfloat162& bias = reinterpret_cast<const __nv_bfloat162&>(bias_rep);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
bf16x2_val = __hsub2(bf16x2_val, bias);
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const &source) {
|
||||
result_type result;
|
||||
using ConverterType = NumericArrayConverter<typename result_type::Element, typename source_type::Element, N, Round>;
|
||||
detail::VectorizedConverter::convert<ConverterType,
|
||||
result_type_packed_16, source_type_packed_16,
|
||||
result_type_packed_8, source_type_packed_8,
|
||||
result_type_packed_4, source_type_packed_4>(result, source);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<cutlass::bfloat16_t, N> <= Array<cutlass::uint2b_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::bfloat16_t, cutlass::uint2b_t, N, Round> {
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<cutlass::uint2b_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
using result_type_packed_16 = Array<cutlass::bfloat16_t, 16>;
|
||||
using result_type_packed_8 = Array<cutlass::bfloat16_t, 8>;
|
||||
using result_type_packed_4 = Array<cutlass::bfloat16_t, 4>;
|
||||
using source_type_packed_16 = Array<cutlass::uint2b_t, 16>;
|
||||
using source_type_packed_8 = Array<cutlass::uint2b_t, 8>;
|
||||
using source_type_packed_4 = Array<cutlass::uint2b_t, 4>;
|
||||
|
||||
using ScalarConverter = NumericConverter<cutlass::bfloat16_t, cutlass::uint2b_t, Round>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_4 const& source) {
|
||||
return static_cast<uint32_t>(
|
||||
reinterpret_cast<const uint8_t&>(source));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_8 const& source) {
|
||||
return static_cast<uint32_t>(
|
||||
reinterpret_cast<const uint16_t&>(source));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_16 const& source) {
|
||||
return reinterpret_cast<const uint32_t&>(source);
|
||||
}
|
||||
|
||||
template <typename PackedResultType, typename PackedSrcType>
|
||||
CUTLASS_DEVICE
|
||||
static PackedResultType packed_convert(PackedSrcType const &source) {
|
||||
|
||||
static_assert((platform::is_same<PackedSrcType, source_type_packed_4>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_4>::value) ||
|
||||
(platform::is_same<PackedSrcType, source_type_packed_8>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_8>::value) ||
|
||||
(platform::is_same<PackedSrcType, source_type_packed_16>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_16>::value),
|
||||
"Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch.");
|
||||
|
||||
using RegArray = cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
// View the input as reg
|
||||
uint32_t src_reg = to_reg(source);
|
||||
uint32_t src_reg_shifted_two = src_reg >> 2;
|
||||
uint32_t src_reg_shifted_four = src_reg >> 4;
|
||||
uint32_t src_reg_shifted_six = src_reg >> 6;
|
||||
|
||||
// Modified prmt indices for signed 2-bit values
|
||||
uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
|
||||
|
||||
static_assert(RegArray::kElements <= 8, "Too many inputs for U2 -> BF16 vector converter");
|
||||
|
||||
// First pass: extract and sign extend the 2-bit values
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
|
||||
asm volatile(
|
||||
"{ prmt.b32 %0, %1, %2, %3; }\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src_reg), "r"(src_reg_shifted_two), "r"(prmt_indices[ii / 2]));
|
||||
|
||||
asm volatile(
|
||||
"{ prmt.b32 %0, %1, %2, %3; }\n"
|
||||
: "=r"(r[ii + 1])
|
||||
: "r"(src_reg_shifted_four), "r"(src_reg_shifted_six), "r"(prmt_indices[ii / 2]));
|
||||
}
|
||||
|
||||
static constexpr uint32_t and_mask = 0x00030003; // Mask for 2 bits
|
||||
static constexpr uint32_t xor_mask = 0x43004300;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{ lop3.b32 %0, %0, %1, %2, %3; }"
|
||||
: "+r"(r[ii])
|
||||
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||
}
|
||||
|
||||
static constexpr uint32_t bias_rep = xor_mask; // {128, 128} in bfloat16
|
||||
const __nv_bfloat162& bias = reinterpret_cast<const __nv_bfloat162&>(bias_rep);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
bf16x2_val = __hsub2(bf16x2_val, bias);
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const &source) {
|
||||
result_type result;
|
||||
using ConverterType = NumericArrayConverter<typename result_type::Element, typename source_type::Element, N, Round>;
|
||||
detail::VectorizedConverter::convert<ConverterType,
|
||||
result_type_packed_16, source_type_packed_16,
|
||||
result_type_packed_8, source_type_packed_8,
|
||||
result_type_packed_4, source_type_packed_4>(result, source);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<cutlass::bfloat16_t, N> <= Array<cutlass::int4b_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::bfloat16_t, cutlass::int4b_t, N, Round> {
|
||||
@ -4171,9 +4824,7 @@ private:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" prmt.b32 %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
"{ prmt.b32 %0, %1, %2, %3; }\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii]));
|
||||
}
|
||||
@ -4185,6 +4836,133 @@ private:
|
||||
static constexpr uint32_t and_mask = 0x000F000F;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
// For each operand, computes:
|
||||
// r[i] = (r[i] & and_mask) ^ xor_mask
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{ lop3.b32 %0, %0, %1, %2, %3; }\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||
}
|
||||
|
||||
// We will issue 2 bfmas that do the following:
|
||||
// high BF16:
|
||||
// hi_bf16 - 136, lo_bf16 - 136
|
||||
|
||||
// This is the BF16 {136, 136} represented as an integer.
|
||||
static constexpr uint32_t bias_rep = 0x43084308;
|
||||
const __nv_bfloat162& bias = reinterpret_cast<const __nv_bfloat162&>(bias_rep);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
bf16x2_val = __hsub2(bf16x2_val, bias);
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const &source) {
|
||||
result_type result;
|
||||
using ConverterType = NumericArrayConverter<typename result_type::Element, typename source_type::Element, N, Round>;
|
||||
detail::VectorizedConverter::convert<ConverterType,
|
||||
result_type_packed_8, source_type_packed_8,
|
||||
result_type_packed_4, source_type_packed_4,
|
||||
result_type_packed_2, source_type_packed_2>(result, source);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Partial specialization for Array<cutlass::bfloat16_t, N> <= Array<cutlass::uint4b_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::bfloat16_t, cutlass::uint4b_t, N, Round> {
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<cutlass::uint4b_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
using result_type_packed_8 = Array<cutlass::bfloat16_t, 8>;
|
||||
using result_type_packed_4 = Array<cutlass::bfloat16_t, 4>;
|
||||
using result_type_packed_2 = Array<cutlass::bfloat16_t, 2>;
|
||||
using source_type_packed_8 = Array<cutlass::uint4b_t, 8>;
|
||||
using source_type_packed_4 = Array<cutlass::uint4b_t, 4>;
|
||||
using source_type_packed_2 = Array<cutlass::uint4b_t, 2>;
|
||||
|
||||
using ScalarConverter = NumericConverter<cutlass::bfloat16_t, cutlass::uint4b_t, Round>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_2 const& source) {
|
||||
return static_cast<uint32_t>(
|
||||
reinterpret_cast<const uint8_t&>(source));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_4 const& source) {
|
||||
return static_cast<uint32_t>(
|
||||
reinterpret_cast<const uint16_t&>(source));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static uint32_t to_reg(source_type_packed_8 const& source) {
|
||||
return reinterpret_cast<const uint32_t&>(source);
|
||||
}
|
||||
|
||||
// The core converter uses bit tricks to construct a known FP16 number, then does a
|
||||
// subtraction in FP16 for the final result.
|
||||
template <typename PackedResultType, typename PackedSrcType>
|
||||
CUTLASS_DEVICE
|
||||
static PackedResultType packed_convert(PackedSrcType const &source) {
|
||||
|
||||
static_assert((platform::is_same<PackedSrcType, source_type_packed_2>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_2>::value) ||
|
||||
(platform::is_same<PackedSrcType, source_type_packed_4>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_4>::value) ||
|
||||
(platform::is_same<PackedSrcType, source_type_packed_8>::value &&
|
||||
platform::is_same<PackedResultType, result_type_packed_8>::value),
|
||||
"Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch.");
|
||||
|
||||
// Hold output FP16s in reg. We need 1 reg for every 2 elements
|
||||
using RegArray = cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
// View the input as reg
|
||||
uint32_t src_reg = to_reg(source);
|
||||
uint32_t src_reg_shifted = src_reg >> 4;
|
||||
|
||||
// Below constructs the following temporary:
|
||||
// fp16s_01 = {0x00, u4_21, 0x00, u4_10}
|
||||
// fp16s_23 = {0x00, u4_43, 0x00, u4_32}
|
||||
// fp16s_45 = {0x00, u4_65, 0x00, u4_54}
|
||||
// fp16s_67 = {0x000, u4_7, 0x00, u4_76}
|
||||
static constexpr uint32_t prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
|
||||
static_assert(RegArray::kElements <= 4, "Too many inputs for BF16 -> I4 vector converter");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" prmt.b32 %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii]));
|
||||
}
|
||||
|
||||
static constexpr uint32_t xor_mask = 0x43004300;
|
||||
static constexpr uint32_t and_mask = 0x000F000F;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
// For each operand, computes:
|
||||
// r[i] = (r[i] & and_mask) ^ xor_mask
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -4199,16 +4977,15 @@ private:
|
||||
|
||||
// We will issue 2 bfmas that do the following:
|
||||
// high BF16:
|
||||
// hi_bf16 - 136, lo_bf16 - 136
|
||||
// hi_bf16 - 128, lo_bf16 - 128
|
||||
|
||||
// This is the BF16 {136, 136} represented as an integer.
|
||||
static constexpr uint32_t bias_rep = 0x43084308;
|
||||
const __nv_bfloat162& bias = reinterpret_cast<const __nv_bfloat162&>(bias_rep);
|
||||
// This is the BF16 {128, 128} represented as an integer.
|
||||
static constexpr uint32_t bias = xor_mask;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
bf16x2_val = __hsub2(bf16x2_val, bias);
|
||||
bf16x2_val = __hsub2(bf16x2_val, reinterpret_cast<const __nv_bfloat162&>(bias));
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
|
||||
Reference in New Issue
Block a user