v3.8.0 update (#2082)
* 3.8 update * fix Markus' name --------- Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
@ -222,8 +222,8 @@ cutlass_add_cutlass_library(
|
||||
# files split for parallel compilation
|
||||
src/reference/gemm_int4.cu
|
||||
|
||||
src/reference/block_scaled_gemm_fp4a_vs16.cu
|
||||
src/reference/block_scaled_gemm_fp4a_vs32.cu
|
||||
src/reference/block_scaled_gemm_fp4a_vs16.cu
|
||||
src/reference/block_scaled_gemm_fp4a_vs32.cu
|
||||
src/reference/block_scaled_gemm_mixed8bitsa.cu
|
||||
src/reference/gemm_f4_f4_f32.cu
|
||||
src/reference/gemm_f4_f6_f32.cu
|
||||
|
||||
@ -43,7 +43,8 @@
|
||||
computational overhead
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#ifndef CUTLASS_LIBRARY_LIBRARY_H
|
||||
#define CUTLASS_LIBRARY_LIBRARY_H
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -103,7 +104,7 @@ public:
|
||||
void *device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const = 0;
|
||||
|
||||
// Originally designed for metadata, but should be useful for FP8/6/4 too.
|
||||
// Originally designed for metadata, but should be useful for FP8/6/4 too.
|
||||
virtual Status initialize_with_profiler_workspace(
|
||||
void const *configuration,
|
||||
void *host_workspace,
|
||||
@ -282,6 +283,11 @@ struct GemmUniversalConfiguration {
|
||||
int device_count{1};
|
||||
};
|
||||
|
||||
enum class Sm90MixedInputWiderOperand {
|
||||
A = 0,
|
||||
B = 1
|
||||
};
|
||||
|
||||
struct GemmUniversalArguments {
|
||||
// NOTE: these are replicated for 3.0 interfaces
|
||||
gemm::GemmCoord problem_size{};
|
||||
@ -317,6 +323,18 @@ struct GemmUniversalArguments {
|
||||
int swizzle_size{1};
|
||||
int split_k_slices{1};
|
||||
|
||||
// For mixed input dtype kernels
|
||||
bool is_mixed_dtype{false};
|
||||
Sm90MixedInputWiderOperand wider_operand{Sm90MixedInputWiderOperand::B};
|
||||
bool generate_scale_and_zero{false};
|
||||
bool generate_dequantized_AB{false};
|
||||
bool *dequantized_AB_ready{nullptr}; // Carry the info back to gemm_operation_profiler.cu
|
||||
void *Scale{nullptr}; // Scale tensor
|
||||
void *Zero{nullptr}; // Zero tensor
|
||||
void *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification
|
||||
void *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle
|
||||
void *packed_Scale{nullptr}; // Packed scale for int4 * fp8
|
||||
|
||||
int device_index{0};
|
||||
|
||||
bool use_pdl{false};
|
||||
@ -472,12 +490,16 @@ struct GemmPlanarComplexArrayArguments {
|
||||
|
||||
struct GemmGroupedConfiguration {
|
||||
int problem_count{0};
|
||||
int threadblock_count{0};
|
||||
// GemmGroupedConfiguration is passed to initialize(), which
|
||||
// is responsible for allocating the device-side stride storage.
|
||||
int64_t* lda;
|
||||
int64_t* ldb;
|
||||
int64_t* ldc;
|
||||
};
|
||||
|
||||
struct GemmGroupedArguments {
|
||||
|
||||
gemm::GemmCoord *problem_sizes{nullptr};
|
||||
int problem_count{};
|
||||
gemm::GemmCoord* problem_sizes{nullptr};
|
||||
|
||||
void * ptr_A{nullptr};
|
||||
void * ptr_B{nullptr};
|
||||
@ -493,6 +515,18 @@ struct GemmGroupedArguments {
|
||||
void const *beta{nullptr};
|
||||
ScalarPointerMode pointer_mode{};
|
||||
bool use_pdl{false};
|
||||
|
||||
gemm::GemmCoord cluster_shape{};
|
||||
gemm::GemmCoord cluster_shape_fallback{};
|
||||
|
||||
// these should really be in the configuration but staying consistent with GEMM
|
||||
int sm_count{0};
|
||||
// The user is responsible for allocating storage for problem sizes.
|
||||
// Since GemmGroupedArguments is used by both the 2.x and 3.x APIs, we
|
||||
// unfortunately need to have both options in this struct, and the
|
||||
// underlying operation uses the one it needs.
|
||||
cute::Shape<int, int, int>* problem_sizes_3x;
|
||||
cute::Shape<int, int, int>* problem_sizes_3x_host;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -880,3 +914,5 @@ struct ReductionArguments {
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif
|
||||
|
||||
@ -142,7 +142,7 @@ enum class Provider {
|
||||
/// Enumeration indicating the kind of operation
|
||||
enum class OperationKind {
|
||||
kGemm,
|
||||
kBlockScaledGemm,
|
||||
kBlockScaledGemm,
|
||||
kRankK,
|
||||
kRank2K,
|
||||
kTrmm,
|
||||
@ -152,6 +152,7 @@ enum class OperationKind {
|
||||
kEqGemm,
|
||||
kSparseGemm,
|
||||
kReduction,
|
||||
kGroupedGemm,
|
||||
kInvalid
|
||||
};
|
||||
|
||||
@ -270,7 +271,6 @@ enum class RuntimeDatatype {
|
||||
kStatic,
|
||||
kE4M3,
|
||||
kE5M2,
|
||||
|
||||
kE3M2,
|
||||
kE2M3,
|
||||
kE2M1,
|
||||
|
||||
@ -34,7 +34,8 @@
|
||||
\brief Utilities accompanying the CUTLASS library for interacting with Library types.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#ifndef CUTLASS_LIBRARY_UTIL_H
|
||||
#define CUTLASS_LIBRARY_UTIL_H
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
@ -213,6 +214,63 @@ bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double sr
|
||||
|
||||
NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type);
|
||||
|
||||
#define CUDA_CHECK(call) \
|
||||
do { \
|
||||
cudaError_t err = (call); \
|
||||
if (err != cudaSuccess) { \
|
||||
std::cerr << "CUDA Error: " << cudaGetErrorString(err) << " in " << __func__ << " at " \
|
||||
<< __FILE__ << ":" << __LINE__ << std::endl; \
|
||||
return Status::kInvalid; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// RAII CUDA buffer container
|
||||
class CudaBuffer {
|
||||
public:
|
||||
CudaBuffer() : size_(0), d_ptr_(nullptr) {}
|
||||
|
||||
explicit CudaBuffer(size_t size) : size_(size), d_ptr_(nullptr) {
|
||||
cudaError_t err = cudaMalloc(&d_ptr_, size_);
|
||||
if (err != cudaSuccess) {
|
||||
throw std::runtime_error("cudaMalloc failed: " + std::string(cudaGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
|
||||
~CudaBuffer() {
|
||||
if (d_ptr_) {
|
||||
cudaFree(d_ptr_);
|
||||
}
|
||||
}
|
||||
|
||||
CudaBuffer(CudaBuffer const&) = delete;
|
||||
CudaBuffer& operator=(CudaBuffer const&) = delete;
|
||||
|
||||
CudaBuffer(CudaBuffer&& other) noexcept : size_(other.size_), d_ptr_(other.d_ptr_) {
|
||||
other.d_ptr_ = nullptr;
|
||||
other.size_ = 0;
|
||||
}
|
||||
|
||||
CudaBuffer& operator=(CudaBuffer&& other) noexcept {
|
||||
if (this != &other) {
|
||||
if (d_ptr_) {
|
||||
cudaFree(d_ptr_);
|
||||
}
|
||||
d_ptr_ = other.d_ptr_;
|
||||
size_ = other.size_;
|
||||
other.d_ptr_ = nullptr;
|
||||
other.size_ = 0;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void* data() const noexcept { return d_ptr_; }
|
||||
size_t size() const noexcept { return size_; }
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
void* d_ptr_;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
@ -220,3 +278,4 @@ NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif
|
||||
|
||||
@ -73,7 +73,7 @@ public:
|
||||
using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp;
|
||||
|
||||
using Sm100BlkScaledConfig = typename CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
|
||||
|
||||
static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v<typename ThreadEpilogueOp::ElementBlockScaleFactor, void>;
|
||||
static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize;
|
||||
using ElementSFD = cute::conditional_t<epilogue_scalefactor_generation, typename ThreadEpilogueOp::ElementBlockScaleFactor, void>;
|
||||
|
||||
@ -1201,25 +1201,30 @@ public:
|
||||
GemmOperationBase<Operator_>(name) {
|
||||
|
||||
this->description_.gemm_kind = GemmKind::kGrouped;
|
||||
this->description_.kind = OperationKind::kGroupedGemm;
|
||||
this->threadblock_count = Operator::sufficient();
|
||||
}
|
||||
|
||||
private:
|
||||
int threadblock_count;
|
||||
|
||||
protected:
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status construct_arguments_(
|
||||
Status construct_arguments_(
|
||||
OperatorArguments &op_args,
|
||||
GemmGroupedConfiguration const *config) {
|
||||
GemmGroupedConfiguration const *config) const {
|
||||
|
||||
op_args.problem_count = config->problem_count;
|
||||
op_args.threadblock_count = config->threadblock_count;
|
||||
op_args.threadblock_count = threadblock_count;
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status update_arguments_(
|
||||
Status update_arguments_(
|
||||
OperatorArguments &op_args,
|
||||
GemmGroupedArguments const *arguments) {
|
||||
GemmGroupedArguments const *arguments) const {
|
||||
|
||||
if (arguments->pointer_mode == ScalarPointerMode::kHost) {
|
||||
|
||||
@ -1243,6 +1248,8 @@ protected:
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
op_args.threadblock_count = threadblock_count;
|
||||
op_args.problem_count = arguments->problem_count;
|
||||
op_args.problem_sizes = arguments->problem_sizes;
|
||||
|
||||
op_args.ptr_A = static_cast<ElementA **>(arguments->ptr_A);
|
||||
|
||||
@ -36,9 +36,17 @@
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/detail/collective.hpp"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/array_subbyte.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "library_internal.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include <unordered_map>
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -65,7 +73,7 @@ public:
|
||||
using ElementAccumulator = typename Operator::ElementAccumulator;
|
||||
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
private:
|
||||
protected:
|
||||
GemmDescription description_;
|
||||
|
||||
public:
|
||||
@ -178,7 +186,23 @@ public:
|
||||
|
||||
/// Constructor
|
||||
GemmUniversal3xOperation(char const *name = "unknown_gemm"):
|
||||
GemmOperation3xBase<Operator_>(name, GemmKind::kUniversal) {}
|
||||
GemmOperation3xBase<Operator_>(name, GemmKind::kUniversal) {
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability == 90) {
|
||||
dim3 cluster_dims(
|
||||
cute::size<0>(typename Operator::GemmKernel::ClusterShape{}),
|
||||
cute::size<1>(typename Operator::GemmKernel::ClusterShape{}),
|
||||
cute::size<2>(typename Operator::GemmKernel::ClusterShape{}));
|
||||
uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock;
|
||||
void const* kernel_ptr = (void*)(device_kernel<typename Operator::GemmKernel>);
|
||||
max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters(
|
||||
cluster_dims,
|
||||
threads_per_block,
|
||||
kernel_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int max_active_clusters{};
|
||||
|
||||
protected:
|
||||
|
||||
@ -227,10 +251,119 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
static Status update_arguments_(
|
||||
template<template<int, class, class> class Policy, int Stages, class ClusterShape, class KernelSchedule>
|
||||
static constexpr bool is_mixed_dtype_mainloop_(Policy<Stages, ClusterShape, KernelSchedule> policy) {
|
||||
return (cute::is_same_v<Policy<Stages, ClusterShape, KernelSchedule>,
|
||||
cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<Stages, ClusterShape, KernelSchedule>>);
|
||||
}
|
||||
|
||||
template <class DispatchPolicy>
|
||||
static constexpr bool is_mixed_dtype_mainloop_(DispatchPolicy) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <
|
||||
typename ElementWide,
|
||||
typename ElementNarrow,
|
||||
typename ElementScaleMainloop,
|
||||
class ActualStrideAB,
|
||||
Sm90MixedInputWiderOperand wider_operand,
|
||||
bool is_n4w8,
|
||||
typename ElementScale,
|
||||
typename ElementZero,
|
||||
class Layout_SZ>
|
||||
static void dequantize_encode_(
|
||||
OperatorArguments &operator_args,
|
||||
GemmUniversalArguments const *arguments) {
|
||||
GemmUniversalArguments const *arguments,
|
||||
cudaStream_t stream,
|
||||
const int &problem_mn,
|
||||
const int &problem_k,
|
||||
const int &options_l,
|
||||
const int &options_g,
|
||||
ElementScale *ptr_S,
|
||||
ElementZero *ptr_Z,
|
||||
const size_t &SZ_size,
|
||||
Layout_SZ layout_SZ
|
||||
) {
|
||||
|
||||
auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l);
|
||||
auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB);
|
||||
auto layout_AB = cute::make_layout(shape_AB, stride_AB);
|
||||
auto *ptr_dequantized_AB = static_cast<ElementWide *>(arguments->dequantized_AB);
|
||||
const ElementNarrow *ptr_AB = nullptr;
|
||||
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
|
||||
ptr_AB = static_cast<const ElementNarrow *>(arguments->B);
|
||||
}
|
||||
else {
|
||||
ptr_AB = static_cast<const ElementNarrow *>(arguments->A);
|
||||
}
|
||||
dequantize(ptr_dequantized_AB, ptr_AB, layout_AB, ptr_S, ptr_Z, layout_SZ, options_g, stream);
|
||||
if constexpr(is_n4w8) {
|
||||
size_t AB_size = cute::size(layout_AB);
|
||||
cutlass::int4b_t *encoded_AB = static_cast<cutlass::int4b_t *>(arguments->encoded_AB);
|
||||
unified_encode_int4b(ptr_AB, encoded_AB, AB_size);
|
||||
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
|
||||
operator_args.mainloop.ptr_B = static_cast<ElementNarrow const *>(encoded_AB);
|
||||
}
|
||||
else {
|
||||
operator_args.mainloop.ptr_A = static_cast<ElementNarrow const *>(encoded_AB);
|
||||
}
|
||||
ElementScaleMainloop *ptr_packed_Scale = static_cast<ElementScaleMainloop *>(arguments->packed_Scale);
|
||||
pack_scale_fp8(ptr_S, ptr_packed_Scale, SZ_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename ElementAB,
|
||||
class ActualStrideAB,
|
||||
class LayoutAB_Reordered,
|
||||
class LayoutAtomQuant,
|
||||
Sm90MixedInputWiderOperand wider_operand>
|
||||
static void handle_shuffle_tensor_(
|
||||
OperatorArguments &operator_args,
|
||||
GemmUniversalArguments const *arguments,
|
||||
const int &problem_mn,
|
||||
const int &problem_k,
|
||||
const int &options_l) {
|
||||
|
||||
auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l);
|
||||
auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB);
|
||||
auto layout_AB = cute::make_layout(shape_AB, stride_AB);
|
||||
LayoutAB_Reordered layout_AB_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_AB);
|
||||
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
|
||||
operator_args.mainloop.dB = layout_AB_reordered;
|
||||
}
|
||||
else {
|
||||
operator_args.mainloop.dA = layout_AB_reordered;
|
||||
}
|
||||
if (arguments->generate_dequantized_AB) {
|
||||
size_t AB_size = cute::size(layout_AB);
|
||||
ElementAB *AB_reordered = cutlass::device_memory::allocate<ElementAB>(AB_size);
|
||||
const ElementAB *AB_src = nullptr;
|
||||
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
|
||||
AB_src = static_cast<const ElementAB *>(operator_args.mainloop.ptr_B);
|
||||
}
|
||||
else {
|
||||
AB_src = static_cast<const ElementAB *>(operator_args.mainloop.ptr_A);
|
||||
}
|
||||
reorder_tensor(AB_src, layout_AB, AB_reordered, layout_AB_reordered);
|
||||
ElementAB *AB_dst = static_cast<ElementAB *>(arguments->encoded_AB);
|
||||
cutlass::device_memory::copy_device_to_device(AB_dst, AB_reordered, AB_size);
|
||||
cutlass::device_memory::free(AB_reordered);
|
||||
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
|
||||
operator_args.mainloop.ptr_B = AB_dst;
|
||||
}
|
||||
else {
|
||||
operator_args.mainloop.ptr_A = AB_dst;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
Status update_arguments_(
|
||||
OperatorArguments& operator_args,
|
||||
GemmUniversalArguments const* arguments,
|
||||
cudaStream_t stream = nullptr) const {
|
||||
Status status = Status::kSuccess;
|
||||
|
||||
status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
|
||||
@ -286,24 +419,173 @@ protected:
|
||||
operator_args.epilogue.ptr_C = static_cast<ElementC const *>(arguments->C);
|
||||
operator_args.epilogue.ptr_D = static_cast<ElementD *>(arguments->D);
|
||||
|
||||
operator_args.mainloop.dA = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideA>(
|
||||
// Stride{A,B} is a Layout if and only if:
|
||||
// (1) This is a mixed dtype kernel, and
|
||||
// (2) This mixed dtype kernel is using shuffling, and
|
||||
// (3) sizeof(narrow_type) == 4 or 8 bits, and
|
||||
// (4) sizeof(wide_type) == 16 bits.
|
||||
// If A/B has the narrow data type, Stride{A/B} will be a Layout
|
||||
constexpr bool is_StrideA_Layout = cute::is_layout<typename CollectiveMainloop::StrideA>::value;
|
||||
constexpr bool is_StrideB_Layout = cute::is_layout<typename CollectiveMainloop::StrideB>::value;
|
||||
static_assert(!(is_StrideA_Layout && is_StrideB_Layout), "Incorrect kernel configuration: StrideA and StrideB are both cute::Layout");
|
||||
if constexpr(!is_StrideA_Layout) {
|
||||
operator_args.mainloop.dA = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideA>(
|
||||
arguments->lda, arguments->batch_stride_A);
|
||||
operator_args.mainloop.dB = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideB>(
|
||||
}
|
||||
if constexpr(!is_StrideB_Layout) {
|
||||
operator_args.mainloop.dB = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideB>(
|
||||
arguments->ldb, arguments->batch_stride_B);
|
||||
}
|
||||
operator_args.epilogue.dC = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideC>(
|
||||
arguments->ldc, arguments->batch_stride_C);
|
||||
operator_args.epilogue.dD = operator_args.epilogue.dC;
|
||||
|
||||
using MainloopPolicy = typename CollectiveMainloop::DispatchPolicy;
|
||||
if constexpr(is_mixed_dtype_mainloop_(MainloopPolicy{})) {
|
||||
int problem_m = arguments->problem_size.m();
|
||||
int problem_n = arguments->problem_size.n();
|
||||
int problem_k = arguments->problem_size.k();
|
||||
int options_l = arguments->batch_count;
|
||||
|
||||
constexpr Sm90MixedInputWiderOperand wider_operand =
|
||||
(cutlass::sizeof_bits<ElementA>::value > cutlass::sizeof_bits<ElementB>::value) ?
|
||||
Sm90MixedInputWiderOperand::A : Sm90MixedInputWiderOperand::B;
|
||||
using ElementWide = std::conditional_t<wider_operand == Sm90MixedInputWiderOperand::A, ElementA, ElementB>;
|
||||
using ElementNarrow = std::conditional_t<wider_operand == Sm90MixedInputWiderOperand::A, ElementB, ElementA>;
|
||||
|
||||
constexpr bool has_scale = !std::is_same_v<typename CollectiveMainloop::ElementScale, void>;
|
||||
constexpr bool has_zero = !std::is_same_v<typename CollectiveMainloop::ElementZero, void>;
|
||||
if constexpr(has_scale) {
|
||||
int options_g = problem_k;
|
||||
int scale_k = (problem_k + options_g - 1) / options_g;
|
||||
|
||||
constexpr bool is_A4B8 = (
|
||||
cutlass::is_same_v<ElementA, cutlass::int4b_t> &&
|
||||
(cutlass::is_same_v<ElementB, cutlass::float_e4m3_t> ||
|
||||
cutlass::is_same_v<ElementB, cutlass::float_e5m2_t>));
|
||||
constexpr bool is_A8B4 = (
|
||||
cutlass::is_same_v<ElementB, cutlass::int4b_t> &&
|
||||
(cutlass::is_same_v<ElementA, cutlass::float_e4m3_t> ||
|
||||
cutlass::is_same_v<ElementA, cutlass::float_e5m2_t>));
|
||||
constexpr bool is_int4_x_fp8 = is_A4B8 || is_A8B4;
|
||||
|
||||
// In int4 * fp8, ElementScale is a cutlass::Array, need to take out it's real element
|
||||
using ElementScaleMainloop = typename CollectiveMainloop::ElementScale;
|
||||
using ElementScale = typename UnderlyingElement<typename CollectiveMainloop::ElementScale>::type;
|
||||
using StrideS = typename CollectiveMainloop::StrideScale;
|
||||
// In ScaleOnly mode, we have allocated the same size of memory for arguments->Z and arguments->S
|
||||
using ElementZero = std::conditional_t<
|
||||
has_zero,
|
||||
typename CollectiveMainloop::ElementZero,
|
||||
ElementScale
|
||||
>;
|
||||
const int SZ_1st_dim = (wider_operand == Sm90MixedInputWiderOperand::A) ? problem_n : problem_m;
|
||||
const size_t SZ_size = static_cast<size_t>(SZ_1st_dim * scale_k * options_l);
|
||||
auto shape_SZ = cute::make_shape(SZ_1st_dim, scale_k, options_l);
|
||||
ElementScale *ptr_S = static_cast<ElementScale *>(arguments->Scale);
|
||||
ElementZero *ptr_Z = static_cast<ElementZero *>(arguments->Zero);
|
||||
|
||||
// 1. If arguments is initialized in profiler, S and Z needs to be allocated and filled
|
||||
if (arguments->generate_scale_and_zero) {
|
||||
// Need to fix max_dequant_val and min_dequant_val?
|
||||
const float elt_max_f = float(cutlass::platform::numeric_limits<ElementScale>::max());
|
||||
const float max_dequant_val = elt_max_f * 0.25f;
|
||||
const float min_dequant_val = 0.5f;
|
||||
const float scale_max = max_dequant_val / elt_max_f;
|
||||
const float scale_min = min_dequant_val / elt_max_f;
|
||||
uint64_t seed = 2023;
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
ptr_S, SZ_size, seed, ElementScale(scale_max), ElementScale(scale_min));
|
||||
|
||||
// In ScaleOnly mode, set Z as zero for generating dequantized A or B
|
||||
const float zero_max = has_zero ? 2.0f : 0.0f;
|
||||
const float zero_min = has_zero ? -2.0f : 0.0f;
|
||||
cutlass::reference::device::BlockFillRandomUniform(
|
||||
ptr_Z, SZ_size, seed, ElementZero(zero_max), ElementZero(zero_min));
|
||||
} // End of "if (arguments->generate_scale_and_zero)"
|
||||
|
||||
// 2. Generate the dequantized A or B for verification
|
||||
if (arguments->generate_dequantized_AB) {
|
||||
StrideS stride_SZ = cutlass::make_cute_packed_stride(StrideS{}, shape_SZ);
|
||||
auto layout_SZ = cute::make_layout(shape_SZ, stride_SZ);
|
||||
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
|
||||
if constexpr(is_StrideB_Layout) {
|
||||
// The generator only generates row-major A and col-major B at the moment
|
||||
// Need a way to read out the actual layout of B later
|
||||
using ActualLayoutB = cutlass::layout::ColumnMajor;
|
||||
using ActualStrideB = cutlass::detail::TagToStrideB_t<ActualLayoutB>;
|
||||
dequantize_encode_<ElementWide, ElementNarrow, ElementScaleMainloop, ActualStrideB, wider_operand, is_A8B4>(
|
||||
operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ);
|
||||
}
|
||||
else {
|
||||
using ActualStrideB = typename CollectiveMainloop::StrideB;
|
||||
dequantize_encode_<ElementWide, ElementNarrow, ElementScaleMainloop, ActualStrideB, wider_operand, is_A8B4>(
|
||||
operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if constexpr(is_StrideA_Layout) {
|
||||
// The generator only generates row-major A and col-major B at the moment
|
||||
// Need a way to read out the actual layout of A later
|
||||
using ActualLayoutA = cutlass::layout::RowMajor;
|
||||
using ActualStrideA = cutlass::detail::TagToStrideA_t<ActualLayoutA>;
|
||||
dequantize_encode_<ElementWide, ElementNarrow, ElementScaleMainloop, ActualStrideA, wider_operand, is_A4B8>(
|
||||
operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ);
|
||||
}
|
||||
else {
|
||||
using ActualStrideA = typename CollectiveMainloop::StrideA;
|
||||
dequantize_encode_<ElementWide, ElementNarrow, ElementScaleMainloop, ActualStrideA, wider_operand, is_A4B8>(
|
||||
operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ);
|
||||
}
|
||||
} // End of "if constexpr(wider_operand == Sm90MixedInputWiderOperand::A)"
|
||||
arguments->dequantized_AB_ready[0] = true;
|
||||
} // End of "if (arguments->generate_dequantized_AB)"
|
||||
|
||||
// 3. Put arguments in mainloop
|
||||
if constexpr(is_int4_x_fp8) {
|
||||
operator_args.mainloop.ptr_S = static_cast<ElementScaleMainloop const*>(arguments->packed_Scale);
|
||||
}
|
||||
else {
|
||||
operator_args.mainloop.ptr_S = static_cast<ElementScale const*>(arguments->Scale);
|
||||
}
|
||||
operator_args.mainloop.dS = cutlass::make_cute_packed_stride(StrideS{}, shape_SZ);
|
||||
operator_args.mainloop.group_size = options_g;
|
||||
if constexpr(has_zero) {
|
||||
operator_args.mainloop.ptr_Z = static_cast<ElementZero const*>(arguments->Zero);
|
||||
}
|
||||
} // End of "if constexpr(has_scale)"
|
||||
|
||||
// Handle the shuffling
|
||||
using ValueShuffle = std::conditional_t<
|
||||
cutlass::sizeof_bits<ElementNarrow>::value == 4,
|
||||
cute::Layout<cute::Shape<cute::_2,cute::_4>, cute::Stride<cute::_4,cute::_1>>,
|
||||
cute::Layout<cute::Shape<cute::_2,cute::_2>, cute::Stride<cute::_2,cute::_1>>
|
||||
>;
|
||||
constexpr int NumShuffleAtoms = 1;
|
||||
using MmaAtomShape = cute::Layout<cute::Shape<cute::_1,cute::Int<NumShuffleAtoms>>>;
|
||||
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<ElementWide, MmaAtomShape, ValueShuffle>());
|
||||
// The generator only generates row-major A and col-major B at the moment
|
||||
// Need a way to read out the actual layout and stride of A/B later
|
||||
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A && is_StrideB_Layout) {
|
||||
using ActualLayoutB = cutlass::layout::ColumnMajor;
|
||||
using ActualStrideB = cutlass::detail::TagToStrideB_t<ActualLayoutB>;
|
||||
using LayoutB_Reordered = typename CollectiveMainloop::StrideB;
|
||||
handle_shuffle_tensor_<ElementB, ActualStrideB, LayoutB_Reordered, LayoutAtomQuant, wider_operand>(
|
||||
operator_args, arguments, problem_n, problem_k, options_l);
|
||||
}
|
||||
if constexpr(wider_operand == Sm90MixedInputWiderOperand::B && is_StrideA_Layout) {
|
||||
using ActualLayoutA = cutlass::layout::RowMajor;
|
||||
using ActualStrideA = cutlass::detail::TagToStrideA_t<ActualLayoutA>;
|
||||
using LayoutA_Reordered = typename CollectiveMainloop::StrideA;
|
||||
handle_shuffle_tensor_<ElementA, ActualStrideA, LayoutA_Reordered, LayoutAtomQuant, wider_operand>(
|
||||
operator_args, arguments, problem_m, problem_k, options_l);
|
||||
}
|
||||
} // End of "if constexpr(is_mixed_dtype_mainloop_(MainloopPolicy{}))"
|
||||
|
||||
/* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */
|
||||
operator_args.hw_info.sm_count = arguments->sm_count;
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability == 90) {
|
||||
dim3 cluster_dims(cute::size<0>(typename Operator::GemmKernel::ClusterShape{}),
|
||||
cute::size<1>(typename Operator::GemmKernel::ClusterShape{}),
|
||||
cute::size<2>(typename Operator::GemmKernel::ClusterShape{}));
|
||||
uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock;
|
||||
void const* kernel_ptr = (void*)(device_kernel<typename Operator::GemmKernel>);
|
||||
operator_args.hw_info.max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters(
|
||||
cluster_dims, threads_per_block, kernel_ptr);
|
||||
operator_args.hw_info.max_active_clusters = max_active_clusters;
|
||||
}
|
||||
if constexpr (!std::is_const_v<decltype(operator_args.scheduler.max_swizzle_size)>) {
|
||||
operator_args.scheduler.max_swizzle_size = arguments->swizzle_size;
|
||||
@ -356,7 +638,10 @@ public:
|
||||
return status;
|
||||
}
|
||||
|
||||
return Operator::can_implement(args);
|
||||
Status can_impl = Operator::can_implement(args);
|
||||
|
||||
//return Operator::can_implement(args);
|
||||
return can_impl;
|
||||
}
|
||||
|
||||
/// Gets the host-side workspace
|
||||
@ -397,7 +682,7 @@ public:
|
||||
cudaStream_t stream = nullptr) const override {
|
||||
|
||||
OperatorArguments args;
|
||||
Status status = update_arguments_(args, static_cast<GemmUniversalArguments const *>(arguments_ptr));
|
||||
Status status = update_arguments_(args, static_cast<GemmUniversalArguments const *>(arguments_ptr), stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
330
tools/library/src/grouped_gemm_operation_3x.hpp
Normal file
330
tools/library/src/grouped_gemm_operation_3x.hpp
Normal file
@ -0,0 +1,330 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Defines operations for all grouped GEMM operations in CUTLASS Library.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/detail/collective.hpp"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/util.h"
|
||||
#include "gemm_operation_3x.hpp"
|
||||
#include "library_internal.h"
|
||||
#include <unordered_map>
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::library {
|
||||
|
||||
/// **** CAUTION ****
|
||||
/// Unlike other operations, initialize() must be called when
|
||||
/// certain arguments change. See initialize() for details.
|
||||
template <typename Operator_>
|
||||
class GroupedGemmUniversal3xOperation : public GemmOperation3xBase<Operator_> {
|
||||
public:
|
||||
using Operator = Operator_;
|
||||
using OperatorArguments = typename Operator::Arguments;
|
||||
using ElementA = typename Operator::ElementA;
|
||||
using LayoutA = typename Operator::LayoutA;
|
||||
using ElementB = typename Operator::ElementB;
|
||||
using LayoutB = typename Operator::LayoutB;
|
||||
using ElementC = typename Operator::ElementC;
|
||||
using LayoutC = typename Operator::LayoutC;
|
||||
using ElementD = typename Operator::ElementD;
|
||||
using LayoutD = typename Operator::LayoutD;
|
||||
using ElementAccumulator = typename Operator::ElementAccumulator;
|
||||
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
|
||||
|
||||
using CollectiveMainloop = typename Operator::CollectiveMainloop;
|
||||
using CollectiveEpilogue = typename Operator::CollectiveEpilogue;
|
||||
using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp;
|
||||
|
||||
private:
|
||||
mutable CudaBuffer strideA_device;
|
||||
mutable CudaBuffer strideB_device;
|
||||
mutable CudaBuffer strideC_device;
|
||||
mutable CudaBuffer strideD_device;
|
||||
mutable std::vector<typename Operator::GemmKernel::InternalStrideA> strideA_host;
|
||||
mutable std::vector<typename Operator::GemmKernel::InternalStrideB> strideB_host;
|
||||
mutable std::vector<typename Operator::GemmKernel::InternalStrideC> strideC_host;
|
||||
mutable std::vector<typename Operator::GemmKernel::InternalStrideD> strideD_host;
|
||||
|
||||
public:
|
||||
GroupedGemmUniversal3xOperation(char const* name = "unknown_gemm")
|
||||
: GemmOperation3xBase<Operator_>(name, GemmKind::kGrouped) {
|
||||
this->description_.kind = OperationKind::kGroupedGemm;
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster_dims(
|
||||
cute::size<0>(typename Operator::GemmKernel::ClusterShape{}),
|
||||
cute::size<1>(typename Operator::GemmKernel::ClusterShape{}),
|
||||
cute::size<2>(typename Operator::GemmKernel::ClusterShape{}));
|
||||
uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock;
|
||||
void const* kernel_ptr = (void*)(device_kernel<typename Operator::GemmKernel>);
|
||||
max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters(
|
||||
cluster_dims,
|
||||
threads_per_block,
|
||||
kernel_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
~GroupedGemmUniversal3xOperation() override = default;
|
||||
|
||||
private:
|
||||
int max_active_clusters{};
|
||||
|
||||
protected:
|
||||
template <class FusionArgs, class = void> struct UpdateFusionArgs {
|
||||
static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) {
|
||||
// If a custom EVT is instantiated then it is the users's responsibility
|
||||
// to ensure alpha and beta are updated appropriately
|
||||
return Status::kSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
template <class FusionArgs>
|
||||
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
|
||||
static Status update_(FusionArgs& fusion_args, GemmGroupedArguments const& arguments) {
|
||||
if (arguments.pointer_mode == ScalarPointerMode::kHost) {
|
||||
fusion_args.alpha = *static_cast<ElementCompute const*>(arguments.alpha);
|
||||
fusion_args.beta = *static_cast<ElementCompute const*>(arguments.beta);
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = nullptr;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// Single alpha and beta for all groups
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else if (arguments.pointer_mode == ScalarPointerMode::kDevice) {
|
||||
fusion_args.alpha = 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(arguments.alpha);
|
||||
fusion_args.beta_ptr = static_cast<ElementCompute const*>(arguments.beta);
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Constructs the arguments structure given the configuration and arguments
|
||||
Status
|
||||
update_arguments_(OperatorArguments& operator_args, GemmGroupedArguments const* arguments) const {
|
||||
|
||||
Status status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
|
||||
operator_args.epilogue.thread,
|
||||
*arguments);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
operator_args.mode = cutlass::gemm::GemmUniversalMode::kGrouped;
|
||||
operator_args.problem_shape = {
|
||||
arguments->problem_count,
|
||||
arguments->problem_sizes_3x,
|
||||
arguments->pointer_mode == ScalarPointerMode::kHost ? arguments->problem_sizes_3x_host
|
||||
: nullptr};
|
||||
operator_args.mainloop.ptr_A =
|
||||
static_cast<const typename Operator::ElementA**>(arguments->ptr_A);
|
||||
operator_args.mainloop.ptr_B =
|
||||
static_cast<const typename Operator::ElementB**>(arguments->ptr_B);
|
||||
operator_args.epilogue.ptr_C =
|
||||
static_cast<const typename Operator::ElementC**>(arguments->ptr_C);
|
||||
operator_args.epilogue.ptr_D = static_cast<typename Operator::ElementD**>(arguments->ptr_D);
|
||||
|
||||
operator_args.mainloop.dA =
|
||||
static_cast<typename Operator::GemmKernel::InternalStrideA*>(strideA_device.data());
|
||||
operator_args.mainloop.dB =
|
||||
static_cast<typename Operator::GemmKernel::InternalStrideB*>(strideB_device.data());
|
||||
operator_args.epilogue.dC =
|
||||
static_cast<typename Operator::GemmKernel::InternalStrideC*>(strideC_device.data());
|
||||
operator_args.epilogue.dD =
|
||||
static_cast<typename Operator::GemmKernel::InternalStrideD*>(strideD_device.data());
|
||||
|
||||
operator_args.hw_info.sm_count = arguments->sm_count;
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) {
|
||||
operator_args.hw_info.max_active_clusters = max_active_clusters;
|
||||
}
|
||||
|
||||
if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) {
|
||||
operator_args.hw_info.cluster_shape = dim3(
|
||||
arguments->cluster_shape.m(),
|
||||
arguments->cluster_shape.n(),
|
||||
arguments->cluster_shape.k());
|
||||
operator_args.hw_info.cluster_shape_fallback = dim3(
|
||||
arguments->cluster_shape_fallback.m(),
|
||||
arguments->cluster_shape_fallback.n(),
|
||||
arguments->cluster_shape_fallback.k());
|
||||
}
|
||||
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
public:
|
||||
/// Returns success if the operation can proceed
|
||||
Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr)
|
||||
const override {
|
||||
GemmGroupedArguments const* arguments = static_cast<GemmGroupedArguments const*>(arguments_ptr);
|
||||
OperatorArguments args;
|
||||
|
||||
auto status = update_arguments_(args, arguments);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = Operator::can_implement(args);
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Gets the host-side workspace
|
||||
uint64_t get_host_workspace_size(void const* configuration) const override {
|
||||
return sizeof(Operator);
|
||||
}
|
||||
|
||||
/// Gets the device-side workspace
|
||||
uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr)
|
||||
const override {
|
||||
|
||||
OperatorArguments args;
|
||||
auto status = update_arguments_(args, static_cast<GemmGroupedArguments const*>(arguments_ptr));
|
||||
if (status != Status::kSuccess) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint64_t size = Operator::get_workspace_size(args);
|
||||
return size;
|
||||
}
|
||||
|
||||
/// Initializes the workspace
|
||||
/// **** CAUTION ****
|
||||
/// Must be called when lda, ldb, ldc, or ldd change.
|
||||
/// The CUTLASS library stores the operations in a type-
|
||||
/// erased manifest. Therefore, only this class knows
|
||||
/// the type of strideA, strideB, strideC, and strideD.
|
||||
/// Since grouped GEMM needs to allocate storage for
|
||||
/// the strides on device, the concrete type of the stride
|
||||
/// must be known in order to copy in the correct memory
|
||||
/// layout on device.
|
||||
Status initialize(
|
||||
void const* configuration_ptr,
|
||||
void* host_workspace,
|
||||
void* device_workspace,
|
||||
cudaStream_t stream = nullptr) const override {
|
||||
|
||||
auto const& config = *static_cast<GemmGroupedConfiguration const*>(configuration_ptr);
|
||||
|
||||
auto num_groups = config.problem_count;
|
||||
strideA_device =
|
||||
CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups);
|
||||
strideB_device =
|
||||
CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups);
|
||||
strideC_device =
|
||||
CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups);
|
||||
strideD_device =
|
||||
CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups);
|
||||
|
||||
strideA_host.resize(num_groups);
|
||||
strideB_host.resize(num_groups);
|
||||
strideC_host.resize(num_groups);
|
||||
strideD_host.resize(num_groups);
|
||||
for (int group_idx = 0; group_idx < num_groups; group_idx++) {
|
||||
strideA_host[group_idx] =
|
||||
cute::make_int_tuple_from<typename Operator::GemmKernel::InternalStrideA>(
|
||||
config.lda[group_idx]);
|
||||
strideB_host[group_idx] =
|
||||
cute::make_int_tuple_from<typename Operator::GemmKernel::InternalStrideB>(
|
||||
config.ldb[group_idx]);
|
||||
strideC_host[group_idx] =
|
||||
cute::make_int_tuple_from<typename Operator::GemmKernel::InternalStrideC>(
|
||||
config.ldc[group_idx]);
|
||||
strideD_host[group_idx] =
|
||||
cute::make_int_tuple_from<typename Operator::GemmKernel::InternalStrideD>(
|
||||
config.ldc[group_idx]);
|
||||
}
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
strideA_device.data(),
|
||||
strideA_host.data(),
|
||||
sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups,
|
||||
cudaMemcpyHostToDevice));
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
strideB_device.data(),
|
||||
strideB_host.data(),
|
||||
sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups,
|
||||
cudaMemcpyHostToDevice));
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
strideC_device.data(),
|
||||
strideC_host.data(),
|
||||
sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups,
|
||||
cudaMemcpyHostToDevice));
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
strideD_device.data(),
|
||||
strideD_host.data(),
|
||||
sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups,
|
||||
cudaMemcpyHostToDevice));
|
||||
|
||||
Operator* op = new (host_workspace) Operator;
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// **** CAUTION ****
|
||||
/// initialize() must be called if lda, ldb, ldc, or ldd change.
|
||||
Status run(
|
||||
void const* arguments_ptr,
|
||||
void* host_workspace,
|
||||
void* device_workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) const override {
|
||||
|
||||
OperatorArguments operator_args;
|
||||
auto const& args = *static_cast<GemmGroupedArguments const*>(arguments_ptr);
|
||||
|
||||
Status status = update_arguments_(operator_args, &args);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
Operator* op = static_cast<Operator*>(host_workspace);
|
||||
// We need to call initialize() since we have to rebuild TMA desc for every new set of args
|
||||
status = op->run(operator_args, device_workspace, stream, nullptr, args.use_pdl);
|
||||
return status;
|
||||
}
|
||||
};
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::library
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -64,7 +64,6 @@ void initialize_gemm_reference_operations_f8_f6_f32(Manifest &manifest);
|
||||
void initialize_block_scaled_gemm_reference_operations_fp4a_vs16(Manifest &manifest);
|
||||
void initialize_block_scaled_gemm_reference_operations_fp4a_vs32(Manifest &manifest);
|
||||
void initialize_block_scaled_gemm_reference_operations_mixed8bitsa(Manifest &manifest);
|
||||
|
||||
void initialize_gemm_reference_operations_fp8in_fp16out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest);
|
||||
@ -114,7 +113,6 @@ void initialize_reference_operations(Manifest &manifest) {
|
||||
initialize_block_scaled_gemm_reference_operations_fp4a_vs16(manifest);
|
||||
initialize_block_scaled_gemm_reference_operations_fp4a_vs32(manifest);
|
||||
initialize_block_scaled_gemm_reference_operations_mixed8bitsa(manifest);
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -37,6 +37,7 @@
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/detail/collective.hpp"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/util.h"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter
|
||||
#include "cutlass/util/packed_stride.hpp" // make_cute_packed_stride
|
||||
@ -45,14 +46,6 @@
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CUDA_CHECK(cuda_error) \
|
||||
{ \
|
||||
if (cuda_error != cudaSuccess) { \
|
||||
printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \
|
||||
return Status::kInvalid; \
|
||||
} \
|
||||
}
|
||||
|
||||
namespace cutlass::library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -330,18 +330,18 @@ static struct {
|
||||
char const *text;
|
||||
char const *pretty;
|
||||
OperationKind enumerant;
|
||||
}
|
||||
OperationKind_enumerants[] = {
|
||||
{"eq_gemm", "EqGemm", OperationKind::kEqGemm},
|
||||
} OperationKind_enumerants[] = {
|
||||
{"eq_gemm", "EqGemm", OperationKind::kEqGemm},
|
||||
{"gemm", "Gemm", OperationKind::kGemm},
|
||||
{"block_scaled_gemm", "blockScaledGemm", OperationKind::kBlockScaledGemm},
|
||||
{"rank_k", "RankK", OperationKind::kRankK},
|
||||
{"rank_2k", "Rank2K", OperationKind::kRank2K},
|
||||
{"trmm", "Trmm", OperationKind::kTrmm},
|
||||
{"symm", "Symm", OperationKind::kSymm},
|
||||
{"conv2d", "Conv2d", OperationKind::kConv2d},
|
||||
{"conv3d", "Conv3d", OperationKind::kConv3d},
|
||||
{"conv2d", "Conv2d", OperationKind::kConv2d},
|
||||
{"conv3d", "Conv3d", OperationKind::kConv3d},
|
||||
{"spgemm", "SparseGemm", OperationKind::kSparseGemm},
|
||||
{"grouped_gemm", "GroupedGemm", OperationKind::kGroupedGemm},
|
||||
};
|
||||
|
||||
/// Converts a Status enumerant to a string
|
||||
@ -504,7 +504,6 @@ NumericTypeID_enumerants[] = {
|
||||
{"fe2m1", "FE2M1", NumericTypeID::kFE2M1},
|
||||
{"fue8m0", "FUE8M0", NumericTypeID::kFUE8M0},
|
||||
{"fue4m3", "FUE4M3", NumericTypeID::kFUE4M3},
|
||||
|
||||
{"f16", "F16", NumericTypeID::kF16},
|
||||
{"bf16", "BF16", NumericTypeID::kBF16},
|
||||
{"f32", "F32", NumericTypeID::kF32},
|
||||
@ -577,7 +576,6 @@ int sizeof_bits(NumericTypeID type) {
|
||||
case NumericTypeID::kFE2M1: return 4;
|
||||
case NumericTypeID::kFUE8M0: return 8;
|
||||
case NumericTypeID::kFUE4M3: return 8;
|
||||
|
||||
case NumericTypeID::kF16: return 16;
|
||||
case NumericTypeID::kBF16: return 16;
|
||||
case NumericTypeID::kTF32: return 32;
|
||||
@ -666,7 +664,6 @@ bool is_signed_type(NumericTypeID type) {
|
||||
case NumericTypeID::kFE2M1: return true;
|
||||
case NumericTypeID::kFUE8M0: return false;
|
||||
case NumericTypeID::kFUE4M3: return false;
|
||||
|
||||
case NumericTypeID::kF16: return true;
|
||||
case NumericTypeID::kBF16: return true;
|
||||
case NumericTypeID::kTF32: return true;
|
||||
@ -707,7 +704,6 @@ bool is_float_type(NumericTypeID type) {
|
||||
case NumericTypeID::kFE2M1: return true;
|
||||
case NumericTypeID::kFUE8M0: return true;
|
||||
case NumericTypeID::kFUE4M3: return true;
|
||||
|
||||
case NumericTypeID::kF16: return true;
|
||||
case NumericTypeID::kBF16: return true;
|
||||
case NumericTypeID::kTF32: return true;
|
||||
@ -1256,7 +1252,6 @@ bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string c
|
||||
*reinterpret_cast<float_e5m2_t *>(bytes.data()) = static_cast<float_e5m2_t>(tmp);
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kFE2M3:
|
||||
{
|
||||
float tmp;
|
||||
@ -1292,7 +1287,6 @@ bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string c
|
||||
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(tmp);
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kF16:
|
||||
{
|
||||
float tmp;
|
||||
@ -1473,7 +1467,6 @@ std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type) {
|
||||
ss << tmp;
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kF16:
|
||||
{
|
||||
float tmp = *reinterpret_cast<half_t *>(bytes.data());
|
||||
@ -1652,7 +1645,6 @@ bool cast_from_int64(std::vector<uint8_t> &bytes, NumericTypeID type, int64_t sr
|
||||
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(float(src));
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kF16:
|
||||
{
|
||||
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src));
|
||||
@ -1789,7 +1781,6 @@ bool cast_from_uint64(std::vector<uint8_t> &bytes, NumericTypeID type, uint64_t
|
||||
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(float(src));
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kF16:
|
||||
{
|
||||
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src));
|
||||
@ -1927,7 +1918,6 @@ bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double sr
|
||||
*reinterpret_cast<float_ue4m3_t *>(bytes.data()) = static_cast<float_ue4m3_t>(float(src));
|
||||
}
|
||||
break;
|
||||
|
||||
case NumericTypeID::kF16:
|
||||
{
|
||||
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src));
|
||||
|
||||
@ -46,7 +46,8 @@ set(CUTLASS_TOOLS_PROFILER_SOURCES
|
||||
src/problem_space.cpp
|
||||
src/operation_profiler.cu
|
||||
src/gemm_operation_profiler.cu
|
||||
src/block_scaled_gemm_operation_profiler.cu
|
||||
src/grouped_gemm_operation_profiler.cu
|
||||
src/block_scaled_gemm_operation_profiler.cu
|
||||
src/rank_k_operation_profiler.cu
|
||||
src/rank_2k_operation_profiler.cu
|
||||
src/trmm_operation_profiler.cu
|
||||
@ -112,6 +113,7 @@ set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_RANK_K --operation=RankK --pro
|
||||
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_RANK_2K --operation=Rank2K --providers=cutlass --verification-providers=cublas --junit-output=test_cutlass_profiler_rank_2k --print-kernel-before-running=true)
|
||||
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_TRMM --operation=Trmm --providers=cutlass --verification-providers=device,host --junit-output=test_cutlass_profiler_trmm --print-kernel-before-running=true)
|
||||
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_SYMM --operation=Symm --providers=cutlass --verification-providers=cublas,host --junit-output=test_cutlass_profiler_symm --print-kernel-before-running=true)
|
||||
set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GROUPED_GEMM --operation=GroupedGemm --providers=cutlass --verification-providers=device --junit-output=test_cutlass_profiler_grouped_gemm --print-kernel-before-running=true)
|
||||
|
||||
cutlass_add_executable_tests(
|
||||
test_profiler cutlass_profiler
|
||||
@ -125,6 +127,7 @@ cutlass_add_executable_tests(
|
||||
RANK_2K
|
||||
TRMM
|
||||
SYMM
|
||||
GROUPED_GEMM
|
||||
TEST_COMMAND_OPTIONS_PREFIX
|
||||
CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_
|
||||
DISABLE_EXECUTABLE_INSTALL_RULE
|
||||
|
||||
@ -108,6 +108,8 @@ public:
|
||||
|
||||
bool use_pdl{false};
|
||||
|
||||
bool enable_sm90_mixed_dtype_shuffle_test{false};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -160,6 +162,13 @@ public:
|
||||
/// Buffer used for the cutlass reduction operations' host workspace
|
||||
std::vector<uint8_t> reduction_host_workspace;
|
||||
|
||||
/// For mixed input dtype kernels
|
||||
DeviceAllocation *Scale{nullptr}; // Scale tensor
|
||||
DeviceAllocation *Zero{nullptr}; // Zero tensor
|
||||
DeviceAllocation *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification
|
||||
DeviceAllocation *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle
|
||||
DeviceAllocation *packed_Scale{nullptr}; // Packed scale for int4 * fp8
|
||||
|
||||
cudaStream_t stream;
|
||||
};
|
||||
|
||||
|
||||
@ -0,0 +1,263 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/* \file
|
||||
\brief GroupedGemm Profiler
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
// CUTLASS Library includes
|
||||
#include "cutlass/library/library.h"
|
||||
|
||||
// Profiler includes
|
||||
#include "device_context.h"
|
||||
#include "operation_profiler.h"
|
||||
#include "options.h"
|
||||
#include "performance_result.h"
|
||||
#include "problem_space.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace profiler {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Abstract base class for each math function
|
||||
class GroupedGemmOperationProfiler : public OperationProfiler {
|
||||
public:
|
||||
/// Problem structure obtained from problem space
|
||||
struct GroupedGemmProblem {
|
||||
|
||||
cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGrouped};
|
||||
|
||||
std::vector<gemm::GemmCoord> problem_sizes;
|
||||
std::vector<cute::Shape<int, int, int>> problem_sizes_3x;
|
||||
|
||||
int cluster_m{1};
|
||||
int cluster_n{1};
|
||||
int cluster_k{1};
|
||||
int cluster_m_fallback{1};
|
||||
int cluster_n_fallback{1};
|
||||
int cluster_k_fallback{1};
|
||||
|
||||
std::vector<int64_t> lda{0};
|
||||
std::vector<int64_t> ldb{0};
|
||||
std::vector<int64_t> ldc{0};
|
||||
|
||||
std::vector<uint8_t> alpha;
|
||||
std::vector<uint8_t> beta;
|
||||
|
||||
/// Parses the problem
|
||||
Status parse(
|
||||
library::GemmDescription const& operation_desc,
|
||||
ProblemSpace const& problem_space,
|
||||
ProblemSpace::Problem const& problem);
|
||||
|
||||
int64_t m(int group_idx) const { return problem_sizes[group_idx].m(); };
|
||||
int64_t n(int group_idx) const { return problem_sizes[group_idx].n(); };
|
||||
int64_t k(int group_idx) const { return problem_sizes[group_idx].k(); };
|
||||
|
||||
/// Total number of bytes loaded
|
||||
int64_t bytes(library::GemmDescription const& operation_desc) const;
|
||||
|
||||
/// Total number of flops computed
|
||||
int64_t flops(library::GemmDescription const& operation_desc) const;
|
||||
|
||||
/// Initializes a performance result
|
||||
void initialize_result(
|
||||
PerformanceResult& result,
|
||||
library::GemmDescription const& operation_desc,
|
||||
ProblemSpace const& problem_space);
|
||||
};
|
||||
|
||||
// workspace contains the allocated blocks, arguments just contain the raw
|
||||
// pointers
|
||||
struct GroupedGemmWorkspace {
|
||||
|
||||
std::vector<DeviceAllocation*> A_ptr_array_device;
|
||||
std::vector<DeviceAllocation*> B_ptr_array_device;
|
||||
std::vector<DeviceAllocation*> C_ptr_array_device;
|
||||
std::vector<DeviceAllocation*> D_ptr_array_device;
|
||||
std::vector<DeviceAllocation*> reference_ptr_array_host;
|
||||
std::vector<DeviceAllocation*> A_ptr_array_host;
|
||||
std::vector<DeviceAllocation*> B_ptr_array_host;
|
||||
std::vector<DeviceAllocation*> C_ptr_array_host;
|
||||
std::vector<DeviceAllocation*> D_ptr_array_host;
|
||||
|
||||
/// Number of copies of the problem workspace which are visited sequentially during
|
||||
/// profiling to avoid camping in the last level cache.
|
||||
/// *NOT* the number of groups in the grouped GEMM
|
||||
int problem_count{1};
|
||||
|
||||
DeviceAllocation* problem_sizes_array_device{nullptr};
|
||||
DeviceAllocation* problem_sizes_3x_array_device{nullptr};
|
||||
DeviceAllocation* lda_array_device{nullptr};
|
||||
DeviceAllocation* ldb_array_device{nullptr};
|
||||
DeviceAllocation* ldc_array_device{nullptr};
|
||||
DeviceAllocation* ldd_array_device{nullptr};
|
||||
|
||||
library::GemmGroupedConfiguration configuration;
|
||||
library::GemmGroupedArguments arguments;
|
||||
|
||||
std::vector<uint8_t> host_workspace;
|
||||
DeviceAllocation device_workspace;
|
||||
};
|
||||
|
||||
private:
|
||||
void init_arguments(Options const& options) {
|
||||
gemm_workspace_.arguments.ptr_A = gemm_workspace_.A_ptr_array_device[0]->data();
|
||||
gemm_workspace_.arguments.ptr_B = gemm_workspace_.B_ptr_array_device[0]->data();
|
||||
gemm_workspace_.arguments.ptr_C = gemm_workspace_.C_ptr_array_device[0]->data();
|
||||
gemm_workspace_.arguments.ptr_D = gemm_workspace_.D_ptr_array_device[0]->data();
|
||||
gemm_workspace_.arguments.alpha = problem_.alpha.data();
|
||||
gemm_workspace_.arguments.beta = problem_.beta.data();
|
||||
gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost;
|
||||
gemm_workspace_.arguments.lda = static_cast<int64_t*>(gemm_workspace_.lda_array_device->data());
|
||||
gemm_workspace_.arguments.ldb = static_cast<int64_t*>(gemm_workspace_.ldb_array_device->data());
|
||||
gemm_workspace_.arguments.ldc = static_cast<int64_t*>(gemm_workspace_.ldc_array_device->data());
|
||||
gemm_workspace_.arguments.ldd = static_cast<int64_t*>(gemm_workspace_.ldc_array_device->data());
|
||||
gemm_workspace_.arguments.problem_sizes =
|
||||
static_cast<gemm::GemmCoord*>(gemm_workspace_.problem_sizes_array_device->data());
|
||||
gemm_workspace_.arguments.problem_sizes_3x = static_cast<cute::Shape<int, int, int>*>(
|
||||
gemm_workspace_.problem_sizes_3x_array_device->data());
|
||||
gemm_workspace_.arguments.problem_sizes_3x_host = problem_.problem_sizes_3x.data();
|
||||
gemm_workspace_.arguments.problem_count = problem_.problem_sizes.size();
|
||||
gemm_workspace_.arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)};
|
||||
gemm_workspace_.arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)};
|
||||
|
||||
/* Query device SM count to pass onto the kernel as an argument, where needed */
|
||||
gemm_workspace_.arguments.sm_count = options.device.properties[0].multiProcessorCount;
|
||||
}
|
||||
|
||||
protected:
|
||||
/// GEMM problem obtained from problem space
|
||||
GroupedGemmProblem problem_;
|
||||
|
||||
/// Device memory allocations
|
||||
GroupedGemmWorkspace gemm_workspace_;
|
||||
|
||||
public:
|
||||
GroupedGemmOperationProfiler(Options const& options);
|
||||
|
||||
virtual ~GroupedGemmOperationProfiler();
|
||||
|
||||
GroupedGemmProblem const& problem() const { return problem_; }
|
||||
|
||||
/// Prints usage statement for the math function
|
||||
virtual void print_usage(std::ostream& out) const;
|
||||
|
||||
/// Prints examples
|
||||
virtual void print_examples(std::ostream& out) const;
|
||||
|
||||
/// Extracts the problem dimensions
|
||||
virtual Status initialize_configuration(
|
||||
Options const& options,
|
||||
PerformanceReport& report,
|
||||
DeviceContext& device_context,
|
||||
library::Operation const* operation,
|
||||
ProblemSpace const& problem_space,
|
||||
ProblemSpace::Problem const& problem);
|
||||
|
||||
/// Initializes workspace
|
||||
virtual Status initialize_workspace(
|
||||
Options const& options,
|
||||
PerformanceReport& report,
|
||||
DeviceContext& device_context,
|
||||
library::Operation const* operation,
|
||||
ProblemSpace const& problem_space,
|
||||
ProblemSpace::Problem const& problem);
|
||||
|
||||
/// Verifies CUTLASS against references
|
||||
virtual bool verify_cutlass(
|
||||
Options const& options,
|
||||
PerformanceReport& report,
|
||||
DeviceContext& device_context,
|
||||
library::Operation const* operation,
|
||||
ProblemSpace const& problem_space,
|
||||
ProblemSpace::Problem const& problem);
|
||||
|
||||
/// Measures performance results
|
||||
virtual bool profile(
|
||||
Options const& options,
|
||||
PerformanceReport& report,
|
||||
DeviceContext& device_context,
|
||||
library::Operation const* operation,
|
||||
ProblemSpace const& problem_space,
|
||||
ProblemSpace::Problem const& problem);
|
||||
|
||||
protected:
|
||||
/// Initializes the performance result
|
||||
void initialize_result_(
|
||||
PerformanceResult& result,
|
||||
Options const& options,
|
||||
library::GemmDescription const& operation_desc,
|
||||
ProblemSpace const& problem_space);
|
||||
|
||||
/// Verifies CUTLASS against host and device references
|
||||
bool verify_with_reference_(
|
||||
Options const& options,
|
||||
PerformanceReport& report,
|
||||
DeviceContext& device_context,
|
||||
library::Operation const* operation,
|
||||
ProblemSpace const& problem_space,
|
||||
ProblemSpace::Problem const& problem,
|
||||
cutlass::library::NumericTypeID element_A,
|
||||
cutlass::library::NumericTypeID element_B);
|
||||
|
||||
/// Method to profile a CUTLASS Operation
|
||||
Status profile_cutlass_(
|
||||
PerformanceResult& result,
|
||||
Options const& options,
|
||||
library::Operation const* operation,
|
||||
void* arguments,
|
||||
void* host_workspace,
|
||||
void* device_workspace) override;
|
||||
|
||||
/// Initialize reduction problem dimensions and library::Operation
|
||||
bool initialize_reduction_configuration_(
|
||||
library::Operation const* operation,
|
||||
ProblemSpace::Problem const& problem);
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -241,17 +241,30 @@ protected:
|
||||
|
||||
/// Profiles the GPU kernel launched in `func` running simultaneously on all
|
||||
/// requested devices.
|
||||
Status profile_kernel_w_cuda_graphs_(
|
||||
PerformanceResult& result,
|
||||
Options const& options,
|
||||
std::function<Status(int, cudaStream_t, int)> const& func,
|
||||
std::vector<cudaStream_t> const& streams);
|
||||
|
||||
Status profile_kernel_(
|
||||
PerformanceResult &result,
|
||||
Options const &options,
|
||||
const std::function<Status(int, cudaStream_t, int)> &func,
|
||||
const std::vector<cudaStream_t> &streams);
|
||||
PerformanceResult& result,
|
||||
Options const& options,
|
||||
std::function<Status(int, cudaStream_t, int)> const& func,
|
||||
std::vector<cudaStream_t> const& streams);
|
||||
|
||||
/// Profiles the GPU kernel launched in `func` on the `stream`
|
||||
Status profile_kernel_(
|
||||
PerformanceResult &result,
|
||||
Options const &options,
|
||||
const std::function<Status(cudaStream_t, int)> &func,
|
||||
PerformanceResult& result,
|
||||
Options const& options,
|
||||
std::function<Status(cudaStream_t, int)> const& func,
|
||||
cudaStream_t stream = nullptr);
|
||||
|
||||
/// Profiles the GPU kernel launched in `func` on the `stream`
|
||||
Status profile_kernel_no_cuda_graphs_(
|
||||
PerformanceResult& result,
|
||||
Options const& options,
|
||||
std::function<Status(cudaStream_t, int)> const& func,
|
||||
cudaStream_t stream = nullptr);
|
||||
|
||||
private:
|
||||
|
||||
@ -208,6 +208,8 @@ public:
|
||||
/// Minimum number of iterations to profile
|
||||
int min_iterations{10};
|
||||
|
||||
bool use_cuda_graphs{false};
|
||||
|
||||
/// Number of ms to sleep between profiling periods (ms)
|
||||
int sleep_duration{50};
|
||||
|
||||
|
||||
@ -988,6 +988,12 @@ bool arg_as_scalar(
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem);
|
||||
|
||||
bool arg_as_string(
|
||||
std::string& arg,
|
||||
char const* name,
|
||||
ProblemSpace const& problem_space,
|
||||
ProblemSpace::Problem const& problem);
|
||||
|
||||
/// Returns true if a tensor description satisfies a `tensor` value
|
||||
bool tensor_description_satisfies(
|
||||
library::TensorDescription const &tensor_desc,
|
||||
|
||||
@ -75,7 +75,6 @@ BlockScaledGemmOperationProfiler::BlockScaledGemmOperationProfiler(Options const
|
||||
{ArgumentTypeID::kTensor, {"D"}, "Tensor storing the D output"},
|
||||
{ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"},
|
||||
{ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"},
|
||||
// TODO: Bring these back once SM100 future audits are complete
|
||||
{ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "Variant of split K mode(serial, parallel)"},
|
||||
{ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"},
|
||||
{ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"},
|
||||
@ -113,14 +112,11 @@ void BlockScaledGemmOperationProfiler::print_examples(std::ostream &out) const {
|
||||
<< "Schmoo over problem size and beta:\n"
|
||||
<< " $ cutlass_profiler --operation=block_scaled_gemm --m=1024:4096:256 --n=1024:4096:256 --k=128:8192:128 --beta=0,1,2.5\n\n"
|
||||
|
||||
// TODO: Bring these back once SM100 future audits are complete
|
||||
#if 0
|
||||
<< "Run when A is f16 with column-major and B is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n"
|
||||
<< "For column major, use column, col, or n. For row major use, row or t:\n"
|
||||
<< " $ cutlass_profiler --operation=Gemm --A=f16:column --B=*:row\n\n"
|
||||
|
||||
<< "Profile a particular problem size with split K and parallel reduction:\n"
|
||||
<< " $ cutlass_profiler --operation=Gemm --split_k_mode=parallel --split_k_slices=2 --m=1024 --n=1024 --k=128\n\n"
|
||||
#endif
|
||||
|
||||
<< "Using various input value distribution:\n"
|
||||
<< " $ cutlass_profiler --operation=Gemm --dist=uniform,min:0,max:3\n"
|
||||
@ -225,7 +221,6 @@ Status BlockScaledGemmOperationProfiler::GemmProblem::parse(
|
||||
this->split_k_slices = 1;
|
||||
}
|
||||
|
||||
// TODO: Bring these back once SM100 future audits are complete
|
||||
if (this->split_k_mode != library::SplitKMode::kSerial) {
|
||||
std::cout<<"SplitK/StreamK feature is not supported yet!";
|
||||
return Status::kErrorInvalidProblem;
|
||||
@ -403,7 +398,6 @@ void BlockScaledGemmOperationProfiler::GemmProblem::initialize_result(
|
||||
set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback);
|
||||
|
||||
|
||||
// TODO: Bring these back once SM100 future audits are complete
|
||||
set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
|
||||
set_argument(result, "split_k_slices", problem_space, split_k_slices);
|
||||
set_argument(result, "batch_count", problem_space, batch_count);
|
||||
@ -536,8 +530,6 @@ bool BlockScaledGemmOperationProfiler::initialize_reduction_configuration_(
|
||||
library::Operation const *operation,
|
||||
ProblemSpace::Problem const &problem) {
|
||||
|
||||
// TODO: Bring these back once SM100 future audits are complete
|
||||
#if 1
|
||||
library::BlockScaledGemmDescription const &gemm_desc =
|
||||
static_cast<library::BlockScaledGemmDescription const&>(operation->description());
|
||||
|
||||
@ -577,8 +569,6 @@ bool BlockScaledGemmOperationProfiler::initialize_reduction_configuration_(
|
||||
|
||||
// reduction operation found and initialized
|
||||
return true;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Initializes workspace
|
||||
|
||||
@ -545,6 +545,7 @@ bool cublasLtGemmExDispatcher::get_cublaslt_algo(cublasLtHandle_t handle,
|
||||
cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, requestedAlgoCount, heuristicResult, &returnedResults);
|
||||
|
||||
if (returnedResults == 0) {
|
||||
cudaFree(workspaceHeuristic);
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -589,6 +590,7 @@ bool cublasLtGemmExDispatcher::get_cublaslt_algo(cublasLtHandle_t handle,
|
||||
// Handle errors
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
std::cerr << "cublasLtMatmul AutoTuning failed with status: " << cublasLtGetStatusName(status) << std::endl;
|
||||
cudaFree(workspaceHeuristic);
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -653,6 +655,7 @@ bool cublasLtGemmExDispatcher::get_cublaslt_algo(cublasLtHandle_t handle,
|
||||
throw std::bad_alloc();
|
||||
}
|
||||
|
||||
cudaFree(workspaceHeuristic);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@ -36,16 +36,17 @@
|
||||
#include <stdexcept>
|
||||
|
||||
// Profiler includes
|
||||
#include "cutlass/profiler/cutlass_profiler.h"
|
||||
#include "cutlass/profiler/gemm_operation_profiler.h"
|
||||
#include "cutlass/profiler/block_scaled_gemm_operation_profiler.h"
|
||||
#include "cutlass/profiler/rank_k_operation_profiler.h"
|
||||
#include "cutlass/profiler/rank_2k_operation_profiler.h"
|
||||
#include "cutlass/profiler/trmm_operation_profiler.h"
|
||||
#include "cutlass/profiler/symm_operation_profiler.h"
|
||||
#include "cutlass/profiler/block_scaled_gemm_operation_profiler.h"
|
||||
#include "cutlass/profiler/conv2d_operation_profiler.h"
|
||||
#include "cutlass/profiler/conv3d_operation_profiler.h"
|
||||
#include "cutlass/profiler/cutlass_profiler.h"
|
||||
#include "cutlass/profiler/gemm_operation_profiler.h"
|
||||
#include "cutlass/profiler/grouped_gemm_operation_profiler.h"
|
||||
#include "cutlass/profiler/rank_2k_operation_profiler.h"
|
||||
#include "cutlass/profiler/rank_k_operation_profiler.h"
|
||||
#include "cutlass/profiler/sparse_gemm_operation_profiler.h"
|
||||
#include "cutlass/profiler/symm_operation_profiler.h"
|
||||
#include "cutlass/profiler/trmm_operation_profiler.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -76,6 +77,8 @@ CutlassProfiler::CutlassProfiler(
|
||||
operation_profilers_.emplace_back(new TrmmOperationProfiler(options));
|
||||
|
||||
operation_profilers_.emplace_back(new SymmOperationProfiler(options));
|
||||
|
||||
operation_profilers_.emplace_back(new GroupedGemmOperationProfiler(options));
|
||||
}
|
||||
|
||||
CutlassProfiler::~CutlassProfiler() {
|
||||
@ -201,6 +204,7 @@ void CutlassProfiler::print_usage_(std::ostream &out) {
|
||||
<< " $ cutlass_profiler --operation=Conv3d --help\n\n"
|
||||
<< " $ cutlass_profiler --operation=Conv2d --help\n\n"
|
||||
<< " $ cutlass_profiler --operation=SparseGemm --help\n\n"
|
||||
<< " $ cutlass_profiler --operation=GroupedGemm --help\n\n"
|
||||
;
|
||||
}
|
||||
|
||||
|
||||
@ -616,7 +616,6 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) {
|
||||
dist
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kFUE4M3:
|
||||
cutlass::reference::device::BlockFillRandom<cutlass::float_ue4m3_t>(
|
||||
reinterpret_cast<cutlass::float_ue4m3_t *>(pointer_),
|
||||
@ -657,7 +656,6 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) {
|
||||
dist
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF64:
|
||||
cutlass::reference::device::BlockFillRandom<double>(
|
||||
reinterpret_cast<double *>(pointer_),
|
||||
@ -823,7 +821,6 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) {
|
||||
);
|
||||
break;
|
||||
|
||||
|
||||
case library::NumericTypeID::kFE2M3:
|
||||
cutlass::reference::host::BlockFillRandom<cutlass::float_e2m3_t>(
|
||||
reinterpret_cast<cutlass::float_e2m3_t *>(host_data.data()),
|
||||
@ -856,7 +853,6 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) {
|
||||
dist
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
cutlass::reference::host::BlockFillRandom<cutlass::half_t>(
|
||||
reinterpret_cast<cutlass::half_t *>(host_data.data()),
|
||||
@ -1086,7 +1082,6 @@ void DeviceAllocation::initialize_sequential_device(Distribution dist) {
|
||||
);
|
||||
break;
|
||||
|
||||
|
||||
case library::NumericTypeID::kFE2M3:
|
||||
cutlass::reference::device::BlockFillSequential<cutlass::float_e2m3_t>(
|
||||
reinterpret_cast<cutlass::float_e2m3_t *>(pointer_),
|
||||
@ -1119,7 +1114,6 @@ void DeviceAllocation::initialize_sequential_device(Distribution dist) {
|
||||
static_cast<cutlass::float_ue8m0_t>(dist.sequential.start)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
cutlass::reference::device::BlockFillSequential<cutlass::half_t>(
|
||||
reinterpret_cast<cutlass::half_t *>(pointer_),
|
||||
@ -1360,7 +1354,6 @@ void DeviceAllocation::initialize_sequential_host(Distribution dist) {
|
||||
);
|
||||
break;
|
||||
|
||||
|
||||
case library::NumericTypeID::kFE2M3:
|
||||
cutlass::reference::host::BlockFillSequential<cutlass::float_e2m3_t>(
|
||||
reinterpret_cast<cutlass::float_e2m3_t *>(host_data.data()),
|
||||
@ -1393,7 +1386,6 @@ void DeviceAllocation::initialize_sequential_host(Distribution dist) {
|
||||
static_cast<cutlass::float_ue8m0_t>(dist.sequential.start)
|
||||
);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
cutlass::reference::host::BlockFillSequential<cutlass::half_t>(
|
||||
reinterpret_cast<cutlass::half_t *>(host_data.data()),
|
||||
@ -1690,7 +1682,6 @@ bool DeviceAllocation::block_compare_equal(
|
||||
reinterpret_cast<float_e5m2_t const *>(ptr_A),
|
||||
reinterpret_cast<float_e5m2_t const *>(ptr_B),
|
||||
capacity);
|
||||
|
||||
case library::NumericTypeID::kFUE4M3:
|
||||
return reference::device::BlockCompareEqual<float_ue4m3_t>(
|
||||
reinterpret_cast<float_ue4m3_t const *>(ptr_A),
|
||||
@ -1717,7 +1708,6 @@ bool DeviceAllocation::block_compare_equal(
|
||||
reinterpret_cast<float_e2m1_t const *>(ptr_A),
|
||||
reinterpret_cast<float_e2m1_t const *>(ptr_B),
|
||||
capacity);
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
return reference::device::BlockCompareEqual<half_t>(
|
||||
reinterpret_cast<half_t const *>(ptr_A),
|
||||
@ -1886,7 +1876,6 @@ bool DeviceAllocation::block_compare_relatively_equal(
|
||||
capacity,
|
||||
static_cast<float_e5m2_t>(epsilon),
|
||||
static_cast<float_e5m2_t>(nonzero_floor));
|
||||
|
||||
case library::NumericTypeID::kFUE4M3:
|
||||
return reference::device::BlockCompareRelativelyEqual<float_ue4m3_t>(
|
||||
reinterpret_cast<float_ue4m3_t const *>(ptr_A),
|
||||
@ -1925,7 +1914,6 @@ bool DeviceAllocation::block_compare_relatively_equal(
|
||||
capacity,
|
||||
static_cast<float_e2m1_t>(epsilon),
|
||||
static_cast<float_e2m1_t>(nonzero_floor));
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
return reference::device::BlockCompareRelativelyEqual<half_t>(
|
||||
reinterpret_cast<half_t const *>(ptr_A),
|
||||
@ -2273,7 +2261,6 @@ void DeviceAllocation::write_tensor_csv(
|
||||
write_tensor_csv_static_type<float_ue4m3_t>(out, *this);
|
||||
break;
|
||||
|
||||
|
||||
case library::NumericTypeID::kFE2M3:
|
||||
write_tensor_csv_static_type<float_e2m3_t>(out, *this);
|
||||
break;
|
||||
@ -2288,7 +2275,6 @@ void DeviceAllocation::write_tensor_csv(
|
||||
case library::NumericTypeID::kFUE8M0:
|
||||
write_tensor_csv_static_type<float_ue8m0_t>(out, *this);
|
||||
break;
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
write_tensor_csv_static_type<half_t>(out, *this);
|
||||
break;
|
||||
@ -2475,7 +2461,6 @@ void DeviceAllocation::fill_device(double val = 0.0) {
|
||||
case library::NumericTypeID::kFE2M1:
|
||||
tensor_fill<float_e2m1_t>(*this, static_cast<float_e2m1_t>(val));
|
||||
break;
|
||||
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
tensor_fill<half_t>(*this, static_cast<half_t>(val));
|
||||
@ -2611,7 +2596,6 @@ void DeviceAllocation::fill_host(double val = 0.0) {
|
||||
static_cast<float_e2m1_t>(val)
|
||||
);
|
||||
break;
|
||||
|
||||
|
||||
case library::NumericTypeID::kFE4M3:
|
||||
cutlass::reference::host::BlockFill<float_e4m3_t>(
|
||||
|
||||
@ -75,6 +75,35 @@ DeviceAllocation *DeviceContext::allocate_tensor(
|
||||
return allocation;
|
||||
}
|
||||
|
||||
static void initialize_allocation_with_data_distribution(
|
||||
Options const &options,
|
||||
int seed_shift,
|
||||
DeviceAllocation *allocation,
|
||||
Distribution &data_distribution) {
|
||||
if (options.initialization.provider == library::Provider::kReferenceDevice) {
|
||||
if (data_distribution.kind == Distribution::Sequential) {
|
||||
allocation->initialize_sequential_device(
|
||||
data_distribution);
|
||||
}
|
||||
else {
|
||||
allocation->initialize_random_device(
|
||||
options.initialization.seed + seed_shift,
|
||||
data_distribution);
|
||||
}
|
||||
}
|
||||
else if (options.initialization.provider == library::Provider::kReferenceHost) {
|
||||
if (data_distribution.kind == Distribution::Sequential) {
|
||||
allocation->initialize_sequential_host(
|
||||
data_distribution);
|
||||
}
|
||||
else {
|
||||
allocation->initialize_random_host(
|
||||
options.initialization.seed + seed_shift,
|
||||
data_distribution);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocates memory of a given type, capacity (elements), and name
|
||||
DeviceAllocation *DeviceContext::allocate_and_initialize_tensor(
|
||||
Options const &options,
|
||||
@ -122,7 +151,6 @@ DeviceAllocation *DeviceContext::allocate_and_initialize_tensor(
|
||||
data_distribution.set_uniform(1, 4, 0);
|
||||
break;
|
||||
|
||||
|
||||
case library::NumericTypeID::kF16:
|
||||
data_distribution.set_uniform(-3, 3, 0);
|
||||
break;
|
||||
@ -168,28 +196,9 @@ DeviceAllocation *DeviceContext::allocate_and_initialize_tensor(
|
||||
}
|
||||
}
|
||||
|
||||
if (options.initialization.provider == library::Provider::kReferenceDevice) {
|
||||
if (data_distribution.kind == Distribution::Sequential) {
|
||||
allocation->initialize_sequential_device(
|
||||
data_distribution);
|
||||
}
|
||||
else {
|
||||
allocation->initialize_random_device(
|
||||
options.initialization.seed + seed_shift,
|
||||
data_distribution);
|
||||
}
|
||||
}
|
||||
else if (options.initialization.provider == library::Provider::kReferenceHost) {
|
||||
if (data_distribution.kind == Distribution::Sequential) {
|
||||
allocation->initialize_sequential_host(
|
||||
data_distribution);
|
||||
}
|
||||
else {
|
||||
allocation->initialize_random_host(
|
||||
options.initialization.seed + seed_shift,
|
||||
data_distribution);
|
||||
}
|
||||
}
|
||||
initialize_allocation_with_data_distribution(
|
||||
options, seed_shift, allocation, data_distribution
|
||||
);
|
||||
}
|
||||
|
||||
return allocation;
|
||||
|
||||
@ -79,6 +79,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
|
||||
{ArgumentTypeID::kEnumerated, {"runtime_input_datatype_a", "runtime-input-datatype::a"}, "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"},
|
||||
{ArgumentTypeID::kEnumerated, {"runtime_input_datatype_b", "runtime-input-datatype::b"}, "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"},
|
||||
{ArgumentTypeID::kInteger, {"use_pdl", "use-pdl"}, "Use PDL (true, false)"},
|
||||
{ArgumentTypeID::kEnumerated, {"enable_sm90_mixed_dtype_shuffle_test", "enable-sm90-mixed-dtype-shuffle-test"}, "Enable SM90 mixed input data type kernel shuffle layout test (true, false)"},
|
||||
{ArgumentTypeID::kInteger, {"swizzle_size", "swizzle-size"}, "Size to swizzle"},
|
||||
},
|
||||
{ library::Provider::kCUBLAS}
|
||||
@ -211,6 +212,11 @@ Status GemmOperationProfiler::GemmProblem::parse(
|
||||
this->use_pdl = false;
|
||||
}
|
||||
|
||||
if (!arg_as_bool(this->enable_sm90_mixed_dtype_shuffle_test, "enable_sm90_mixed_dtype_shuffle_test", problem_space, problem)) {
|
||||
// default value
|
||||
this->enable_sm90_mixed_dtype_shuffle_test = false;
|
||||
}
|
||||
|
||||
if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) {
|
||||
// default value
|
||||
this->split_k_mode = library::SplitKMode::kSerial;
|
||||
@ -399,6 +405,7 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
|
||||
set_argument(result, "raster_order", problem_space, library::to_string(raster_order));
|
||||
set_argument(result, "swizzle_size", problem_space, swizzle_size);
|
||||
set_argument(result, "use_pdl", problem_space, library::to_string(use_pdl));
|
||||
set_argument(result, "enable_sm90_mixed_dtype_shuffle_test", problem_space, library::to_string(enable_sm90_mixed_dtype_shuffle_test));
|
||||
|
||||
|
||||
set_argument(result, "runtime_input_datatype_a", problem_space, library::to_string(runtime_input_datatype_a));
|
||||
@ -432,14 +439,26 @@ Status GemmOperationProfiler::initialize_configuration(
|
||||
|
||||
Status status = problem_.parse(operation_desc, problem_space, problem);
|
||||
|
||||
// Note: this is a temporary workaround
|
||||
bool is_current_operation_sm90_mixed_dtype_shuffle = (strstr(operation_desc.name, "_shfl") != NULL);
|
||||
if (is_current_operation_sm90_mixed_dtype_shuffle && (problem_.enable_sm90_mixed_dtype_shuffle_test == false)) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
const auto device_count = options.device.devices.size();
|
||||
auto const device_count = options.device.devices.size();
|
||||
|
||||
gemm_workspace_.clear();
|
||||
|
||||
library::NumericTypeID a_elem = library::get_real_type(operation_desc.A.element);
|
||||
library::NumericTypeID b_elem = library::get_real_type(operation_desc.B.element);
|
||||
int a_elem_bits = library::sizeof_bits(a_elem);
|
||||
int b_elem_bits = library::sizeof_bits(b_elem);
|
||||
bool is_mixed_input = (a_elem_bits != b_elem_bits);
|
||||
|
||||
for (size_t i = 0; i < device_count; ++i) {
|
||||
cudaSetDevice(options.device.device_id(i));
|
||||
gemm_workspace_.emplace_back();
|
||||
@ -455,7 +474,6 @@ Status GemmOperationProfiler::initialize_configuration(
|
||||
gemm_workspace_[i].configuration.cluster_shape_fallback.m() = int(problem_.cluster_m_fallback);
|
||||
gemm_workspace_[i].configuration.cluster_shape_fallback.n() = int(problem_.cluster_n_fallback);
|
||||
gemm_workspace_[i].configuration.cluster_shape_fallback.k() = int(problem_.cluster_k_fallback);
|
||||
|
||||
gemm_workspace_[i].configuration.lda = problem_.lda;
|
||||
gemm_workspace_[i].configuration.ldb = problem_.ldb;
|
||||
gemm_workspace_[i].configuration.ldc = problem_.ldc;
|
||||
@ -501,7 +519,77 @@ Status GemmOperationProfiler::initialize_configuration(
|
||||
|
||||
|
||||
initialize_result_(this->model_result_, options, operation_desc, problem_space);
|
||||
if (const auto can_implement = operation->can_implement(&gemm_workspace_[i].configuration, &gemm_workspace_[i].arguments); can_implement != Status::kSuccess) {
|
||||
if (is_mixed_input)
|
||||
{
|
||||
const int options_g = problem_.k;
|
||||
const int options_l = problem_.batch_count;
|
||||
const int scale_k = (problem_.k + options_g - 1) / options_g;
|
||||
// We cannot get the mainloop's ElementScale and ElementZero here,
|
||||
// use the wide type to allocate a large enough workspace for S and Z.
|
||||
library::NumericTypeID wide_dtype;
|
||||
size_t SZ_mat_size = 0;
|
||||
if (a_elem_bits > b_elem_bits) {
|
||||
wide_dtype = a_elem;
|
||||
SZ_mat_size = static_cast<size_t>(problem_.n * scale_k);
|
||||
}
|
||||
else {
|
||||
wide_dtype = b_elem;
|
||||
SZ_mat_size = static_cast<size_t>(problem_.m * scale_k);
|
||||
}
|
||||
|
||||
gemm_workspace_[i].Scale = device_context.allocate_tensor(
|
||||
options,
|
||||
"Scale",
|
||||
wide_dtype,
|
||||
library::LayoutTypeID::kRowMajor,
|
||||
{int(SZ_mat_size), int(options_l)},
|
||||
{int(options_l)},
|
||||
problem_.batch_count * gemm_workspace_[i].problem_count,
|
||||
i // device_index
|
||||
);
|
||||
gemm_workspace_[i].Zero = device_context.allocate_tensor(
|
||||
options,
|
||||
"Zero",
|
||||
wide_dtype,
|
||||
library::LayoutTypeID::kRowMajor,
|
||||
{int(SZ_mat_size), int(options_l)},
|
||||
{int(options_l)},
|
||||
problem_.batch_count * gemm_workspace_[i].problem_count,
|
||||
i // device_index
|
||||
);
|
||||
|
||||
// Packed scale is for int4 * fp8, where the original scale is fp8, and
|
||||
// each scale element will be packed into an Array<fp8, 8> which is 64-bit
|
||||
gemm_workspace_[i].packed_Scale = device_context.allocate_tensor(
|
||||
options,
|
||||
"packed-Scale",
|
||||
library::NumericTypeID::kU64,
|
||||
library::LayoutTypeID::kRowMajor,
|
||||
{int(SZ_mat_size), int(options_l)},
|
||||
{int(options_l)},
|
||||
problem_.batch_count * gemm_workspace_[i].problem_count,
|
||||
i // device_index
|
||||
);
|
||||
|
||||
gemm_workspace_[i].arguments.problem_size = {int(problem_.m), int(problem_.n), int(problem_.k)};
|
||||
gemm_workspace_[i].arguments.batch_count = problem_.batch_count;
|
||||
|
||||
// Here is the first touch of the arguments, mark the mixed dtype,
|
||||
// populate the scale and zero tensors in the following can_implement() call later.
|
||||
// A and B are not populated at this moment, so do not update the dequantized A or B
|
||||
gemm_workspace_[i].arguments.is_mixed_dtype = true;
|
||||
gemm_workspace_[i].arguments.wider_operand = (a_elem_bits > b_elem_bits) ? cutlass::library::Sm90MixedInputWiderOperand::A : cutlass::library::Sm90MixedInputWiderOperand::B;
|
||||
gemm_workspace_[i].arguments.generate_scale_and_zero = true;
|
||||
gemm_workspace_[i].arguments.generate_dequantized_AB = false;
|
||||
gemm_workspace_[i].arguments.dequantized_AB_ready = (bool *) malloc(sizeof(bool));
|
||||
gemm_workspace_[i].arguments.dequantized_AB_ready[0] = false;
|
||||
gemm_workspace_[i].arguments.Scale = gemm_workspace_[i].Scale->data();
|
||||
gemm_workspace_[i].arguments.Zero = gemm_workspace_[i].Zero->data();
|
||||
gemm_workspace_[i].arguments.packed_Scale = gemm_workspace_[i].packed_Scale->data();
|
||||
} // End of "if (is_mixed_input)"
|
||||
|
||||
const auto can_implement = operation->can_implement(&gemm_workspace_[i].configuration, &gemm_workspace_[i].arguments);
|
||||
if (can_implement != Status::kSuccess) {
|
||||
return can_implement;
|
||||
}
|
||||
}
|
||||
@ -693,6 +781,56 @@ Status GemmOperationProfiler::initialize_workspace(
|
||||
problem_.batch_count * gemm_workspace_[i].problem_count,
|
||||
i // device_index
|
||||
);
|
||||
|
||||
if (gemm_workspace_[i].arguments.is_mixed_dtype) {
|
||||
// Dequantized tensor has the same shape of the narrow data type tensor,
|
||||
// and the same data type as the wide data type tensor
|
||||
// Encoded tensor has the same shape and data type of the narrow data type tensor
|
||||
if (gemm_workspace_[i].arguments.wider_operand == cutlass::library::Sm90MixedInputWiderOperand::A) {
|
||||
gemm_workspace_[i].dequantized_AB = device_context.allocate_tensor(
|
||||
options,
|
||||
"dequantized-B",
|
||||
operation_desc.A.element,
|
||||
operation_desc.B.layout,
|
||||
{int(problem_.k), int(problem_.n)},
|
||||
{int(problem_.ldb)},
|
||||
problem_.batch_count * gemm_workspace_[i].problem_count,
|
||||
i // device_index
|
||||
);
|
||||
gemm_workspace_[i].encoded_AB = device_context.allocate_tensor(
|
||||
options,
|
||||
"encoded-B",
|
||||
operation_desc.B.element,
|
||||
operation_desc.B.layout,
|
||||
{int(problem_.k), int(problem_.n)},
|
||||
{int(problem_.ldb)},
|
||||
problem_.batch_count * gemm_workspace_[i].problem_count,
|
||||
i // device_index
|
||||
);
|
||||
}
|
||||
else {
|
||||
gemm_workspace_[i].dequantized_AB = device_context.allocate_tensor(
|
||||
options,
|
||||
"dequantized-A",
|
||||
operation_desc.B.element,
|
||||
operation_desc.A.layout,
|
||||
{int(problem_.m), int(problem_.k)},
|
||||
{int(problem_.lda)},
|
||||
problem_.batch_count * gemm_workspace_[i].problem_count,
|
||||
i // device_index
|
||||
);
|
||||
gemm_workspace_[i].encoded_AB = device_context.allocate_tensor(
|
||||
options,
|
||||
"encoded-A",
|
||||
operation_desc.A.element,
|
||||
operation_desc.A.layout,
|
||||
{int(problem_.m), int(problem_.k)},
|
||||
{int(problem_.lda)},
|
||||
problem_.batch_count * gemm_workspace_[i].problem_count,
|
||||
i // device_index
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (options.execution_mode != ExecutionMode::kDryRun) {
|
||||
@ -712,7 +850,7 @@ Status GemmOperationProfiler::initialize_workspace(
|
||||
gemm_workspace_[i].arguments.batch_stride_D = gemm_workspace_[i].Computed->batch_stride();
|
||||
|
||||
/* Query device SM count to pass onto the kernel as an argument, where needed */
|
||||
gemm_workspace_[i].arguments.sm_count = options.device.properties[0].multiProcessorCount;
|
||||
gemm_workspace_[i].arguments.sm_count = options.device.properties[i].multiProcessorCount;
|
||||
gemm_workspace_[i].arguments.device_index = static_cast<int>(i);
|
||||
}
|
||||
}
|
||||
@ -836,6 +974,17 @@ bool GemmOperationProfiler::verify_cutlass(
|
||||
gemm_workspace_[i].arguments.batch_stride_C = gemm_workspace_[i].C->batch_stride();
|
||||
gemm_workspace_[i].arguments.batch_stride_D = gemm_workspace_[i].Computed->batch_stride();
|
||||
|
||||
if (gemm_workspace_[i].arguments.is_mixed_dtype) {
|
||||
// Scale and zero already generated in initialize_configuration(),
|
||||
// A and B already generated in initialize_workspace(), signal
|
||||
// GemmUniversal3xOperation::update_arguments_() (trigger by underlying_operation->run())
|
||||
// to generate the dequantized matrix for verification
|
||||
gemm_workspace_[i].arguments.generate_scale_and_zero = false;
|
||||
gemm_workspace_[i].arguments.generate_dequantized_AB = true;
|
||||
gemm_workspace_[i].arguments.dequantized_AB = gemm_workspace_[i].dequantized_AB->data();
|
||||
gemm_workspace_[i].arguments.encoded_AB = gemm_workspace_[i].encoded_AB->data();
|
||||
}
|
||||
|
||||
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||
gemm_workspace_[i].arguments.D = gemm_workspace_[i].device_workspace.data();
|
||||
gemm_workspace_[i].arguments.alpha = problem_.alpha_one.data();
|
||||
@ -1133,7 +1282,6 @@ bool GemmOperationProfiler::verify_with_reference_(
|
||||
//
|
||||
// Initialize state
|
||||
//
|
||||
|
||||
for (auto provider : options.verification.providers) {
|
||||
|
||||
// Skip providers that are not enabled
|
||||
@ -1149,6 +1297,21 @@ bool GemmOperationProfiler::verify_with_reference_(
|
||||
void *ptr_C = gemm_workspace_[i].C->data();
|
||||
void *ptr_D = gemm_workspace_[i].Reference->data();
|
||||
|
||||
cutlass::library::NumericTypeID element_A_for_reference = element_A;
|
||||
cutlass::library::NumericTypeID element_B_for_reference = element_B;
|
||||
if (gemm_workspace_[i].arguments.is_mixed_dtype && gemm_workspace_[i].arguments.dequantized_AB_ready[0]) {
|
||||
// Dequantized tensor has the same shape of the narrow data type tensor,
|
||||
// and the same data type as the wide data type tensor
|
||||
if (gemm_workspace_[i].arguments.wider_operand == cutlass::library::Sm90MixedInputWiderOperand::A) {
|
||||
ptr_B = gemm_workspace_[i].dequantized_AB->data();
|
||||
element_B_for_reference = element_A;
|
||||
}
|
||||
else {
|
||||
ptr_A = gemm_workspace_[i].dequantized_AB->data();
|
||||
element_A_for_reference = element_B;
|
||||
}
|
||||
}
|
||||
|
||||
// To support the host-side reference, conditionally allocate and
|
||||
// copy tensors to host memory.
|
||||
std::vector<uint8_t> host_data_A;
|
||||
@ -1200,13 +1363,13 @@ bool GemmOperationProfiler::verify_with_reference_(
|
||||
|
||||
problem_.alpha.data(),
|
||||
|
||||
element_A,
|
||||
element_A_for_reference,
|
||||
gemm_desc.A.layout,
|
||||
gemm_desc.transform_A,
|
||||
ptr_A,
|
||||
int(gemm_workspace_[i].configuration.lda),
|
||||
|
||||
element_B,
|
||||
element_B_for_reference,
|
||||
gemm_desc.B.layout,
|
||||
gemm_desc.transform_B,
|
||||
ptr_B,
|
||||
@ -1349,6 +1512,13 @@ Status GemmOperationProfiler::profile_cutlass_(
|
||||
gemm_workspace_[dev_id].arguments.C = gemm_workspace_[dev_id].C->batch_data(problem_idx);
|
||||
gemm_workspace_[dev_id].arguments.D = gemm_workspace_[dev_id].Computed->batch_data(problem_idx);
|
||||
|
||||
if (gemm_workspace_[dev_id].arguments.is_mixed_dtype) {
|
||||
// Scale, zero, and dequantized tensors are already generated in
|
||||
// verify_cutlass(), no need to re-generate them in profiling
|
||||
gemm_workspace_[dev_id].arguments.generate_scale_and_zero = false;
|
||||
gemm_workspace_[dev_id].arguments.generate_dequantized_AB = false;
|
||||
}
|
||||
|
||||
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||
gemm_workspace_[dev_id].arguments.D = gemm_workspace_[dev_id].device_workspace.data();
|
||||
|
||||
@ -1383,11 +1553,6 @@ Status GemmOperationProfiler::profile_cutlass_(
|
||||
return Status::kSuccess;
|
||||
};
|
||||
|
||||
if (options.device.devices.size() == 1) {
|
||||
auto func = [&](cudaStream_t stream, int iteration) { return launch_gemm(0, stream, iteration); };
|
||||
return profile_kernel_(result, options, func, gemm_workspace_[0].stream);
|
||||
}
|
||||
|
||||
std::vector<cudaStream_t> streams(gemm_workspace_.size());
|
||||
for (size_t i = 0; i < streams.size(); i++) {
|
||||
streams[i] = gemm_workspace_[i].stream;
|
||||
|
||||
1034
tools/profiler/src/grouped_gemm_operation_profiler.cu
Normal file
1034
tools/profiler/src/grouped_gemm_operation_profiler.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -57,16 +57,6 @@
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CUDA_CHECK(call) \
|
||||
do { \
|
||||
cudaError_t err = call; \
|
||||
if (err != cudaSuccess) { \
|
||||
std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ << " code=" << err << " \"" \
|
||||
<< cudaGetErrorString(err) << "\"\n"; \
|
||||
return Status::kErrorInternal; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace cutlass {
|
||||
namespace profiler {
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -304,42 +294,43 @@ std::ostream& operator<<(std::ostream& out, library::Provider provider) {
|
||||
return out;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, library::OperationKind provider) {
|
||||
if (provider == library::OperationKind::kGemm) {
|
||||
std::ostream& operator<<(std::ostream& out, library::OperationKind op_kind) {
|
||||
if (op_kind == library::OperationKind::kGemm) {
|
||||
out << "kGemm";
|
||||
}
|
||||
|
||||
else if (provider == library::OperationKind::kBlockScaledGemm) {
|
||||
else if (op_kind == library::OperationKind::kBlockScaledGemm) {
|
||||
out << "kBlockScaledGemm";
|
||||
}
|
||||
|
||||
else if (provider == library::OperationKind::kRankK) {
|
||||
else if (op_kind == library::OperationKind::kRankK) {
|
||||
out << "kRankK";
|
||||
}
|
||||
else if (provider == library::OperationKind::kRank2K) {
|
||||
else if (op_kind == library::OperationKind::kRank2K) {
|
||||
out << "kRank2K";
|
||||
}
|
||||
else if (provider == library::OperationKind::kTrmm) {
|
||||
else if (op_kind == library::OperationKind::kTrmm) {
|
||||
out << "kTrmm";
|
||||
}
|
||||
else if (provider == library::OperationKind::kSymm) {
|
||||
else if (op_kind == library::OperationKind::kSymm) {
|
||||
out << "kSymm";
|
||||
}
|
||||
else if (provider == library::OperationKind::kConv2d) {
|
||||
else if (op_kind == library::OperationKind::kConv2d) {
|
||||
out << "kConv2d";
|
||||
}
|
||||
else if (provider == library::OperationKind::kConv3d) {
|
||||
else if (op_kind == library::OperationKind::kConv3d) {
|
||||
out << "kConv3d";
|
||||
}
|
||||
else if (provider == library::OperationKind::kEqGemm) {
|
||||
else if (op_kind == library::OperationKind::kEqGemm) {
|
||||
out << "kEqGemm";
|
||||
}
|
||||
else if (provider == library::OperationKind::kSparseGemm) {
|
||||
else if (op_kind == library::OperationKind::kSparseGemm) {
|
||||
out << "kSparseGemm";
|
||||
}
|
||||
else if (provider == library::OperationKind::kReduction) {
|
||||
else if (op_kind == library::OperationKind::kReduction) {
|
||||
out << "kReduction";
|
||||
}
|
||||
else if (op_kind == library::OperationKind::kGroupedGemm) {
|
||||
out << "kGroupedGemm";
|
||||
}
|
||||
else {
|
||||
out << "kInvalid";
|
||||
}
|
||||
@ -660,6 +651,11 @@ void OperationProfiler::save_workspace(
|
||||
|
||||
DeviceAllocation *allocation = named_allocation.second;
|
||||
|
||||
if (allocation->layout() == library::LayoutTypeID::kUnknown) {
|
||||
continue; // write_tensor not set up to handle DeviceAllocations initialized using
|
||||
// allocate_block()
|
||||
}
|
||||
|
||||
std::stringstream filename;
|
||||
|
||||
filename << desc.name << "_" << library::to_string(provider) << "_";
|
||||
@ -736,15 +732,20 @@ Status predict_iters(
|
||||
/// CUDA graphs allows you to record the launch of large numbers of kernels without
|
||||
/// blocking and therefore avoids a deadlock which happens if you try to enqueue too
|
||||
/// many kernels behind the spinloop kernel.
|
||||
Status OperationProfiler::profile_kernel_(
|
||||
PerformanceResult &result,
|
||||
Options const &options,
|
||||
const std::function<Status(int, cudaStream_t, int)> &func,
|
||||
const std::vector<cudaStream_t> &streams) {
|
||||
Status OperationProfiler::profile_kernel_w_cuda_graphs_(
|
||||
PerformanceResult& result,
|
||||
Options const& options,
|
||||
std::function<Status(int, cudaStream_t, int)> const& func,
|
||||
std::vector<cudaStream_t> const& streams) {
|
||||
|
||||
auto dev_count = streams.size();
|
||||
|
||||
cuda::atomic<bool> *release;
|
||||
CUDA_CHECK(cudaHostAlloc(&release, sizeof(*release), cudaHostAllocPortable));
|
||||
release->store(false, cuda::memory_order_release);
|
||||
|
||||
if (dev_count > 1) {
|
||||
CUDA_CHECK(cudaHostAlloc(&release, sizeof(*release), cudaHostAllocPortable));
|
||||
release->store(false, cuda::memory_order_release);
|
||||
}
|
||||
|
||||
std::vector<GpuTimer> timer;
|
||||
for (size_t i = 0; i < dev_count; ++i) {
|
||||
@ -774,9 +775,11 @@ Status OperationProfiler::profile_kernel_(
|
||||
for (size_t i = 0; i < dev_count; ++i) {
|
||||
CUDA_CHECK(cudaSetDevice(options.device.device_id(i)));
|
||||
CUDA_CHECK(cudaStreamBeginCapture(streams[i], cudaStreamCaptureModeGlobal));
|
||||
// Halt execution until all GPUs are ready to precede.
|
||||
// It allows the CPU to trigger the GPUs all start at the same time.
|
||||
delay<<<1, 1, 0, streams[i]>>>(release);
|
||||
if (dev_count > 1) {
|
||||
// Halt execution until all GPUs are ready to precede.
|
||||
// It allows the CPU to trigger the GPUs all start at the same time.
|
||||
delay<<<1, 1, 0, streams[i]>>>(release);
|
||||
}
|
||||
for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) {
|
||||
Status status = func(i, streams[i], iteration);
|
||||
if (status != Status::kSuccess) {
|
||||
@ -803,8 +806,10 @@ Status OperationProfiler::profile_kernel_(
|
||||
CUDA_CHECK(cudaGraphLaunch(graphExecs[i], streams[i]));
|
||||
}
|
||||
|
||||
// release the enqueued kernels
|
||||
release->store(true, cuda::memory_order_release);
|
||||
if (dev_count > 1) {
|
||||
// release the enqueued kernels
|
||||
release->store(true, cuda::memory_order_release);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < dev_count; ++i) {
|
||||
CUDA_CHECK(cudaSetDevice(options.device.device_id(i)));
|
||||
@ -819,7 +824,9 @@ Status OperationProfiler::profile_kernel_(
|
||||
}
|
||||
result.runtime /= static_cast<double>(dev_count);
|
||||
|
||||
CUDA_CHECK(cudaFreeHost(release));
|
||||
if (dev_count > 1) {
|
||||
CUDA_CHECK(cudaFreeHost(release));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < dev_count; ++i) {
|
||||
CUDA_CHECK(cudaSetDevice(options.device.device_id(i)));
|
||||
@ -835,11 +842,47 @@ Status OperationProfiler::profile_kernel_(
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Method to profile GPU execution time of a kernel launched in func
|
||||
Status OperationProfiler::profile_kernel_(
|
||||
PerformanceResult &result,
|
||||
Options const &options,
|
||||
const std::function<Status(cudaStream_t, int)> &func,
|
||||
const std::function<Status(int, cudaStream_t, int)> &func,
|
||||
const std::vector<cudaStream_t> &streams) {
|
||||
|
||||
if (options.profiling.use_cuda_graphs) {
|
||||
return profile_kernel_w_cuda_graphs_(result, options, func, streams);
|
||||
}
|
||||
else if (streams.size() == 1) {
|
||||
auto single_device_func = [&](cudaStream_t stream, int iteration) {
|
||||
return func(0, stream, iteration);
|
||||
};
|
||||
return profile_kernel_no_cuda_graphs_(result, options, single_device_func, streams[0]);
|
||||
}
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
/// Method to profile GPU execution time of a kernel launched in func
|
||||
Status OperationProfiler::profile_kernel_(
|
||||
PerformanceResult& result,
|
||||
Options const& options,
|
||||
std::function<Status(cudaStream_t, int)> const& func,
|
||||
cudaStream_t stream) {
|
||||
|
||||
if (options.profiling.use_cuda_graphs) {
|
||||
auto graph_func = [&](int dev_id, cudaStream_t stream, int iteration) {
|
||||
return func(stream, iteration);
|
||||
};
|
||||
return profile_kernel_w_cuda_graphs_(result, options, graph_func, {stream});
|
||||
} else {
|
||||
return profile_kernel_no_cuda_graphs_(result, options, func, stream);
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Method to profile GPU execution time of a kernel launched in func
|
||||
Status OperationProfiler::profile_kernel_no_cuda_graphs_(
|
||||
PerformanceResult& result,
|
||||
Options const& options,
|
||||
std::function<Status(cudaStream_t, int)> const& func,
|
||||
cudaStream_t stream) {
|
||||
|
||||
GpuTimer timer;
|
||||
|
||||
@ -477,6 +477,7 @@ Options::Profiling::Profiling(cutlass::CommandLine const &cmdline) {
|
||||
cmdline.get_cmd_line_argument("profiling-enabled", enabled, true);
|
||||
cmdline.get_cmd_line_argument("profiling-duration", duration, 10);
|
||||
cmdline.get_cmd_line_argument("min-iterations", min_iterations, 10);
|
||||
cmdline.get_cmd_line_argument("use-cuda-graphs", use_cuda_graphs, false);
|
||||
|
||||
if (cmdline.check_cmd_line_flag("providers")) {
|
||||
|
||||
|
||||
@ -1203,6 +1203,34 @@ bool arg_as_scalar(
|
||||
return arg_as_scalar(bytes, numeric_type, value_ptr);
|
||||
}
|
||||
|
||||
/// Returns a copy of the string passed to the argument.
|
||||
/// (kScalar arguments are stored as strings).
|
||||
bool arg_as_string(
|
||||
std::string& arg,
|
||||
char const* name,
|
||||
ProblemSpace const& problem_space,
|
||||
ProblemSpace::Problem const& problem) {
|
||||
|
||||
size_t idx = problem_space.argument_index(name);
|
||||
KernelArgument::Value const* value_ptr = problem.at(idx).get();
|
||||
|
||||
if (value_ptr->not_null) {
|
||||
if (value_ptr->argument->description->type == ArgumentTypeID::kScalar) {
|
||||
std::string const& str_value =
|
||||
static_cast<ScalarArgument::ScalarValue const*>(value_ptr)->value;
|
||||
arg = std::string(str_value);
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error(
|
||||
"arg_as_string() - illegal cast. Problem space argument must be scalar");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Returns true if a tensor description satisfies a `tensor` value
|
||||
|
||||
@ -56,9 +56,7 @@ template <typename T>
|
||||
T* allocate(size_t count = 1) {
|
||||
|
||||
T* ptr = 0;
|
||||
size_t bytes = 0;
|
||||
|
||||
bytes = count * sizeof(T);
|
||||
size_t bytes = count * sizeof_bits<T>::value / 8;
|
||||
|
||||
cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes);
|
||||
|
||||
|
||||
480
tools/util/include/cutlass/util/mixed_dtype_utils.hpp
Normal file
480
tools/util/include/cutlass/util/mixed_dtype_utils.hpp
Normal file
@ -0,0 +1,480 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Utilities for mixed input data type kernels.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/arch/mma_sm90.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cute/util/type_traits.hpp"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
#define CUDA_CHECK(status) \
|
||||
{ \
|
||||
cudaError_t error = status; \
|
||||
if (error != cudaSuccess) { \
|
||||
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
|
||||
<< " at line: " << __LINE__ << std::endl; \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
}
|
||||
|
||||
template <
|
||||
class QuantizedElement,
|
||||
class DequantizedElement,
|
||||
class OperandLayout,
|
||||
class ElementScale,
|
||||
class ElementZero,
|
||||
class ScaleBroadCastLayout,
|
||||
class ThrLayout>
|
||||
__global__ void dequantize_kernel(DequantizedElement* dq_buffer,
|
||||
QuantizedElement const* q_buffer,
|
||||
OperandLayout const operand_layout,
|
||||
ElementScale const* scale_buffer,
|
||||
ElementZero const* zero_buffer,
|
||||
ScaleBroadCastLayout const broadcasted_scale_layout,
|
||||
ThrLayout thr_layout) {
|
||||
using namespace cute;
|
||||
|
||||
// Represent the full tensors to gmem elements.
|
||||
// These are expected to have shape [MN, K, L]
|
||||
cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout);
|
||||
auto init_quantized_iterator = [&]() {
|
||||
if constexpr (cute::sizeof_bits_v<QuantizedElement> >= 8) {
|
||||
return cute::make_gmem_ptr(q_buffer);
|
||||
}
|
||||
else {
|
||||
return cute::subbyte_iterator<const QuantizedElement>(q_buffer);
|
||||
}
|
||||
};
|
||||
cute::Tensor gmem_op_q = cute::make_tensor(init_quantized_iterator(), operand_layout);
|
||||
// While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting
|
||||
// It is expected that K % G == 0
|
||||
cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout);
|
||||
cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout);
|
||||
|
||||
// Assign 1 thread per element in the thread block
|
||||
auto blk_shape = cute::make_shape(size<0>(thr_layout), _1{}, _1{}); //
|
||||
auto blk_coord = cute::make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L)
|
||||
|
||||
// Tile across the block
|
||||
auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord);
|
||||
auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord);
|
||||
auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord);
|
||||
auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord);
|
||||
|
||||
auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x);
|
||||
auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x);
|
||||
auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x);
|
||||
auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x);
|
||||
|
||||
// Make a fragment of registers to hold gmem loads
|
||||
cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0));
|
||||
cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0));
|
||||
cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0));
|
||||
cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0));
|
||||
cute::Tensor rmem_op_scaled = cute::make_fragment_like<ElementScale>(rmem_op_dq);
|
||||
cute::Tensor rmem_zero_buf = cute::make_fragment_like<ElementScale>(rmem_zero);
|
||||
|
||||
cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout));
|
||||
auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord);
|
||||
auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x);
|
||||
|
||||
const auto num_iters = cute::size<3>(tOpDq_gOpDq);
|
||||
|
||||
for (int ii = 0; ii < num_iters; ++ii) {
|
||||
const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii));
|
||||
if (thread_offset < cute::size<0>(operand_layout)) {
|
||||
cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q);
|
||||
cute::copy(tScale_gScale(_, _, _, ii), rmem_scale);
|
||||
cute::copy(tZero_gZero(_, _, _, ii), rmem_zero);
|
||||
cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } );
|
||||
cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } );
|
||||
cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, cute::multiplies{});
|
||||
cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, cute::plus{});
|
||||
cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } );
|
||||
cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
class QuantizedElement,
|
||||
class DequantizedElement,
|
||||
class OperandLayout,
|
||||
class ElementScale,
|
||||
class ElementZero,
|
||||
class ScaleLayout>
|
||||
static void dequantize(DequantizedElement* dq_buffer,
|
||||
QuantizedElement const* q_buffer,
|
||||
OperandLayout const operand_layout,
|
||||
ElementScale const* scale_buffer,
|
||||
ElementZero const* zero_buffer,
|
||||
ScaleLayout const scale_layout,
|
||||
int const group_size,
|
||||
cudaStream_t &stream) {
|
||||
using namespace cute;
|
||||
|
||||
constexpr int tpb = 128;
|
||||
auto thr_layout = make_layout(make_shape(Int<tpb>{}));
|
||||
|
||||
const auto num_rows = get<0>(shape(operand_layout));
|
||||
const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L]
|
||||
const auto batches = get<2>(shape(operand_layout)); // [MN, K, L]
|
||||
const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L]
|
||||
|
||||
if (num_rows != size<0>(scale_layout)) {
|
||||
std::cerr << "Invalid first dimension for scales. Must match first dim for weights."
|
||||
<< " But got shapes " << shape(operand_layout) << " " << shape(scale_layout)
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const auto scale_stride0 = get<0>(stride(scale_layout));
|
||||
const auto scale_stride1 = get<1>(stride(scale_layout));
|
||||
const auto scale_stride2 = get<2>(stride(scale_layout));
|
||||
|
||||
auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches);
|
||||
auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2);
|
||||
auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast);
|
||||
|
||||
const auto blocks_x = gemm_k;
|
||||
const auto blocks_y = batches;
|
||||
|
||||
dim3 blocks(blocks_x, blocks_y, 1);
|
||||
dequantize_kernel<<<blocks, tpb, 0, stream>>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout);
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class packed_scale_t {
|
||||
public:
|
||||
static_assert(cute::is_same_v<T, cutlass::int8_t> ||
|
||||
cute::is_same_v<T, cutlass::uint8_t> ||
|
||||
cute::is_same_v<T, cutlass::float_e4m3_t> ||
|
||||
cute::is_same_v<T, cutlass::float_e5m2_t>,
|
||||
"only 8 bit arithmetic types are supported.");
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit packed_scale_t(T val) {
|
||||
if constexpr (!cute::is_unsigned_v<T>) {
|
||||
// Only pack negative values. The positive values are generated in flight in the mainloop.
|
||||
storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f));
|
||||
storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val);
|
||||
}
|
||||
else {
|
||||
storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f));
|
||||
storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val);
|
||||
}
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
packed_scale_t() = default;
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit operator float() const {
|
||||
return float(get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(packed_scale_t const& rhs) const {
|
||||
return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1];
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(packed_scale_t const& rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() + rhs.get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() - rhs.get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() * rhs.get());
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) {
|
||||
return packed_scale_t(lhs.get() / rhs.get());
|
||||
}
|
||||
|
||||
private:
|
||||
using Storage = uint32_t;
|
||||
using Stage = uint8_t;
|
||||
|
||||
Storage storage[2] {};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Storage pack4(T c1, T c2, T c3, T c4) {
|
||||
Storage result = 0;
|
||||
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c4)) << 24);
|
||||
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c3)) << 16);
|
||||
result |= (static_cast<Storage>(reinterpret_cast<Stage const&>(c2)) << 8);
|
||||
result |= static_cast<Storage>(reinterpret_cast<Stage const&>(c1));
|
||||
return result;
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
T get() const {
|
||||
auto stage = static_cast<Stage>(storage[0] >> 8);
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return reinterpret_cast<T const&>(stage);
|
||||
#else
|
||||
T tmp;
|
||||
std::memcpy(&tmp, &stage, sizeof(Stage));
|
||||
return tmp;
|
||||
#endif
|
||||
}
|
||||
CUTLASS_HOST_DEVICE
|
||||
T get(int idx) const {
|
||||
Stage stage;
|
||||
if (idx < 4) stage = static_cast<Stage>(storage[0] >> (8 * idx));
|
||||
else stage = static_cast<Stage>(storage[1] >> (8 * idx - 32));
|
||||
#if defined(__CUDA_ARCH__)
|
||||
return reinterpret_cast<T const&>(stage);
|
||||
#else
|
||||
T tmp;
|
||||
std::memcpy(&tmp, &stage, sizeof(Stage));
|
||||
return tmp;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT.
|
||||
// Here the encodings of positive values and negative values are unified (except for the sign bit).
|
||||
// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111).
|
||||
static bool unified_encode_int4b(cutlass::int4b_t const *block_in, cutlass::int4b_t *block_out, const size_t block_size) {
|
||||
|
||||
using StorageType = cutlass::int4b_t::Storage;
|
||||
constexpr int pack = cute::sizeof_bits_v<StorageType> / 4;
|
||||
const size_t host_buf_size = block_size / pack;
|
||||
std::vector<StorageType> host_buf(host_buf_size);
|
||||
cutlass::device_memory::copy_to_host(host_buf.data(), (StorageType *) block_in, host_buf_size);
|
||||
|
||||
for (auto&& d : host_buf) {
|
||||
StorageType out = 0;
|
||||
StorageType mask = 0x0f;
|
||||
for (int i = 0; i < pack; i++) {
|
||||
cutlass::int4b_t curr;
|
||||
curr.storage = (d >> (i * 4)) & 0x0f;
|
||||
switch (curr) {
|
||||
case 1: curr.storage = StorageType(0b0111); break; // 2's complement
|
||||
case 2: curr.storage = StorageType(0b0110); break; // 2's complement
|
||||
case 3: curr.storage = StorageType(0b0101); break; // 2's complement
|
||||
case 4: curr.storage = StorageType(0b0100); break; // 2's complement
|
||||
case 5: curr.storage = StorageType(0b0011); break; // 2's complement
|
||||
case 6: curr.storage = StorageType(0b0010); break; // 2's complement
|
||||
case 7: curr.storage = StorageType(0b0001); break; // 2's complement
|
||||
default: break;
|
||||
}
|
||||
out |= (curr.storage << (4 * i)) & mask;
|
||||
mask <<= 4;
|
||||
}
|
||||
d = out;
|
||||
}
|
||||
|
||||
cutlass::device_memory::copy_to_device((StorageType*) block_out, host_buf.data(), host_buf_size);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ElementScale>
|
||||
static bool pack_scale_fp8(ElementScale const *block_in, cutlass::Array<ElementScale, 8> *block_out, const size_t block_size) {
|
||||
std::vector<ElementScale> data_in(block_size);
|
||||
std::vector<cutlass::Array<ElementScale, 8>> data_out(block_size);
|
||||
|
||||
try {
|
||||
cutlass::device_memory::copy_to_host(data_in.data(), block_in, block_size);
|
||||
}
|
||||
catch (cutlass::cuda_exception const& e) {
|
||||
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < block_size; i++) {
|
||||
cutlass::packed_scale_t<ElementScale> tmp(data_in[i]);
|
||||
data_out[i] = reinterpret_cast<cutlass::Array<ElementScale, 8> const&>(tmp);
|
||||
}
|
||||
|
||||
try {
|
||||
cutlass::device_memory::copy_to_device(block_out, data_out.data(), block_size);
|
||||
}
|
||||
catch (cutlass::cuda_exception const& e) {
|
||||
std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class T, class = void>
|
||||
struct UnderlyingElement {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct UnderlyingElement<T, cute::void_t<typename T::Element>> {
|
||||
using type = typename T::Element;
|
||||
};
|
||||
|
||||
// Given a type of MMA instruction, compute a memory reordering atom that places all values
|
||||
// owned by each thread in contiguous memory locations. This improves smem load vectorization,
|
||||
// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order
|
||||
// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses.
|
||||
// In addition, we can reorder the values across several MMA instructions to get even wider
|
||||
// vectorization (AtomLayout parameter) and permute the values within each instruction to get
|
||||
// more optimal conversion instruction sequences (ValLayout parameter).
|
||||
template <class ElementMma,
|
||||
class AtomLayout = cute::Layout<cute::_1>,
|
||||
class ValLayout = cute::Layout<cute::_1>>
|
||||
constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {})
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
static_assert(is_static_v<ValLayout>, "ValLayout must be static");
|
||||
static_assert(is_static_v<AtomLayout>, "AtomLayout must be static");
|
||||
|
||||
// 1. Choose an MMA atom to access TV layout and MN shape
|
||||
// Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary
|
||||
using MmaAtom = decltype(SM90::GMMA::rs_op_selector<ElementMma, ElementMma, float, Shape<_64,_16,_32>>());
|
||||
using MmaTraits = MMA_Traits<MmaAtom>;
|
||||
auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{});
|
||||
auto tv_layout_mma = typename MmaTraits::ALayout{};
|
||||
static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout");
|
||||
|
||||
// 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val)
|
||||
// Note: this assumes A is partitioned between warps along M mode
|
||||
auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma));
|
||||
auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{});
|
||||
auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp));
|
||||
auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp);
|
||||
|
||||
// 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization
|
||||
auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout);
|
||||
|
||||
// 4. Compose with a contiguous layout of values in each thread (required for smem vectorization)
|
||||
auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout));
|
||||
auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp));
|
||||
auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset));
|
||||
auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt);
|
||||
|
||||
return layout_atom;
|
||||
}
|
||||
|
||||
template <class TileShape, class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst, class TiledCopy>
|
||||
__global__ void reorder_tensor_kernel(
|
||||
cute::Tensor<EngineSrc, LayoutSrc> S,
|
||||
cute::Tensor<EngineDst, LayoutDst> D,
|
||||
TiledCopy tiled_copy)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
using T = typename EngineDst::value_type;
|
||||
|
||||
Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
|
||||
Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z));
|
||||
|
||||
auto thread_copy = tiled_copy.get_slice(threadIdx.x);
|
||||
Tensor tS = thread_copy.partition_S(gS);
|
||||
Tensor tD = thread_copy.partition_D(gD);
|
||||
|
||||
copy(tiled_copy, tS, tD);
|
||||
}
|
||||
|
||||
template <class EngineSrc, class LayoutSrc, class EngineDst, class LayoutDst>
|
||||
void reorder_tensor(
|
||||
cute::Tensor<EngineSrc, LayoutSrc> S,
|
||||
cute::Tensor<EngineDst, LayoutDst> D)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
using T = typename EngineDst::value_type;
|
||||
static_assert(is_same_v<remove_const_t<typename EngineSrc::value_type>, T>, "Type mismatch");
|
||||
|
||||
// Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread
|
||||
// This avoids a race condition when writing out subbyte types (e.g. int4b_t).
|
||||
auto has_major_mode = [](auto s) {
|
||||
return any_of(flatten(s), [](auto a){ return is_constant<1, decltype(a)>{}; });
|
||||
};
|
||||
static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})),
|
||||
"Could not find stride-1 mode in destination layout");
|
||||
constexpr int N = shape_div(Int<8>{}, sizeof_bits<T>{});
|
||||
auto val_layout = conditional_return<has_major_mode(stride<0>(LayoutDst{}))>(
|
||||
make_layout(make_shape(Int<N>{}, Int<1>{}), GenColMajor{}),
|
||||
make_layout(make_shape(Int<1>{}, Int<N>{}), GenRowMajor{}));
|
||||
|
||||
// Make a tiled copy with a simple row-major thread order and above layout
|
||||
int constexpr NumThreads = 128;
|
||||
auto const thr_layout = make_layout(make_shape(Int<1>{}, Int<NumThreads>{}));
|
||||
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, T>{}, thr_layout, val_layout);
|
||||
|
||||
// Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper
|
||||
using TileShape = Shape<_16>;
|
||||
auto tiled_D = group_modes<3,rank_v<LayoutDst>>(tiled_divide(D, TileShape{}));
|
||||
dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))};
|
||||
|
||||
reorder_tensor_kernel<TileShape><<<blocks, NumThreads>>>(S, D, tiled_copy);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
// In-place version
|
||||
template <class T, class LayoutSrc, class LayoutDst>
|
||||
void reorder_tensor(
|
||||
T const* src,
|
||||
LayoutSrc const& layout_src,
|
||||
T * dst,
|
||||
LayoutDst const& layout_dst)
|
||||
{
|
||||
using namespace cute;
|
||||
reorder_tensor(make_tensor(make_gmem_ptr<T>(src), layout_src),
|
||||
make_tensor(make_gmem_ptr<T>(dst), layout_dst));
|
||||
}
|
||||
|
||||
// In-place version
|
||||
template <class T, class LayoutSrc, class LayoutDst>
|
||||
void reorder_tensor(
|
||||
T * data,
|
||||
LayoutSrc const& layout_src,
|
||||
LayoutDst const& layout_dst)
|
||||
{
|
||||
using namespace cute;
|
||||
cutlass::DeviceAllocation<T> temp(size(layout_src));
|
||||
reorder_tensor(data, layout_src, temp.get(), layout_dst);
|
||||
cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast<size_t>(size(layout_src)));
|
||||
}
|
||||
|
||||
#undef CUDA_CHECK
|
||||
|
||||
} // namespace cutlass
|
||||
@ -388,7 +388,7 @@ void compute_1d_scaling_factor_and_quantized_output(
|
||||
absolute_value_op<ElementCompute> abs_op;
|
||||
maximum_with_nan_propogation<ElementCompute> max_op;
|
||||
|
||||
if constexpr (cute::is_constant<1, decltype(cute::stride<0,1>(tensor_SfD))>::value) {
|
||||
if constexpr (cute::is_constant<1, decltype(cute::stride<0,0,1>(tensor_SfD))>::value) {
|
||||
// MN major output
|
||||
int const NumVecPerBlock = ceil_div(kBlockM, kVectorSize);
|
||||
// Col major output
|
||||
@ -705,7 +705,7 @@ void gett_epilogue(
|
||||
if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) {
|
||||
// Convert every type to ElementCompute first, do compute, convert to output type, write it out
|
||||
ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]);
|
||||
// per-row alpha
|
||||
// vector alpha
|
||||
if (raw_pointer_cast(epilogue_params.Valpha.data())) {
|
||||
converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b, n + n_b, l));
|
||||
converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b));
|
||||
@ -719,7 +719,7 @@ void gett_epilogue(
|
||||
|
||||
if (raw_pointer_cast(epilogue_params.C.data())) {
|
||||
ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l));
|
||||
// per-row beta
|
||||
// vector beta
|
||||
if (epilogue_params.Vbeta.data()) {
|
||||
converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b, n + n_b, l));
|
||||
converted_beta = mul(converted_beta, converted_scale_c);
|
||||
|
||||
Reference in New Issue
Block a user