add {uint4, uint2, int2} => {fp16, bf16} conversion (#1966)

This commit is contained in:
Lain
2024-12-03 11:03:43 -08:00
committed by GitHub
parent b0e09d7cd3
commit 80243e0b8c
2 changed files with 810 additions and 12 deletions

View File

@ -3766,6 +3766,280 @@ public:
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Array<cutlass::half_t, N> <= Array<cutlass::int2b_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::half_t, cutlass::int2b_t, N, Round> {
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<cutlass::int2b_t, N>;
static FloatRoundStyle const round_style = Round;
private:
using result_type_packed_16 = Array<cutlass::half_t, 16>;
using result_type_packed_8 = Array<cutlass::half_t, 8>;
using result_type_packed_4 = Array<cutlass::half_t, 4>;
using source_type_packed_16 = Array<cutlass::int2b_t, 16>;
using source_type_packed_8 = Array<cutlass::int2b_t, 8>;
using source_type_packed_4 = Array<cutlass::int2b_t, 4>;
using ScalarConverter = NumericConverter<cutlass::half_t, cutlass::int2b_t, Round>;
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_4 const& source) {
return static_cast<uint32_t>(
reinterpret_cast<const uint8_t&>(source));
}
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_8 const& source) {
return static_cast<uint32_t>(
reinterpret_cast<const uint16_t&>(source));
}
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_16 const& source) {
return reinterpret_cast<const uint32_t&>(source);
}
template <typename PackedResultType, typename PackedSrcType>
CUTLASS_DEVICE
static PackedResultType packed_convert(PackedSrcType const &source) {
static_assert((platform::is_same<PackedSrcType, source_type_packed_4>::value &&
platform::is_same<PackedResultType, result_type_packed_4>::value) ||
(platform::is_same<PackedSrcType, source_type_packed_8>::value &&
platform::is_same<PackedResultType, result_type_packed_8>::value) ||
(platform::is_same<PackedSrcType, source_type_packed_16>::value &&
platform::is_same<PackedResultType, result_type_packed_16>::value),
"Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch.");
// Hold output FP16s in reg. We need 1 reg for every 2 elements
using RegArray = cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, sizeof(PackedResultType)>;
RegArray r;
// View the input as reg
uint32_t src_reg = to_reg(source);
uint32_t src_reg_shifted = src_reg >> 4;
// Below constructs the following temporary:
// f1f0 = {0x00, i3i2i1i0, 0x00, i3i2i1i0}
// f3f2 = {0x00, i5i4i3i2, 0x00, i5i4i3i2}
// f5f4 = {0x00, i7i6i5i4, 0x00, i7i6i5i4}
// f7f6 = {0x00, i9i8i7i6, 0x00, i9i8i7i6}
// f9f8 = {0x00, i11i10i9i8, 0x00, i11i10i9i8}
// f11f10 = {0x00, i13i12i11i10, 0x00, i13i12i11i10}
// f13f12 = {0x00, i15i14i13i12, 0x00, i15i14i13i12}
// f15f14 = {0x00, 0000i15i14, 0x00, 0000i15i14}
// We use inline asm instead of __byte_perm intrinsic since we don't want the documented (& 0x7) on the index. NVCC
// might be able to optimize it out since the index is a constexpr, but we choose to be safe about it here.
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
static_assert(RegArray::kElements <= 8, "Too many inputs for I2 -> FP16 vector converter");
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
asm volatile(
"{ prmt.b32 %0, %1, %2, %3; }\n"
: "=r"(r[ii])
: "r"(src_reg), "n"(0), "r"(prmt_indices[ii / 2]));
asm volatile(
"{ prmt.b32 %0, %1, %2, %3; }\n"
: "=r"(r[ii + 1])
: "r"(src_reg_shifted), "n"(0), "r"(prmt_indices[ii / 2]));
}
// The below XOR does the following:
// Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing
// 1024 + x + 2, 1024 + 4 * (x + 2)
// We use lop3 so that we can use 1 instruction for AND and XOR.
// static constexpr uint32_t xor_mask[2] = { 0x64086402, 0x64806420};
// static constexpr uint32_t and_mask[2] = { 0x000C0003, 0x00C00030};
static constexpr uint32_t xor_mask = 0x64086402;
static constexpr uint32_t and_mask = 0x000C0003;
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
// For each operand, computes:
// r[i] = (r[i] & and_mask[i / 2]) ^ xor_mask[i / 2]
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{ lop3.b32 %0, %0, %1, %2, %3; }\n"
: "+r"(r[ii])
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
}
// {-258, -1026}
static constexpr uint32_t hfma_bias_rep = 0xDC08E402;
// {1/4, 1}
static constexpr uint32_t hfma_scale_rep = 0x34003C00;
// Scale and subtract the FP16s to get the original int4 number as FP16.
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
fp16x2_val = __hfma2(fp16x2_val,
reinterpret_cast<const half2&>(hfma_scale_rep),
reinterpret_cast<const half2&>(hfma_bias_rep));
}
return reinterpret_cast<PackedResultType&>(r);
}
friend class detail::VectorizedConverter;
public:
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;
using ConverterType = NumericArrayConverter<typename result_type::Element, typename source_type::Element, N, Round>;
detail::VectorizedConverter::convert<ConverterType,
result_type_packed_16, source_type_packed_16,
result_type_packed_8, source_type_packed_8,
result_type_packed_4, source_type_packed_4>(result, source);
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/// Partial specialization for Array<cutlass::half_t, N> <= Array<cutlass::uint2b_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::half_t, cutlass::uint2b_t, N, Round> {
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<cutlass::uint2b_t, N>;
static FloatRoundStyle const round_style = Round;
private:
using result_type_packed_16 = Array<cutlass::half_t, 16>;
using result_type_packed_8 = Array<cutlass::half_t, 8>;
using result_type_packed_4 = Array<cutlass::half_t, 4>;
using source_type_packed_16 = Array<cutlass::uint2b_t, 16>;
using source_type_packed_8 = Array<cutlass::uint2b_t, 8>;
using source_type_packed_4 = Array<cutlass::uint2b_t, 4>;
using ScalarConverter = NumericConverter<cutlass::half_t, cutlass::uint2b_t, Round>;
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_4 const& source) {
return static_cast<uint32_t>(
reinterpret_cast<const uint8_t&>(source));
}
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_8 const& source) {
return static_cast<uint32_t>(
reinterpret_cast<const uint16_t&>(source));
}
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_16 const& source) {
return reinterpret_cast<const uint32_t&>(source);
}
template <typename PackedResultType, typename PackedSrcType>
CUTLASS_DEVICE
static PackedResultType packed_convert(PackedSrcType const &source) {
static_assert((platform::is_same<PackedSrcType, source_type_packed_4>::value &&
platform::is_same<PackedResultType, result_type_packed_4>::value) ||
(platform::is_same<PackedSrcType, source_type_packed_8>::value &&
platform::is_same<PackedResultType, result_type_packed_8>::value) ||
(platform::is_same<PackedSrcType, source_type_packed_16>::value &&
platform::is_same<PackedResultType, result_type_packed_16>::value),
"Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch.");
// Hold output FP16s in reg. We need 1 reg for every 2 elements
using RegArray = cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, sizeof(PackedResultType)>;
RegArray r;
// View the input as reg
uint32_t src_reg = to_reg(source);
uint32_t src_reg_shifted = src_reg >> 4;
// Below constructs the following temporary:
// f1f0 = {0x00, u3u2u1u0, 0x00, u3u2u1u0}
// f3f2 = {0x00, u5u4u3u2, 0x00, u5u4u3u2}
// f5f4 = {0x00, u7u6u5u4, 0x00, u7u6u5u4}
// f7f6 = {0x00, u9u8u7u6, 0x00, u9u8u7u6}
// f9f8 = {0x00, u11u10u9u8, 0x00, u11u10u9u8}
// f11f10 = {0x00, u13u12u11u10, 0x00, u13u12u11u10}
// f13f12 = {0x00, u15u14u13u12, 0x00, u15u14u13u12}
// f15f14 = {0x00, 0000u15u14, 0x00, 0000u15u14}
// We use inline asm instead of __byte_perm intrinsic since we don't want the documented (& 0x7) on the index. NVCC
// might be able to optimize it out since the index is a constexpr, but we choose to be safe about it here.
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
static_assert(RegArray::kElements <= 8, "Too many inputs for I2 -> FP16 vector converter");
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
asm volatile(
"{ prmt.b32 %0, %1, %2, %3; }\n"
: "=r"(r[ii])
: "r"(src_reg), "n"(0), "r"(prmt_indices[ii / 2]));
asm volatile(
"{ prmt.b32 %0, %1, %2, %3; }\n"
: "=r"(r[ii + 1])
: "r"(src_reg_shifted), "n"(0), "r"(prmt_indices[ii / 2]));
}
// The below XOR does the following:
// Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing
// 1024 + x, 1024 + 4 * x
// We use lop3 so that we can use 1 instruction for AND and OR.
static constexpr uint32_t xor_mask = 0x64006400;
static constexpr uint32_t and_mask = 0x000C0003;
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
// For each operand, computes:
// r[i] = (r[i] & and_mask[i / 2]) ^ xor_mask[i / 2]
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{ lop3.b32 %0, %0, %1, %2, %3; }\n"
: "+r"(r[ii])
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
}
// {-256, -1024}
static constexpr uint32_t hfma_bias_rep = 0xDC00E400;
// {1/4, 1}
static constexpr uint32_t hfma_scale_rep = 0x34003C00;
// Scale and subtract the FP16s to get the original int4 number as FP16.
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
fp16x2_val = __hfma2(fp16x2_val,
reinterpret_cast<const half2&>(hfma_scale_rep),
reinterpret_cast<const half2&>(hfma_bias_rep));
}
return reinterpret_cast<PackedResultType&>(r);
}
friend class detail::VectorizedConverter;
public:
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;
using ConverterType = NumericArrayConverter<typename result_type::Element, typename source_type::Element, N, Round>;
detail::VectorizedConverter::convert<ConverterType,
result_type_packed_16, source_type_packed_16,
result_type_packed_8, source_type_packed_8,
result_type_packed_4, source_type_packed_4>(result, source);
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/// Partial specialization for Array<cutlass::half_t, N> <= Array<cutlass::int4b_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::half_t, cutlass::int4b_t, N, Round> {
@ -3830,13 +4104,11 @@ private:
// We use inline asm instead of __byte_perm intrinsic since we don't want the documented (& 0x7) on the index. NVCC
// might be able to optimize it out since the index is a constexpr, but we choose to be safe about it here.
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
static_assert(RegArray::kElements <= 4, "Too many inputs for F16 -> I4 vector converter");
static_assert(RegArray::kElements <= 4, "Too many inputs for I4 ->F16 vector converter");
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" prmt.b32 %0, %1, %2, %3;\n"
"}\n"
"{ prmt.b32 %0, %1, %2, %3; }\n"
: "=r"(r[ii])
: "r"(src_reg), "n"(0), "r"(prmt_indices[ii]));
}
@ -3891,6 +4163,133 @@ private:
return reinterpret_cast<PackedResultType&>(r);
}
friend class detail::VectorizedConverter;
public:
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;
using ConverterType = NumericArrayConverter<typename result_type::Element, typename source_type::Element, N, Round>;
detail::VectorizedConverter::convert<ConverterType,
result_type_packed_8, source_type_packed_8,
result_type_packed_4, source_type_packed_4,
result_type_packed_2, source_type_packed_2>(result, source);
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/// Partial specialization for Array<cutlass::half_t, N> <= Array<cutlass::uint4b_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::half_t, cutlass::uint4b_t, N, Round> {
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<cutlass::uint4b_t, N>;
static FloatRoundStyle const round_style = Round;
private:
using result_type_packed_8 = Array<cutlass::half_t, 8>;
using result_type_packed_4 = Array<cutlass::half_t, 4>;
using result_type_packed_2 = Array<cutlass::half_t, 2>;
using source_type_packed_8 = Array<cutlass::uint4b_t, 8>;
using source_type_packed_4 = Array<cutlass::uint4b_t, 4>;
using source_type_packed_2 = Array<cutlass::uint4b_t, 2>;
using ScalarConverter = NumericConverter<cutlass::half_t, cutlass::uint4b_t, Round>;
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_2 const& source) {
return static_cast<uint32_t>(
reinterpret_cast<const uint8_t&>(source));
}
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_4 const& source) {
return static_cast<uint32_t>(
reinterpret_cast<const uint16_t&>(source));
}
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_8 const& source) {
return reinterpret_cast<const uint32_t&>(source);
}
// The core converter uses bit tricks to construct a known FP16 number, then does a
// subtraction in FP16 for the final result.
template <typename PackedResultType, typename PackedSrcType>
CUTLASS_DEVICE
static PackedResultType packed_convert(PackedSrcType const &source) {
static_assert((platform::is_same<PackedSrcType, source_type_packed_2>::value &&
platform::is_same<PackedResultType, result_type_packed_2>::value) ||
(platform::is_same<PackedSrcType, source_type_packed_4>::value &&
platform::is_same<PackedResultType, result_type_packed_4>::value) ||
(platform::is_same<PackedSrcType, source_type_packed_8>::value &&
platform::is_same<PackedResultType, result_type_packed_8>::value),
"Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch.");
// Hold output FP16s in reg. We need 1 reg for every 2 elements
using RegArray = cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, sizeof(PackedResultType)>;
RegArray r;
// View the input as reg
uint32_t src_reg = to_reg(source);
// Below constructs the following temporary:
// fp16s_01 = {0x00, u4_01, 0x00, u4_01}
// fp16s_23 = {0x00, u4_23, 0x00, u4_23}
// fp16s_45 = {0x00, u4_45, 0x00, u4_45}
// fp16s_67 = {0x00, u4_67, 0x00, u4_67}
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
static_assert(RegArray::kElements <= 4, "Too many inputs for u4 -> f16 vector converter");
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{ prmt.b32 %0, %1, %2, %3; }\n"
: "=r"(r[ii])
: "r"(src_reg), "n"(0), "r"(prmt_indices[ii]));
}
// The below XOR does the following:
// Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing
// 1024 + x, then using hsub2 to subtract 1024 from that
static constexpr uint32_t or_mask = 0x64006400;
static constexpr uint32_t and_mask = 0x00F0000F;
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
// For each operand, computes:
// r[i] = (r[i] & and_mask) | or_mask
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii])
: "n"(and_mask), "n"(or_mask), "n"(immLut));
// We will issue 2 hfmas that do the following:
// For the high FP16:
// Divide by 16 {packed as a operand} to get:
// 64 + x
// Subtract 64 {packed as c operand} to get x
// For the low FP16:
// we subtract 1024 {packed as c operand} to get x
static constexpr uint32_t hfma_bias = 0xD400E400; // {-64, -1024}
static constexpr uint32_t hfma_scale = 0x2C003C00; // {1 / 16, 1}
{
__half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
fp16x2_val = __hfma2(fp16x2_val, reinterpret_cast<const __half2&>(hfma_scale), reinterpret_cast<const __half2&>(hfma_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
}
friend class detail::VectorizedConverter;
public:
@ -4108,6 +4507,260 @@ public:
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Array<cutlass::bfloat16_t, N> <= Array<cutlass::int2b_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::bfloat16_t, cutlass::int2b_t, N, Round> {
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<cutlass::int2b_t, N>;
static FloatRoundStyle const round_style = Round;
private:
using result_type_packed_16 = Array<cutlass::bfloat16_t, 16>;
using result_type_packed_8 = Array<cutlass::bfloat16_t, 8>;
using result_type_packed_4 = Array<cutlass::bfloat16_t, 4>;
using source_type_packed_16 = Array<cutlass::int2b_t, 16>;
using source_type_packed_8 = Array<cutlass::int2b_t, 8>;
using source_type_packed_4 = Array<cutlass::int2b_t, 4>;
using ScalarConverter = NumericConverter<cutlass::bfloat16_t, cutlass::int2b_t, Round>;
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_4 const& source) {
return static_cast<uint32_t>(
reinterpret_cast<const uint8_t&>(source));
}
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_8 const& source) {
return static_cast<uint32_t>(
reinterpret_cast<const uint16_t&>(source));
}
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_16 const& source) {
return reinterpret_cast<const uint32_t&>(source);
}
template <typename PackedResultType, typename PackedSrcType>
CUTLASS_DEVICE
static PackedResultType packed_convert(PackedSrcType const &source) {
static_assert((platform::is_same<PackedSrcType, source_type_packed_4>::value &&
platform::is_same<PackedResultType, result_type_packed_4>::value) ||
(platform::is_same<PackedSrcType, source_type_packed_8>::value &&
platform::is_same<PackedResultType, result_type_packed_8>::value) ||
(platform::is_same<PackedSrcType, source_type_packed_16>::value &&
platform::is_same<PackedResultType, result_type_packed_16>::value),
"Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch.");
using RegArray = cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, sizeof(PackedResultType)>;
RegArray r;
// View the input as reg
uint32_t src_reg = to_reg(source);
uint32_t src_reg_shifted_two = src_reg >> 2;
uint32_t src_reg_shifted_four = src_reg >> 4;
uint32_t src_reg_shifted_six = src_reg >> 6;
// Modified prmt indices for signed 2-bit values
uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
static_assert(RegArray::kElements <= 8, "Too many inputs for I2 -> BF16 vector converter");
// First pass: extract and sign extend the 2-bit values
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
asm volatile(
"{ prmt.b32 %0, %1, %2, %3; }\n"
: "=r"(r[ii])
: "r"(src_reg), "r"(src_reg_shifted_two), "r"(prmt_indices[ii / 2]));
asm volatile(
"{ prmt.b32 %0, %1, %2, %3; }\n"
: "=r"(r[ii + 1])
: "r"(src_reg_shifted_four), "r"(src_reg_shifted_six), "r"(prmt_indices[ii / 2]));
}
// For signed 2-bit integers:
// 00 -> 0 (0)
// 01 -> 1 (1)
// 10 -> -2 (2 with sign extension)
// 11 -> -1 (3 with sign extension)
//static constexpr uint32_t sign_mask = 0x00020002; // Mask to check sign bit
static constexpr uint32_t and_mask = 0x00030003; // Mask for 2 bits
// Modified for signed range (-2 to 1)
// We'll construct numbers in the form 128 + (x + 2) and then subtract 130
// to get back to our original range
static constexpr uint32_t xor_mask = 0x43024302;
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii])
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
}
// Bias represents 130 in bfloat16 format
// Subtracting 130 brings us back to our signed range (-2 to 1)
static constexpr uint32_t bias_rep = 0x43024302; // {130, 130} in bfloat16
const __nv_bfloat162& bias = reinterpret_cast<const __nv_bfloat162&>(bias_rep);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
bf16x2_val = __hsub2(bf16x2_val, bias);
}
return reinterpret_cast<PackedResultType&>(r);
}
friend class detail::VectorizedConverter;
public:
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;
using ConverterType = NumericArrayConverter<typename result_type::Element, typename source_type::Element, N, Round>;
detail::VectorizedConverter::convert<ConverterType,
result_type_packed_16, source_type_packed_16,
result_type_packed_8, source_type_packed_8,
result_type_packed_4, source_type_packed_4>(result, source);
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/// Partial specialization for Array<cutlass::bfloat16_t, N> <= Array<cutlass::uint2b_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::bfloat16_t, cutlass::uint2b_t, N, Round> {
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<cutlass::uint2b_t, N>;
static FloatRoundStyle const round_style = Round;
private:
using result_type_packed_16 = Array<cutlass::bfloat16_t, 16>;
using result_type_packed_8 = Array<cutlass::bfloat16_t, 8>;
using result_type_packed_4 = Array<cutlass::bfloat16_t, 4>;
using source_type_packed_16 = Array<cutlass::uint2b_t, 16>;
using source_type_packed_8 = Array<cutlass::uint2b_t, 8>;
using source_type_packed_4 = Array<cutlass::uint2b_t, 4>;
using ScalarConverter = NumericConverter<cutlass::bfloat16_t, cutlass::uint2b_t, Round>;
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_4 const& source) {
return static_cast<uint32_t>(
reinterpret_cast<const uint8_t&>(source));
}
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_8 const& source) {
return static_cast<uint32_t>(
reinterpret_cast<const uint16_t&>(source));
}
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_16 const& source) {
return reinterpret_cast<const uint32_t&>(source);
}
template <typename PackedResultType, typename PackedSrcType>
CUTLASS_DEVICE
static PackedResultType packed_convert(PackedSrcType const &source) {
static_assert((platform::is_same<PackedSrcType, source_type_packed_4>::value &&
platform::is_same<PackedResultType, result_type_packed_4>::value) ||
(platform::is_same<PackedSrcType, source_type_packed_8>::value &&
platform::is_same<PackedResultType, result_type_packed_8>::value) ||
(platform::is_same<PackedSrcType, source_type_packed_16>::value &&
platform::is_same<PackedResultType, result_type_packed_16>::value),
"Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch.");
using RegArray = cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, sizeof(PackedResultType)>;
RegArray r;
// View the input as reg
uint32_t src_reg = to_reg(source);
uint32_t src_reg_shifted_two = src_reg >> 2;
uint32_t src_reg_shifted_four = src_reg >> 4;
uint32_t src_reg_shifted_six = src_reg >> 6;
// Modified prmt indices for signed 2-bit values
uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
static_assert(RegArray::kElements <= 8, "Too many inputs for U2 -> BF16 vector converter");
// First pass: extract and sign extend the 2-bit values
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
asm volatile(
"{ prmt.b32 %0, %1, %2, %3; }\n"
: "=r"(r[ii])
: "r"(src_reg), "r"(src_reg_shifted_two), "r"(prmt_indices[ii / 2]));
asm volatile(
"{ prmt.b32 %0, %1, %2, %3; }\n"
: "=r"(r[ii + 1])
: "r"(src_reg_shifted_four), "r"(src_reg_shifted_six), "r"(prmt_indices[ii / 2]));
}
static constexpr uint32_t and_mask = 0x00030003; // Mask for 2 bits
static constexpr uint32_t xor_mask = 0x43004300;
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{ lop3.b32 %0, %0, %1, %2, %3; }"
: "+r"(r[ii])
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
}
static constexpr uint32_t bias_rep = xor_mask; // {128, 128} in bfloat16
const __nv_bfloat162& bias = reinterpret_cast<const __nv_bfloat162&>(bias_rep);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
bf16x2_val = __hsub2(bf16x2_val, bias);
}
return reinterpret_cast<PackedResultType&>(r);
}
friend class detail::VectorizedConverter;
public:
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;
using ConverterType = NumericArrayConverter<typename result_type::Element, typename source_type::Element, N, Round>;
detail::VectorizedConverter::convert<ConverterType,
result_type_packed_16, source_type_packed_16,
result_type_packed_8, source_type_packed_8,
result_type_packed_4, source_type_packed_4>(result, source);
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/// Partial specialization for Array<cutlass::bfloat16_t, N> <= Array<cutlass::int4b_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::bfloat16_t, cutlass::int4b_t, N, Round> {
@ -4171,9 +4824,7 @@ private:
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" prmt.b32 %0, %1, %2, %3;\n"
"}\n"
"{ prmt.b32 %0, %1, %2, %3; }\n"
: "=r"(r[ii])
: "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii]));
}
@ -4185,6 +4836,133 @@ private:
static constexpr uint32_t and_mask = 0x000F000F;
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
// For each operand, computes:
// r[i] = (r[i] & and_mask) ^ xor_mask
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{ lop3.b32 %0, %0, %1, %2, %3; }\n"
: "+r"(r[ii])
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
}
// We will issue 2 bfmas that do the following:
// high BF16:
// hi_bf16 - 136, lo_bf16 - 136
// This is the BF16 {136, 136} represented as an integer.
static constexpr uint32_t bias_rep = 0x43084308;
const __nv_bfloat162& bias = reinterpret_cast<const __nv_bfloat162&>(bias_rep);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
bf16x2_val = __hsub2(bf16x2_val, bias);
}
return reinterpret_cast<PackedResultType&>(r);
}
friend class detail::VectorizedConverter;
public:
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;
using ConverterType = NumericArrayConverter<typename result_type::Element, typename source_type::Element, N, Round>;
detail::VectorizedConverter::convert<ConverterType,
result_type_packed_8, source_type_packed_8,
result_type_packed_4, source_type_packed_4,
result_type_packed_2, source_type_packed_2>(result, source);
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Array<cutlass::bfloat16_t, N> <= Array<cutlass::uint4b_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::bfloat16_t, cutlass::uint4b_t, N, Round> {
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<cutlass::uint4b_t, N>;
static FloatRoundStyle const round_style = Round;
private:
using result_type_packed_8 = Array<cutlass::bfloat16_t, 8>;
using result_type_packed_4 = Array<cutlass::bfloat16_t, 4>;
using result_type_packed_2 = Array<cutlass::bfloat16_t, 2>;
using source_type_packed_8 = Array<cutlass::uint4b_t, 8>;
using source_type_packed_4 = Array<cutlass::uint4b_t, 4>;
using source_type_packed_2 = Array<cutlass::uint4b_t, 2>;
using ScalarConverter = NumericConverter<cutlass::bfloat16_t, cutlass::uint4b_t, Round>;
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_2 const& source) {
return static_cast<uint32_t>(
reinterpret_cast<const uint8_t&>(source));
}
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_4 const& source) {
return static_cast<uint32_t>(
reinterpret_cast<const uint16_t&>(source));
}
CUTLASS_DEVICE
static uint32_t to_reg(source_type_packed_8 const& source) {
return reinterpret_cast<const uint32_t&>(source);
}
// The core converter uses bit tricks to construct a known FP16 number, then does a
// subtraction in FP16 for the final result.
template <typename PackedResultType, typename PackedSrcType>
CUTLASS_DEVICE
static PackedResultType packed_convert(PackedSrcType const &source) {
static_assert((platform::is_same<PackedSrcType, source_type_packed_2>::value &&
platform::is_same<PackedResultType, result_type_packed_2>::value) ||
(platform::is_same<PackedSrcType, source_type_packed_4>::value &&
platform::is_same<PackedResultType, result_type_packed_4>::value) ||
(platform::is_same<PackedSrcType, source_type_packed_8>::value &&
platform::is_same<PackedResultType, result_type_packed_8>::value),
"Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch.");
// Hold output FP16s in reg. We need 1 reg for every 2 elements
using RegArray = cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, sizeof(PackedResultType)>;
RegArray r;
// View the input as reg
uint32_t src_reg = to_reg(source);
uint32_t src_reg_shifted = src_reg >> 4;
// Below constructs the following temporary:
// fp16s_01 = {0x00, u4_21, 0x00, u4_10}
// fp16s_23 = {0x00, u4_43, 0x00, u4_32}
// fp16s_45 = {0x00, u4_65, 0x00, u4_54}
// fp16s_67 = {0x000, u4_7, 0x00, u4_76}
static constexpr uint32_t prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
static_assert(RegArray::kElements <= 4, "Too many inputs for BF16 -> I4 vector converter");
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" prmt.b32 %0, %1, %2, %3;\n"
"}\n"
: "=r"(r[ii])
: "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii]));
}
static constexpr uint32_t xor_mask = 0x43004300;
static constexpr uint32_t and_mask = 0x000F000F;
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
// For each operand, computes:
// r[i] = (r[i] & and_mask) ^ xor_mask
CUTLASS_PRAGMA_UNROLL
@ -4199,16 +4977,15 @@ private:
// We will issue 2 bfmas that do the following:
// high BF16:
// hi_bf16 - 136, lo_bf16 - 136
// hi_bf16 - 128, lo_bf16 - 128
// This is the BF16 {136, 136} represented as an integer.
static constexpr uint32_t bias_rep = 0x43084308;
const __nv_bfloat162& bias = reinterpret_cast<const __nv_bfloat162&>(bias_rep);
// This is the BF16 {128, 128} represented as an integer.
static constexpr uint32_t bias = xor_mask;
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
bf16x2_val = __hsub2(bf16x2_val, bias);
bf16x2_val = __hsub2(bf16x2_val, reinterpret_cast<const __nv_bfloat162&>(bias));
}
return reinterpret_cast<PackedResultType&>(r);